diff --git a/server/claude-sdk.js b/server/claude-sdk.js index dc48a92f..da15142d 100644 --- a/server/claude-sdk.js +++ b/server/claude-sdk.js @@ -18,6 +18,7 @@ import { promises as fs } from 'fs'; import path from 'path'; import os from 'os'; import { CLAUDE_FALLBACK_MODELS } from './modules/providers/list/claude/claude-models.provider.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { resolveClaudeCodeExecutablePath } from './shared/claude-cli-path.js'; import { createNotificationEvent, @@ -491,8 +492,17 @@ async function queryClaudeSDK(command, options = {}, ws) { }; try { + const resolvedModel = await providerModelsService.resolveResumeModel( + 'claude', + sessionId, + options.model, + ); + // Map CLI options to SDK format - const sdkOptions = mapCliOptionsToSDK(options); + const sdkOptions = mapCliOptionsToSDK({ + ...options, + model: resolvedModel || options.model, + }); // Load MCP configuration const mcpServers = await loadMcpConfig(options.cwd); diff --git a/server/cursor-cli.js b/server/cursor-cli.js index 1d5a7d79..6cfd7bac 100644 --- a/server/cursor-cli.js +++ b/server/cursor-cli.js @@ -3,6 +3,7 @@ import crossSpawn from 'cross-spawn'; import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js'; import { sessionsService } from './modules/providers/services/sessions.service.js'; import { providerAuthService } from './modules/providers/services/provider-auth.service.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { createNormalizedMessage } from './shared/utils.js'; // Use cross-spawn on Windows for better command execution @@ -28,6 +29,7 @@ function isWorkspaceTrustPrompt(text = '') { async function spawnCursor(command, options = {}, ws) { return new Promise(async (resolve, reject) => { const { sessionId, projectPath, cwd, resume, toolsSettings, skipPermissions, model, sessionSummary } = options; + const resolvedModel = await providerModelsService.resolveResumeModel('cursor', sessionId, model); let capturedSessionId = sessionId; // Track session ID throughout the process let sessionCreatedSent = false; // Track if we've already sent session-created event let hasRetriedWithTrust = false; @@ -52,9 +54,10 @@ async function spawnCursor(command, options = {}, ws) { // Provide a prompt (works for both new and resumed sessions) baseArgs.push('-p', command); - // Add model flag if specified (only meaningful for new sessions; harmless on resume) - if (!sessionId && model) { - baseArgs.push('--model', model); + // Model overrides are applied to both new and resumed sessions so a + // session-scoped change request can take effect on the next turn. + if (resolvedModel) { + baseArgs.push('--model', resolvedModel); } // Request streaming JSON when we are providing a prompt diff --git a/server/gemini-cli.js b/server/gemini-cli.js index 1f45682c..ee1fd845 100644 --- a/server/gemini-cli.js +++ b/server/gemini-cli.js @@ -9,6 +9,7 @@ import sessionManager from './sessionManager.js'; import GeminiResponseHandler from './gemini-response-handler.js'; import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js'; import { providerAuthService } from './modules/providers/services/provider-auth.service.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { createNormalizedMessage } from './shared/utils.js'; // Use cross-spawn on Windows for correct .cmd resolution (same pattern as cursor-cli.js) @@ -120,6 +121,11 @@ async function buildGeminiProcessEnv() { async function spawnGemini(command, options = {}, ws) { const { sessionId, projectPath, cwd, toolsSettings, permissionMode, images, sessionSummary } = options; + const resolvedModel = await providerModelsService.resolveResumeModel( + 'gemini', + sessionId, + options.model + ); let capturedSessionId = sessionId; // Track session ID throughout the process let sessionCreatedSent = false; // Track if we've already sent session-created event let assistantBlocks = []; // Accumulate the full response blocks including tools @@ -244,7 +250,7 @@ async function spawnGemini(command, options = {}, ws) { } // Add model for all sessions (both new and resumed) - let modelToUse = options.model || 'gemini-2.5-flash'; + let modelToUse = resolvedModel || 'gemini-2.5-flash'; args.push('--model', modelToUse); args.push('--output-format', 'stream-json'); diff --git a/server/modules/providers/list/claude/claude-models.provider.ts b/server/modules/providers/list/claude/claude-models.provider.ts index b0d845a3..02e04194 100644 --- a/server/modules/providers/list/claude/claude-models.provider.ts +++ b/server/modules/providers/list/claude/claude-models.provider.ts @@ -1,16 +1,21 @@ -import { spawn } from 'node:child_process'; +import { readFile } from 'node:fs/promises'; import { query, type ModelInfo, type Options } from '@anthropic-ai/claude-agent-sdk'; -import crossSpawn from 'cross-spawn'; +import { sessionsDb } from '@/modules/database/index.js'; import { resolveClaudeCodeExecutablePath } from '@/shared/claude-cli-path.js'; import type { IProviderModels } from '@/shared/interfaces.js'; import type { + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelOption, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; -import { buildDefaultProviderCurrentActiveModel } from '@/shared/utils.js'; +import { + buildDefaultProviderCurrentActiveModel, + writeProviderSessionActiveModelChange, +} from '@/shared/utils.js'; export const CLAUDE_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -26,13 +31,23 @@ export const CLAUDE_FALLBACK_MODELS: ProviderModelsDefinition = { type ClaudeModelQueryOptions = Pick; type ClaudeInitEvent = { + sessionId?: string; + session_id?: string; type?: string; subtype?: string; model?: string; + message?: { + content?: unknown; + model?: string; + }; }; -const CLAUDE_ACTIVE_MODEL_TIMEOUT_MS = 20_000; -const claudeSpawn = process.platform === 'win32' ? crossSpawn : spawn; +const ANSI_PATTERN = new RegExp( + '[\\u001B\\u009B][[\\]()#;?]*(?:' + + '(?:[0-9]{1,4}(?:;[0-9]{0,4})*)?[0-9A-ORZcf-nqry=><]' + + '|(?:[\\dA-PR-TZcf-ntqry=><~]))', + 'g', +); const buildClaudeQueryOptions = (): ClaudeModelQueryOptions => ({ env: { ...process.env }, @@ -74,82 +89,94 @@ const buildClaudeModelsDefinition = (models: ModelInfo[]): ProviderModelsDefinit }; }; -const runClaudeSessionModelCommand = async (sessionId: string): Promise => { - const cliPath = resolveClaudeCodeExecutablePath(process.env.CLAUDE_CLI_PATH); +const extractClaudeEventModel = (event: ClaudeInitEvent, sessionId: string): string | null => { + const eventSessionId = event.sessionId ?? event.session_id; + if (eventSessionId && eventSessionId !== sessionId) { + return null; + } - return new Promise((resolve, reject) => { - const child = claudeSpawn( - cliPath, - ['-p', '--verbose', '--output-format', 'stream-json', '--resume', sessionId, 'ok'], - { - env: { ...process.env }, - windowsHide: true, - }, - ); + const contentModel = extractClaudeModelFromMessageContent(event.message?.content); + if (contentModel) { + return contentModel; + } - let stdout = ''; - let stderr = ''; - let settled = false; + const directModel = event.model?.trim(); + if (directModel) { + return directModel; + } - const timer = setTimeout(() => { - child.kill('SIGTERM'); - if (!settled) { - settled = true; - reject(new Error('Claude current-model lookup timed out')); + const messageModel = event.message?.model?.trim(); + return messageModel || null; +}; + +const stripAnsi = (value: string): string => value.replace(ANSI_PATTERN, ''); + +const extractTaggedContent = (content: string, tagName: string): string | null => { + const escapedTagName = tagName.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + const match = new RegExp(`<${escapedTagName}>([\\s\\S]*?)<\\/${escapedTagName}>`).exec(content); + return match ? match[1] : null; +}; + +const extractClaudeModelFromTextContent = (content: string): string | null => { + const localCommandStdout = extractTaggedContent(content, 'local-command-stdout'); + if (localCommandStdout !== null) { + const cleanedStdout = stripAnsi(localCommandStdout).replace(/\s+/g, ' ').trim(); + const changedModel = /(?:set|changed|switched)\s+model\s+to\s+(.+?)\.?$/i.exec(cleanedStdout); + if (changedModel?.[1]?.trim()) { + return changedModel[1].trim(); + } + } + + const modelTag = extractTaggedContent(content, 'model')?.trim(); + return modelTag || null; +}; + +const extractClaudeModelFromMessageContent = (content: unknown): string | null => { + if (typeof content === 'string') { + return extractClaudeModelFromTextContent(content); + } + + if (!Array.isArray(content)) { + return null; + } + + for (const part of content) { + if (!part || typeof part !== 'object' || !('text' in part) || typeof part.text !== 'string') { + continue; + } + + const model = extractClaudeModelFromTextContent(part.text); + if (model) { + return model; + } + } + + return null; +}; + +const readClaudeSessionModelFromJsonl = async ( + sessionId: string, + jsonlPath: string, +): Promise => { + const content = await readFile(jsonlPath, 'utf8'); + const lines = content + .split(/\r?\n/) + .map((line) => line.trim()) + .filter(Boolean); + + for (let index = lines.length - 1; index >= 0; index -= 1) { + try { + const event = JSON.parse(lines[index]) as ClaudeInitEvent; + const model = extractClaudeEventModel(event, sessionId); + if (model) { + return { model }; } - }, CLAUDE_ACTIVE_MODEL_TIMEOUT_MS); + } catch { + // Skip malformed JSONL lines that can happen during concurrent writes. + } + } - const finish = (error: Error | null, result: ProviderCurrentActiveModel | null) => { - if (settled) { - return; - } - - settled = true; - clearTimeout(timer); - - if (error) { - reject(error); - return; - } - - resolve(result); - }; - - child.stdout?.on('data', (chunk: Buffer) => { - stdout += chunk.toString(); - }); - - child.stderr?.on('data', (chunk: Buffer) => { - stderr += chunk.toString(); - }); - - child.on('error', (error) => { - finish(error instanceof Error ? error : new Error(String(error)), null); - }); - - child.on('close', () => { - const lines = `${stdout}\n${stderr}` - .split(/\r?\n/) - .map((line) => line.trim()) - .filter(Boolean); - - for (const line of lines) { - try { - const event = JSON.parse(line) as ClaudeInitEvent; - if (event.type === 'system' && event.subtype === 'init' && event.model) { - finish(null, { - model: event.model, - }); - return; - } - } catch { - // The Claude CLI mixes non-JSON lines into verbose output; ignore them. - } - } - - finish(null, null); - }); - }); + return null; }; export class ClaudeProviderModels implements IProviderModels { @@ -161,7 +188,7 @@ export class ClaudeProviderModels implements IProviderModels { // instance, so we create a lightweight query and immediately close it // after reading the control-plane metadata. queryInstance = query({ - prompt: '', + prompt: 'Get supported models', options: buildClaudeQueryOptions(), }); @@ -181,7 +208,10 @@ export class ClaudeProviderModels implements IProviderModels { } try { - const activeModel = await runClaudeSessionModelCommand(sessionId); + const jsonlPath = sessionsDb.getSessionById(sessionId)?.jsonl_path; + const activeModel = jsonlPath + ? await readClaudeSessionModelFromJsonl(sessionId, jsonlPath) + : null; if (activeModel?.model) { return activeModel; } @@ -191,4 +221,10 @@ export class ClaudeProviderModels implements IProviderModels { return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('claude', input); + } } diff --git a/server/modules/providers/list/codex/codex-models.provider.ts b/server/modules/providers/list/codex/codex-models.provider.ts index 75a767a3..de4de0ea 100644 --- a/server/modules/providers/list/codex/codex-models.provider.ts +++ b/server/modules/providers/list/codex/codex-models.provider.ts @@ -6,14 +6,17 @@ import TOML from '@iarna/toml'; import type { IProviderModels } from '@/shared/interfaces.js'; import type { + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelOption, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; import { buildDefaultProviderCurrentActiveModel, readObjectRecord, readOptionalString, + writeProviderSessionActiveModelChange, } from '@/shared/utils.js'; export const CODEX_FALLBACK_MODELS: ProviderModelsDefinition = { @@ -113,4 +116,10 @@ export class CodexProviderModels implements IProviderModels { return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); } } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('codex', input); + } } diff --git a/server/modules/providers/list/cursor/cursor-models.provider.ts b/server/modules/providers/list/cursor/cursor-models.provider.ts index 9290e870..ccc2288a 100644 --- a/server/modules/providers/list/cursor/cursor-models.provider.ts +++ b/server/modules/providers/list/cursor/cursor-models.provider.ts @@ -7,13 +7,16 @@ import crossSpawn from 'cross-spawn'; import type { IProviderModels } from '@/shared/interfaces.js'; import type { + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelOption, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; import { buildDefaultProviderCurrentActiveModel, sanitizeLeafDirectoryName, + writeProviderSessionActiveModelChange, } from '@/shared/utils.js'; export const CURSOR_FALLBACK_MODELS: ProviderModelsDefinition = { @@ -257,4 +260,10 @@ export class CursorProviderModels implements IProviderModels { return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('cursor', input); + } } diff --git a/server/modules/providers/list/gemini/gemini-models.provider.ts b/server/modules/providers/list/gemini/gemini-models.provider.ts index 45102058..fc830394 100644 --- a/server/modules/providers/list/gemini/gemini-models.provider.ts +++ b/server/modules/providers/list/gemini/gemini-models.provider.ts @@ -1,6 +1,14 @@ import type { IProviderModels } from '@/shared/interfaces.js'; -import type { ProviderCurrentActiveModel, ProviderModelsDefinition } from '@/shared/types.js'; -import { buildDefaultProviderCurrentActiveModel } from '@/shared/utils.js'; +import type { + ProviderChangeActiveModelInput, + ProviderCurrentActiveModel, + ProviderModelsDefinition, + ProviderSessionActiveModelChange, +} from '@/shared/types.js'; +import { + buildDefaultProviderCurrentActiveModel, + writeProviderSessionActiveModelChange, +} from '@/shared/utils.js'; export const GEMINI_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -25,4 +33,10 @@ export class GeminiProviderModels implements IProviderModels { async getCurrentActiveModel(): Promise { return buildDefaultProviderCurrentActiveModel(GEMINI_FALLBACK_MODELS); } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('gemini', input); + } } diff --git a/server/modules/providers/list/opencode/opencode-models.provider.ts b/server/modules/providers/list/opencode/opencode-models.provider.ts index 8dc75ee6..1b939256 100644 --- a/server/modules/providers/list/opencode/opencode-models.provider.ts +++ b/server/modules/providers/list/opencode/opencode-models.provider.ts @@ -5,15 +5,18 @@ import crossSpawn from 'cross-spawn'; import type { IProviderModels } from '@/shared/interfaces.js'; import type { + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelOption, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; import { buildDefaultProviderCurrentActiveModel, getOpenCodeDatabasePath, readObjectRecord, readOptionalString, + writeProviderSessionActiveModelChange, } from '@/shared/utils.js'; export const OPENCODE_FALLBACK_MODELS: ProviderModelsDefinition = { @@ -220,4 +223,10 @@ export class OpenCodeProviderModels implements IProviderModels { return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('opencode', input); + } } diff --git a/server/modules/providers/provider.routes.ts b/server/modules/providers/provider.routes.ts index 1d128882..2604fcc8 100644 --- a/server/modules/providers/provider.routes.ts +++ b/server/modules/providers/provider.routes.ts @@ -6,7 +6,13 @@ import { providerModelsService } from '@/modules/providers/services/provider-mod import { providerSkillsService } from '@/modules/providers/services/skills.service.js'; import { sessionConversationsSearchService } from '@/modules/providers/services/session-conversations-search.service.js'; import { sessionsService } from '@/modules/providers/services/sessions.service.js'; -import type { LLMProvider, McpScope, McpTransport, UpsertProviderMcpServerInput } from '@/shared/types.js'; +import type { + LLMProvider, + McpScope, + McpTransport, + ProviderChangeActiveModelInput, + UpsertProviderMcpServerInput, +} from '@/shared/types.js'; import { AppError, asyncHandler, createApiSuccessResponse } from '@/shared/utils.js'; const router = express.Router(); @@ -246,6 +252,29 @@ const parseSessionSearchLimit = (value: unknown): number => { return Math.max(1, Math.min(parsed, 100)); }; +const parseChangeActiveModelPayload = (payload: unknown): ProviderChangeActiveModelInput => { + if (!payload || typeof payload !== 'object') { + throw new AppError('Request body must be an object.', { + code: 'INVALID_REQUEST_BODY', + statusCode: 400, + }); + } + + const body = payload as Record; + const model = readOptionalQueryString(body.model); + if (!model) { + throw new AppError('model is required.', { + code: 'MODEL_REQUIRED', + statusCode: 400, + }); + } + + return { + sessionId: '', + model, + }; +}; + router.get( '/:provider/auth/status', asyncHandler(async (req: Request, res: Response) => { @@ -265,6 +294,20 @@ router.get( }), ); +router.post( + '/:provider/sessions/:sessionId/active-model', + asyncHandler(async (req: Request, res: Response) => { + const provider = parseProvider(req.params.provider); + const sessionId = parseSessionId(req.params.sessionId); + const payload = parseChangeActiveModelPayload(req.body); + const result = await providerModelsService.changeActiveModel(provider, { + ...payload, + sessionId, + }); + res.json(createApiSuccessResponse(result)); + }), +); + // ----------------- Skills routes ----------------- router.get( '/:provider/skills', diff --git a/server/modules/providers/services/provider-models.service.ts b/server/modules/providers/services/provider-models.service.ts index e0883fc6..5cb62433 100644 --- a/server/modules/providers/services/provider-models.service.ts +++ b/server/modules/providers/services/provider-models.service.ts @@ -6,11 +6,14 @@ import { providerRegistry } from '@/modules/providers/provider.registry.js'; import type { IProvider } from '@/shared/interfaces.js'; import type { LLMProvider, + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelsCacheInfo, ProviderModelsDefinition, ProviderModelsResult, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; +import { readProviderSessionActiveModelChange } from '@/shared/utils.js'; export const PROVIDER_MODELS_CACHE_TTL_MS = 3 * 24 * 60 * 60 * 1000; const PROVIDER_MODELS_CACHE_VERSION = 1; @@ -18,6 +21,7 @@ const PROVIDER_MODELS_CACHE_VERSION = 1; type ProviderModelsServiceDependencies = { resolveProvider?: (provider: LLMProvider) => Pick; cachePath?: string; + activeModelChangesPath?: string; now?: () => number; }; @@ -132,6 +136,7 @@ const writeProviderModelsCacheFile = async ( export const createProviderModelsService = (dependencies: ProviderModelsServiceDependencies = {}) => { const resolveProvider = dependencies.resolveProvider ?? providerRegistry.resolveProvider; const cachePath = dependencies.cachePath ?? getProviderModelsCachePath(); + const activeModelChangesPath = dependencies.activeModelChangesPath; const now = dependencies.now ?? (() => Date.now()); const memoryCache = new Map(); const pendingRequests = new Map>(); @@ -270,6 +275,36 @@ export const createProviderModelsService = (dependencies: ProviderModelsServiceD sessionId?: string, ): Promise => resolveProvider(provider).models.getCurrentActiveModel(sessionId); + const changeActiveModel = async ( + provider: LLMProvider, + input: ProviderChangeActiveModelInput, + ): Promise => resolveProvider(provider).models.changeActiveModel(input); + + const getChangedActiveModel = async ( + provider: LLMProvider, + sessionId: string, + ): Promise => readProviderSessionActiveModelChange(provider, sessionId, { + filePath: activeModelChangesPath, + }); + + const resolveResumeModel = async ( + provider: LLMProvider, + sessionId: string | undefined, + requestedModel?: string | null, + ): Promise => { + const normalizedRequestedModel = typeof requestedModel === 'string' ? requestedModel.trim() : ''; + if (!sessionId?.trim()) { + return normalizedRequestedModel || undefined; + } + + const changedModel = await getChangedActiveModel(provider, sessionId); + if (changedModel.supported && changedModel.changed && changedModel.model?.trim()) { + return changedModel.model.trim(); + } + + return normalizedRequestedModel || undefined; + }; + const clearCache = (): void => { memoryCache.clear(); pendingRequests.clear(); @@ -280,6 +315,9 @@ export const createProviderModelsService = (dependencies: ProviderModelsServiceD return { getProviderModels, getCurrentActiveModel, + getChangedActiveModel, + changeActiveModel, + resolveResumeModel, clearCache, }; }; diff --git a/server/modules/providers/tests/provider-models.service.test.ts b/server/modules/providers/tests/provider-models.service.test.ts index 6abc4c45..fb9ebf7d 100644 --- a/server/modules/providers/tests/provider-models.service.test.ts +++ b/server/modules/providers/tests/provider-models.service.test.ts @@ -9,10 +9,13 @@ import { PROVIDER_MODELS_CACHE_TTL_MS, } from '@/modules/providers/services/provider-models.service.js'; import type { + ProviderChangeActiveModelInput, LLMProvider, ProviderCurrentActiveModel, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; +import { writeProviderSessionActiveModelChange } from '@/shared/utils.js'; const createModels = (value: string): ProviderModelsDefinition => ({ OPTIONS: [{ value, label: value }], @@ -23,6 +26,17 @@ const createCurrentActiveModel = (model: string): ProviderCurrentActiveModel => model, }); +const createSessionActiveModelChange = ( + provider: LLMProvider, + input: ProviderChangeActiveModelInput, +): ProviderSessionActiveModelChange => ({ + provider, + sessionId: input.sessionId, + supported: true, + changed: true, + model: input.model, +}); + const createEphemeralCachePath = (): string => path.join( os.tmpdir(), `provider-model-cache-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}.json`, @@ -38,6 +52,7 @@ test('provider models service delegates to the resolved provider model adapter', models: { getSupportedModels: async () => createModels(`${provider}-models`), getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), }, }; }, @@ -65,6 +80,7 @@ test('provider models service returns each provider adapter result without rewri models: { getSupportedModels: async () => expectedModels, getCurrentActiveModel: async () => createCurrentActiveModel('cursor-active'), + changeActiveModel: async (input) => createSessionActiveModelChange('cursor', input), }, }), }); @@ -90,6 +106,7 @@ test('provider models are cached for the three-day ttl', async () => { return createModels(`${provider}-${loadCount}`); }, getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), }, }), }); @@ -124,6 +141,7 @@ test('provider model cache is persisted across service instances', async () => { models: { getSupportedModels: async () => createModels('gemini-cached'), getCurrentActiveModel: async () => createCurrentActiveModel('gemini-active'), + changeActiveModel: async (input) => createSessionActiveModelChange('gemini', input), }, }), }); @@ -137,6 +155,7 @@ test('provider model cache is persisted across service instances', async () => { throw new Error('loader should not be called for persisted cache hits'); }, getCurrentActiveModel: async () => createCurrentActiveModel('gemini-active'), + changeActiveModel: async (input) => createSessionActiveModelChange('gemini', input), }, }), }); @@ -163,6 +182,7 @@ test('concurrent provider model requests share one load operation', async () => return createModels('claude-cached'); }, getCurrentActiveModel: async () => createCurrentActiveModel('claude-active'), + changeActiveModel: async (input) => createSessionActiveModelChange('claude', input), }, }), }); @@ -196,6 +216,7 @@ test('bypassCache forces a fresh provider fetch and updates cache metadata', asy return createModels(`${provider}-${loadCount}`); }, getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active-${loadCount}`), + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), }, }), }); @@ -224,6 +245,7 @@ test('provider models service delegates current active model lookups to the prov calls.push({ provider, sessionId }); return createCurrentActiveModel(`${provider}-${sessionId}`); }, + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), }, }), }); @@ -233,3 +255,64 @@ test('provider models service delegates current active model lookups to the prov assert.deepEqual(calls, [{ provider: 'opencode', sessionId: 'session-123' }]); assert.equal(activeModel.model, 'opencode-session-123'); }); + +test('provider models service delegates active model change requests to the provider adapter', async () => { + const calls: Array<{ provider: LLMProvider; input: ProviderChangeActiveModelInput }> = []; + const service = createProviderModelsService({ + resolveProvider: (provider) => ({ + models: { + getSupportedModels: async () => createModels(`${provider}-models`), + getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), + changeActiveModel: async (input) => { + calls.push({ provider, input }); + return createSessionActiveModelChange(provider, input); + }, + }, + }), + }); + + const changedModel = await service.changeActiveModel('claude', { + sessionId: 'session-123', + model: 'opus', + }); + + assert.deepEqual(calls, [{ + provider: 'claude', + input: { + sessionId: 'session-123', + model: 'opus', + }, + }]); + assert.equal(changedModel.changed, true); + assert.equal(changedModel.model, 'opus'); +}); + +test('resolveResumeModel prefers a stored changed model over the requested one', async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), 'provider-model-change-')); + const activeModelChangesPath = path.join(tempRoot, 'session-model-changes.json'); + + try { + const service = createProviderModelsService({ + activeModelChangesPath, + resolveProvider: (provider) => ({ + models: { + getSupportedModels: async () => createModels(`${provider}-models`), + getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), + }, + }), + }); + + await writeProviderSessionActiveModelChange('cursor', { + sessionId: 'session-456', + model: 'composer-2', + }, { + filePath: activeModelChangesPath, + }); + + const model = await service.resolveResumeModel('cursor', 'session-456', 'composer-2-fast'); + assert.equal(model, 'composer-2'); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } +}); diff --git a/server/openai-codex.js b/server/openai-codex.js index 03497c30..95913586 100644 --- a/server/openai-codex.js +++ b/server/openai-codex.js @@ -17,6 +17,7 @@ import { Codex } from '@openai/codex-sdk'; import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js'; import { sessionsService } from './modules/providers/services/sessions.service.js'; import { providerAuthService } from './modules/providers/services/provider-auth.service.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { createNormalizedMessage } from './shared/utils.js'; // Track active sessions @@ -202,6 +203,12 @@ export async function queryCodex(command, options = {}, ws) { permissionMode = 'default' } = options; + const resolvedModel = await providerModelsService.resolveResumeModel( + 'codex', + sessionId, + model, + ); + const workingDirectory = cwd || projectPath || process.cwd(); const { sandboxMode, approvalPolicy } = mapPermissionModeToCodexOptions(permissionMode); @@ -222,7 +229,7 @@ export async function queryCodex(command, options = {}, ws) { skipGitRepoCheck: true, sandboxMode, approvalPolicy, - model + model: resolvedModel }; // Start or resume thread diff --git a/server/opencode-cli.js b/server/opencode-cli.js index 9d3c3851..ceabc082 100644 --- a/server/opencode-cli.js +++ b/server/opencode-cli.js @@ -4,6 +4,7 @@ import crossSpawn from 'cross-spawn'; import { sessionsService } from './modules/providers/services/sessions.service.js'; import { providerAuthService } from './modules/providers/services/provider-auth.service.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js'; import { createNormalizedMessage } from './shared/utils.js'; @@ -29,17 +30,6 @@ async function spawnOpenCode(command, options = {}, ws) { let stdoutLineBuffer = ''; let terminalNotificationSent = false; - const args = ['run', '--format', 'json']; - if (sessionId) { - args.push('--session', sessionId); - } - if (model) { - args.push('--model', model); - } - if (command && command.trim()) { - args.push(command.trim()); - } - const notifyTerminalState = ({ code = null, error = null } = {}) => { if (terminalNotificationSent) { return; @@ -67,16 +57,6 @@ async function spawnOpenCode(command, options = {}, ws) { }); }; - const opencodeProcess = spawnFunction('opencode', args, { - cwd: workingDir, - stdio: ['pipe', 'pipe', 'pipe'], - env: { ...process.env }, - }); - - activeOpenCodeProcesses.set(processKey, opencodeProcess); - opencodeProcess.sessionId = processKey; - opencodeProcess.stdin.end(); - const registerSession = (nextSessionId) => { if (!nextSessionId || capturedSessionId === nextSessionId) { return; @@ -130,89 +110,112 @@ async function spawnOpenCode(command, options = {}, ws) { } }; - opencodeProcess.stdout.on('data', (data) => { - stdoutLineBuffer += data.toString(); - const completeLines = stdoutLineBuffer.split(/\r?\n/); - stdoutLineBuffer = completeLines.pop() || ''; + void providerModelsService.resolveResumeModel('opencode', sessionId, model).then((resolvedModel) => { + const args = ['run', '--format', 'json']; + if (sessionId) { + args.push('--session', sessionId); + } + if (resolvedModel) { + args.push('--model', resolvedModel); + } + if (command && command.trim()) { + args.push(command.trim()); + } - completeLines.forEach((line) => { - processOpenCodeOutputLine(line.trim()); + const opencodeProcess = spawnFunction('opencode', args, { + cwd: workingDir, + stdio: ['pipe', 'pipe', 'pipe'], + env: { ...process.env }, }); - }); - opencodeProcess.stderr.on('data', (data) => { - const stderrText = data.toString(); - if (!stderrText.trim()) { - return; - } + activeOpenCodeProcesses.set(processKey, opencodeProcess); + opencodeProcess.sessionId = processKey; + opencodeProcess.stdin.end(); - ws.send(createNormalizedMessage({ - kind: 'error', - content: stderrText, - sessionId: capturedSessionId || sessionId || null, - provider: 'opencode', - })); - }); + opencodeProcess.stdout.on('data', (data) => { + stdoutLineBuffer += data.toString(); + const completeLines = stdoutLineBuffer.split(/\r?\n/); + stdoutLineBuffer = completeLines.pop() || ''; - opencodeProcess.on('close', async (code) => { - const finalSessionId = capturedSessionId || sessionId || processKey; - activeOpenCodeProcesses.delete(finalSessionId); - activeOpenCodeProcesses.delete(processKey); + completeLines.forEach((line) => { + processOpenCodeOutputLine(line.trim()); + }); + }); - if (stdoutLineBuffer.trim()) { - processOpenCodeOutputLine(stdoutLineBuffer.trim()); - stdoutLineBuffer = ''; - } - - ws.send(createNormalizedMessage({ - kind: 'complete', - exitCode: code, - isNewSession: !sessionId && !!command, - sessionId: finalSessionId, - provider: 'opencode', - })); - - if (code === 0) { - notifyTerminalState({ code }); - resolve(); - return; - } - - if (code === 127 || code === null) { - const installed = await providerAuthService.isProviderInstalled('opencode'); - if (!installed) { - ws.send(createNormalizedMessage({ - kind: 'error', - content: 'OpenCode CLI is not installed. Install it from https://opencode.ai/docs/', - sessionId: finalSessionId, - provider: 'opencode', - })); + opencodeProcess.stderr.on('data', (data) => { + const stderrText = data.toString(); + if (!stderrText.trim()) { + return; } - } - notifyTerminalState({ code }); - reject(new Error(code === null ? 'OpenCode CLI process was terminated' : `OpenCode CLI exited with code ${code}`)); - }); + ws.send(createNormalizedMessage({ + kind: 'error', + content: stderrText, + sessionId: capturedSessionId || sessionId || null, + provider: 'opencode', + })); + }); - opencodeProcess.on('error', async (error) => { - const finalSessionId = capturedSessionId || sessionId || processKey; - activeOpenCodeProcesses.delete(finalSessionId); - activeOpenCodeProcesses.delete(processKey); + opencodeProcess.on('close', async (code) => { + const finalSessionId = capturedSessionId || sessionId || processKey; + activeOpenCodeProcesses.delete(finalSessionId); + activeOpenCodeProcesses.delete(processKey); - const installed = await providerAuthService.isProviderInstalled('opencode'); - const errorContent = !installed - ? 'OpenCode CLI is not installed. Install it from https://opencode.ai/docs/' - : error.message; + if (stdoutLineBuffer.trim()) { + processOpenCodeOutputLine(stdoutLineBuffer.trim()); + stdoutLineBuffer = ''; + } - ws.send(createNormalizedMessage({ - kind: 'error', - content: errorContent, - sessionId: finalSessionId, - provider: 'opencode', - })); - notifyTerminalState({ error }); - reject(error); - }); + ws.send(createNormalizedMessage({ + kind: 'complete', + exitCode: code, + isNewSession: !sessionId && !!command, + sessionId: finalSessionId, + provider: 'opencode', + })); + + if (code === 0) { + notifyTerminalState({ code }); + resolve(); + return; + } + + if (code === 127 || code === null) { + const installed = await providerAuthService.isProviderInstalled('opencode'); + if (!installed) { + ws.send(createNormalizedMessage({ + kind: 'error', + content: 'OpenCode CLI is not installed. Install it from https://opencode.ai/docs/', + sessionId: finalSessionId, + provider: 'opencode', + })); + } + } + + notifyTerminalState({ code }); + reject(new Error(code === null ? 'OpenCode CLI process was terminated' : `OpenCode CLI exited with code ${code}`)); + }); + + opencodeProcess.on('error', async (error) => { + const finalSessionId = capturedSessionId || sessionId || processKey; + activeOpenCodeProcesses.delete(finalSessionId); + activeOpenCodeProcesses.delete(processKey); + + const installed = await providerAuthService.isProviderInstalled('opencode'); + const errorContent = !installed + ? 'OpenCode CLI is not installed. Install it from https://opencode.ai/docs/' + : error.message; + + ws.send(createNormalizedMessage({ + kind: 'error', + content: errorContent, + sessionId: finalSessionId, + provider: 'opencode', + })); + notifyTerminalState({ error }); + reject(error); + }); + }).catch(reject); }); } diff --git a/server/shared/interfaces.ts b/server/shared/interfaces.ts index 9228698e..1864fd02 100644 --- a/server/shared/interfaces.ts +++ b/server/shared/interfaces.ts @@ -7,9 +7,11 @@ import type { ProviderSkill, ProviderSkillListOptions, ProviderAuthStatus, + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelsDefinition, ProviderMcpServer, + ProviderSessionActiveModelChange, UpsertProviderMcpServerInput, } from '@/shared/types.js'; @@ -55,6 +57,19 @@ export interface IProviderModels { * no active model can be resolved. */ getCurrentActiveModel(sessionId?: string): Promise; + + /** + * Persists a session-scoped model override that the next resumed turn should + * honor for this provider. + * + * This does not require the provider to mutate an already running remote + * session in-place. Instead, adapters store the user's explicit model choice + * so the backend resume path can add the correct provider-native model option + * on the next CLI/SDK invocation for the same session. + */ + changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise; } // --------------------------- diff --git a/server/shared/types.ts b/server/shared/types.ts index 121cc509..7d981988 100644 --- a/server/shared/types.ts +++ b/server/shared/types.ts @@ -123,6 +123,37 @@ export type ProviderCurrentActiveModel = { model: string; }; +/** + * Input payload used when one session needs to use a different model on its + * next resumed turn. + * + * This is a backend-owned session override, not a claim that the provider has + * already switched the currently running session in-place. Provider adapters + * persist this request so the next CLI/SDK resume can inject the chosen model + * using the provider-specific mechanism supported by that runtime. + */ +export type ProviderChangeActiveModelInput = { + sessionId: string; + model: string; +}; + +/** + * Provider-neutral session model-change state. + * + * `supported` indicates whether the provider adapter supports the app's + * session-scoped resume override flow. `changed` is the persisted boolean the + * resume layer checks before forcing a model on the next resumed turn. When + * `changed` is `false`, `model` is `null` and the runtime should use the + * normal request/default model selection path. + */ +export type ProviderSessionActiveModelChange = { + provider: LLMProvider; + sessionId: string; + supported: boolean; + changed: boolean; + model: string | null; +}; + /** * Message/event variants emitted by provider adapters and normalized transports. * diff --git a/server/shared/utils.ts b/server/shared/utils.ts index 437ea034..821489c5 100644 --- a/server/shared/utils.ts +++ b/server/shared/utils.ts @@ -22,9 +22,12 @@ import type { AnyRecord, ApiSuccessShape, AppErrorOptions, + LLMProvider, NormalizedMessage, + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelsDefinition, + ProviderSessionActiveModelChange, ProviderSkillSource, WorkspacePathValidationResult, } from '@/shared/types.js'; @@ -434,6 +437,213 @@ export function buildDefaultProviderCurrentActiveModel( }; } +// --------------------------- +//----------------- PROVIDER SESSION MODEL CHANGE UTILITIES ------------ +type ProviderSessionActiveModelChangeCacheEntry = ProviderSessionActiveModelChange & { + updatedAt: string; +}; + +type ProviderSessionActiveModelChangeCacheFile = { + version: number; + entries: Record; +}; + +const PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION = 1; + +/** + * Resolves the backend-owned cache file used for session-scoped resume model + * overrides. + * + * The file lives under `~/.cloudcli` because these overrides are an application + * concern rather than a provider-native config file. Providers, routes, and + * runtime command launchers should all use this helper instead of re-creating + * the path so the storage location stays consistent. + */ +export function getProviderSessionActiveModelChangesPath(): string { + return path.join(os.homedir(), '.cloudcli', 'provider-session-active-model-changes.json'); +} + +const buildProviderSessionActiveModelChangeKey = ( + provider: LLMProvider, + sessionId: string, +): string => `${provider}:${sessionId}`; + +const isProviderSessionActiveModelChangeCacheEntry = ( + value: unknown, +): value is ProviderSessionActiveModelChangeCacheEntry => { + const record = readObjectRecord(value); + return Boolean( + record + && typeof record.provider === 'string' + && typeof record.sessionId === 'string' + && typeof record.supported === 'boolean' + && typeof record.changed === 'boolean' + && (typeof record.model === 'string' || record.model === null) + && typeof record.updatedAt === 'string', + ); +}; + +const readProviderSessionActiveModelChangeCacheFile = async ( + filePath: string, +): Promise => { + try { + const raw = await readFile(filePath, 'utf8'); + const parsed = readObjectRecord(JSON.parse(raw)); + if ( + !parsed + || parsed.version !== PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION + || !readObjectRecord(parsed.entries) + ) { + return { + version: PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION, + entries: {}, + }; + } + + const entries = Object.fromEntries( + Object.entries(parsed.entries).filter((entry): entry is [string, ProviderSessionActiveModelChangeCacheEntry] => + isProviderSessionActiveModelChangeCacheEntry(entry[1]), + ), + ); + + return { + version: PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION, + entries, + }; + } catch { + return { + version: PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION, + entries: {}, + }; + } +}; + +const writeProviderSessionActiveModelChangeCacheFile = async ( + filePath: string, + payload: ProviderSessionActiveModelChangeCacheFile, +): Promise => { + await mkdir(path.dirname(filePath), { recursive: true }); + await writeFile(filePath, `${JSON.stringify(payload, null, 2)}\n`, 'utf8'); +}; + +const buildUnsupportedProviderSessionActiveModelChange = ( + provider: LLMProvider, + sessionId: string, +): ProviderSessionActiveModelChange => ({ + provider, + sessionId, + supported: false, + changed: false, + model: null, +}); + +/** + * Reads the persisted session model-change state for one provider session. + * + * Runtime resume paths use this to decide whether they should inject a + * provider-specific model argument/thread option for the next resumed turn. + * Missing cache entries are normalized to `{ changed: false }` so callers can + * treat absence as "use the ordinary model selection flow". + */ +export async function readProviderSessionActiveModelChange( + provider: LLMProvider, + sessionId: string, + options: { + filePath?: string; + supported?: boolean; + } = {}, +): Promise { + const normalizedSessionId = sessionId.trim(); + if (!normalizedSessionId) { + return buildUnsupportedProviderSessionActiveModelChange(provider, normalizedSessionId); + } + + const supported = options.supported ?? true; + if (!supported) { + return buildUnsupportedProviderSessionActiveModelChange(provider, normalizedSessionId); + } + + const filePath = options.filePath ?? getProviderSessionActiveModelChangesPath(); + const cacheFile = await readProviderSessionActiveModelChangeCacheFile(filePath); + const cacheEntry = cacheFile.entries[ + buildProviderSessionActiveModelChangeKey(provider, normalizedSessionId) + ]; + + if (!cacheEntry || !cacheEntry.changed || !cacheEntry.model?.trim()) { + return { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: false, + model: null, + }; + } + + return { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: true, + model: cacheEntry.model.trim(), + }; +} + +/** + * Persists a session model-change request for one provider. + * + * Provider adapters call this when the frontend explicitly selects a different + * model for an existing session. The stored `changed: true` flag is the single + * source of truth used later by resume paths to decide whether they should add + * a provider-native model override on the next invocation. + */ +export async function writeProviderSessionActiveModelChange( + provider: LLMProvider, + input: ProviderChangeActiveModelInput, + options: { + filePath?: string; + supported?: boolean; + } = {}, +): Promise { + const normalizedSessionId = input.sessionId.trim(); + const normalizedModel = input.model.trim(); + const supported = options.supported ?? true; + + if (!supported) { + return buildUnsupportedProviderSessionActiveModelChange(provider, normalizedSessionId); + } + + if (!normalizedSessionId || !normalizedModel) { + return { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: false, + model: null, + }; + } + + const filePath = options.filePath ?? getProviderSessionActiveModelChangesPath(); + const cacheFile = await readProviderSessionActiveModelChangeCacheFile(filePath); + cacheFile.entries[buildProviderSessionActiveModelChangeKey(provider, normalizedSessionId)] = { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: true, + model: normalizedModel, + updatedAt: new Date().toISOString(), + }; + + await writeProviderSessionActiveModelChangeCacheFile(filePath, cacheFile); + + return { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: true, + model: normalizedModel, + }; +} + // --------------------------- //----------------- WEBSOCKET PAYLOAD PARSING UTILITIES ------------ /** diff --git a/src/components/chat/hooks/useChatProviderState.ts b/src/components/chat/hooks/useChatProviderState.ts index e8f0ed5c..33c54d4b 100644 --- a/src/components/chat/hooks/useChatProviderState.ts +++ b/src/components/chat/hooks/useChatProviderState.ts @@ -43,6 +43,17 @@ type ProviderModelsApiResponse = { }; }; +type ChangeActiveModelApiResponse = { + success?: boolean; + data?: { + provider?: LLMProvider; + sessionId?: string; + supported?: boolean; + changed?: boolean; + model?: string | null; + }; +}; + export function useChatProviderState({ selectedSession, selectedProject }: UseChatProviderStateArgs) { const [permissionMode, setPermissionMode] = useState('default'); const [pendingPermissionRequests, setPendingPermissionRequests] = useState([]); @@ -77,6 +88,35 @@ export function useChatProviderState({ selectedSession, selectedProject }: UseCh const lastProviderRef = useRef(provider); const providerModelsRequestIdRef = useRef(0); + const setStoredProviderModel = useCallback((targetProvider: LLMProvider, model: string) => { + if (targetProvider === 'claude') { + setClaudeModel(model); + localStorage.setItem('claude-model', model); + return; + } + + if (targetProvider === 'cursor') { + setCursorModel(model); + localStorage.setItem('cursor-model', model); + return; + } + + if (targetProvider === 'codex') { + setCodexModel(model); + localStorage.setItem('codex-model', model); + return; + } + + if (targetProvider === 'gemini') { + setGeminiModel(model); + localStorage.setItem('gemini-model', model); + return; + } + + setOpenCodeModel(model); + localStorage.setItem('opencode-model', model); + }, []); + const loadProviderModels = useCallback(async (options: { bypassCache?: boolean } = {}) => { const providers: LLMProvider[] = ['claude', 'cursor', 'codex', 'gemini', 'opencode']; const requestId = providerModelsRequestIdRef.current + 1; @@ -289,6 +329,41 @@ export function useChatProviderState({ selectedSession, selectedProject }: UseCh } }, [permissionMode, provider, selectedSession?.id]); + const selectProviderModel = useCallback(async ( + targetProvider: LLMProvider, + model: string, + sessionId?: string | null, + ) => { + const normalizedSessionId = typeof sessionId === 'string' ? sessionId.trim() : ''; + if (!normalizedSessionId) { + setStoredProviderModel(targetProvider, model); + return { + scope: 'default' as const, + changed: false, + model, + }; + } + + const response = await authenticatedFetch( + `/api/providers/${targetProvider}/sessions/${encodeURIComponent(normalizedSessionId)}/active-model`, + { + method: 'POST', + body: JSON.stringify({ model }), + }, + ); + + const body = (await response.json()) as ChangeActiveModelApiResponse; + if (!response.ok || !body.success || !body.data?.supported) { + throw new Error('Unable to change the active model for this session.'); + } + + return { + scope: 'session' as const, + changed: body.data.changed === true, + model: body.data.model || model, + }; + }, [setStoredProviderModel]); + return { provider, setProvider, @@ -312,5 +387,6 @@ export function useChatProviderState({ selectedSession, selectedProject }: UseCh providerModelsLoading, providerModelsRefreshing, hardRefreshProviderModels: () => loadProviderModels({ bypassCache: true }), + selectProviderModel, }; } diff --git a/src/components/chat/view/ChatInterface.tsx b/src/components/chat/view/ChatInterface.tsx index c724bc9c..63e0cbb5 100644 --- a/src/components/chat/view/ChatInterface.tsx +++ b/src/components/chat/view/ChatInterface.tsx @@ -83,6 +83,7 @@ function ChatInterface({ providerModelsLoading, providerModelsRefreshing, hardRefreshProviderModels, + selectProviderModel, } = useChatProviderState({ selectedSession, selectedProject, @@ -441,6 +442,8 @@ function ChatInterface({ providerModelCacheCatalog={providerModelCacheCatalog} providerModelsRefreshing={providerModelsRefreshing} onHardRefreshProviderModels={hardRefreshProviderModels} + currentSessionId={currentSessionId || selectedSession?.id || null} + onSelectProviderModel={selectProviderModel} /> ); diff --git a/src/components/chat/view/subcomponents/CommandResultModal.tsx b/src/components/chat/view/subcomponents/CommandResultModal.tsx index f0163bba..8faa677b 100644 --- a/src/components/chat/view/subcomponents/CommandResultModal.tsx +++ b/src/components/chat/view/subcomponents/CommandResultModal.tsx @@ -38,6 +38,16 @@ type CommandResultModalProps = { providerModelCacheCatalog: Partial>; providerModelsRefreshing: boolean; onHardRefreshProviderModels: () => void; + currentSessionId: string | null; + onSelectProviderModel: ( + provider: LLMProvider, + model: string, + sessionId?: string | null, + ) => Promise<{ + scope: 'default' | 'session'; + changed: boolean; + model: string; + }>; }; type CommandEntry = { @@ -254,15 +264,22 @@ function ModelsContent({ providerModelCacheCatalog, providerModelsRefreshing, onHardRefreshProviderModels, + currentSessionId, + onSelectProviderModel, }: { data: ModelCommandData; providerModelCatalog: Partial>; providerModelCacheCatalog: Partial>; providerModelsRefreshing: boolean; onHardRefreshProviderModels: () => void; + currentSessionId: string | null; + onSelectProviderModel: CommandResultModalProps['onSelectProviderModel']; }) { const [query, setQuery] = useState(''); const [copiedModel, setCopiedModel] = useState(null); + const [changingModel, setChangingModel] = useState(null); + const [pendingSessionModel, setPendingSessionModel] = useState(null); + const [selectionNotice, setSelectionNotice] = useState(null); const currentProvider = (data?.current?.provider || 'claude') as LLMProvider; const currentModel = data?.current?.model || 'Unknown'; const providerLabel = data?.current?.providerLabel || getProviderLabel(currentProvider); @@ -295,6 +312,7 @@ function ModelsContent({ }, [availableOptions, query]); const activeOption = availableOptions.find((option) => option.value === currentModel); + const hasConcreteSessionId = typeof currentSessionId === 'string' && currentSessionId.trim().length > 0; const copyModel = (model: string) => { if (typeof navigator !== 'undefined' && navigator.clipboard) { @@ -306,6 +324,26 @@ function ModelsContent({ }, 1300); }; + const handleSelectModel = async (model: string) => { + setChangingModel(model); + try { + const result = await onSelectProviderModel(currentProvider, model, currentSessionId); + if (result.scope === 'session') { + setPendingSessionModel(result.model); + setSelectionNotice(`Next response will resume with ${result.model}.`); + return; + } + + setPendingSessionModel(null); + setSelectionNotice(`Default ${providerLabel} model set to ${result.model}.`); + } catch (error) { + const message = error instanceof Error ? error.message : 'Unable to change the model right now.'; + setSelectionNotice(message); + } finally { + setChangingModel(null); + } + }; + return (
@@ -331,6 +369,13 @@ function ModelsContent({
+
+ {hasConcreteSessionId + ? 'Selecting a model stores a session override and applies it on the next response for this session.' + : 'Selecting a model updates the default model used for new turns in this provider.'} + {selectionNotice && {selectionNotice}} +
+
@@ -344,6 +389,11 @@ function ModelsContent({ {activeOption?.description && (

{activeOption.description}

)} + {pendingSessionModel && pendingSessionModel !== currentModel && ( +

+ Next response: {pendingSessionModel} +

+ )}
Live
@@ -367,20 +417,27 @@ function ModelsContent({ {filteredOptions.map((option, index) => { const isCurrent = option.value === currentModel; const wasCopied = copiedModel === option.value; + const isPendingSelection = option.value === pendingSessionModel; + const isChanging = option.value === changingModel; return ( - + + +
); })}
@@ -521,6 +593,8 @@ export default function CommandResultModal({ providerModelCacheCatalog, providerModelsRefreshing, onHardRefreshProviderModels, + currentSessionId, + onSelectProviderModel, }: CommandResultModalProps) { const isOpen = Boolean(payload); const kind = payload?.kind; @@ -613,6 +687,8 @@ export default function CommandResultModal({ providerModelCacheCatalog={providerModelCacheCatalog} providerModelsRefreshing={providerModelsRefreshing} onHardRefreshProviderModels={onHardRefreshProviderModels} + currentSessionId={currentSessionId} + onSelectProviderModel={onSelectProviderModel} /> )} {payload?.kind === 'cost' && }