From 9aa927002e3808f1381d6aeeaabaace8e4f3f074 Mon Sep 17 00:00:00 2001 From: Haileyesus <118998054+blackmammoth@users.noreply.github.com> Date: Mon, 18 May 2026 16:57:29 +0300 Subject: [PATCH] feat: support session-scoped model overrides Model selection was acting like a provider-level preference. That made resumed sessions drift back to a default or request-time model. Users expect /models changes made inside a conversation to affect that session. Store explicit session choices in app-owned ~/.cloudcli state. This avoids editing provider transcripts or native provider config. Resolve the effective model before launching each provider runtime. Claude, Cursor, Codex, Gemini, and OpenCode now honor stored resume choices. Expose a backend active-model change endpoint for existing sessions. The models modal can now distinguish default changes from session overrides. It also shows when a selected model will apply on the next response. For Claude, stop probing active model state by resuming with a dummy prompt. Read the indexed JSONL transcript from the end instead. This preserves provider history while honoring /model stdout or model fields. Add service tests for adapter delegation and resume-model precedence. The tests keep cache state, override state, and requested fallback separate. --- server/claude-sdk.js | 12 +- server/cursor-cli.js | 9 +- server/gemini-cli.js | 8 +- .../list/claude/claude-models.provider.ts | 192 +++++++++------- .../list/codex/codex-models.provider.ts | 9 + .../list/cursor/cursor-models.provider.ts | 9 + .../list/gemini/gemini-models.provider.ts | 18 +- .../list/opencode/opencode-models.provider.ts | 9 + server/modules/providers/provider.routes.ts | 45 +++- .../services/provider-models.service.ts | 38 ++++ .../tests/provider-models.service.test.ts | 83 +++++++ server/openai-codex.js | 9 +- server/opencode-cli.js | 189 ++++++++-------- server/shared/interfaces.ts | 15 ++ server/shared/types.ts | 31 +++ server/shared/utils.ts | 210 ++++++++++++++++++ .../chat/hooks/useChatProviderState.ts | 76 +++++++ src/components/chat/view/ChatInterface.tsx | 3 + .../view/subcomponents/CommandResultModal.tsx | 98 +++++++- 19 files changed, 872 insertions(+), 191 deletions(-) diff --git a/server/claude-sdk.js b/server/claude-sdk.js index dc48a92f..da15142d 100644 --- a/server/claude-sdk.js +++ b/server/claude-sdk.js @@ -18,6 +18,7 @@ import { promises as fs } from 'fs'; import path from 'path'; import os from 'os'; import { CLAUDE_FALLBACK_MODELS } from './modules/providers/list/claude/claude-models.provider.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { resolveClaudeCodeExecutablePath } from './shared/claude-cli-path.js'; import { createNotificationEvent, @@ -491,8 +492,17 @@ async function queryClaudeSDK(command, options = {}, ws) { }; try { + const resolvedModel = await providerModelsService.resolveResumeModel( + 'claude', + sessionId, + options.model, + ); + // Map CLI options to SDK format - const sdkOptions = mapCliOptionsToSDK(options); + const sdkOptions = mapCliOptionsToSDK({ + ...options, + model: resolvedModel || options.model, + }); // Load MCP configuration const mcpServers = await loadMcpConfig(options.cwd); diff --git a/server/cursor-cli.js b/server/cursor-cli.js index 1d5a7d79..6cfd7bac 100644 --- a/server/cursor-cli.js +++ b/server/cursor-cli.js @@ -3,6 +3,7 @@ import crossSpawn from 'cross-spawn'; import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js'; import { sessionsService } from './modules/providers/services/sessions.service.js'; import { providerAuthService } from './modules/providers/services/provider-auth.service.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { createNormalizedMessage } from './shared/utils.js'; // Use cross-spawn on Windows for better command execution @@ -28,6 +29,7 @@ function isWorkspaceTrustPrompt(text = '') { async function spawnCursor(command, options = {}, ws) { return new Promise(async (resolve, reject) => { const { sessionId, projectPath, cwd, resume, toolsSettings, skipPermissions, model, sessionSummary } = options; + const resolvedModel = await providerModelsService.resolveResumeModel('cursor', sessionId, model); let capturedSessionId = sessionId; // Track session ID throughout the process let sessionCreatedSent = false; // Track if we've already sent session-created event let hasRetriedWithTrust = false; @@ -52,9 +54,10 @@ async function spawnCursor(command, options = {}, ws) { // Provide a prompt (works for both new and resumed sessions) baseArgs.push('-p', command); - // Add model flag if specified (only meaningful for new sessions; harmless on resume) - if (!sessionId && model) { - baseArgs.push('--model', model); + // Model overrides are applied to both new and resumed sessions so a + // session-scoped change request can take effect on the next turn. + if (resolvedModel) { + baseArgs.push('--model', resolvedModel); } // Request streaming JSON when we are providing a prompt diff --git a/server/gemini-cli.js b/server/gemini-cli.js index 1f45682c..ee1fd845 100644 --- a/server/gemini-cli.js +++ b/server/gemini-cli.js @@ -9,6 +9,7 @@ import sessionManager from './sessionManager.js'; import GeminiResponseHandler from './gemini-response-handler.js'; import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js'; import { providerAuthService } from './modules/providers/services/provider-auth.service.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { createNormalizedMessage } from './shared/utils.js'; // Use cross-spawn on Windows for correct .cmd resolution (same pattern as cursor-cli.js) @@ -120,6 +121,11 @@ async function buildGeminiProcessEnv() { async function spawnGemini(command, options = {}, ws) { const { sessionId, projectPath, cwd, toolsSettings, permissionMode, images, sessionSummary } = options; + const resolvedModel = await providerModelsService.resolveResumeModel( + 'gemini', + sessionId, + options.model + ); let capturedSessionId = sessionId; // Track session ID throughout the process let sessionCreatedSent = false; // Track if we've already sent session-created event let assistantBlocks = []; // Accumulate the full response blocks including tools @@ -244,7 +250,7 @@ async function spawnGemini(command, options = {}, ws) { } // Add model for all sessions (both new and resumed) - let modelToUse = options.model || 'gemini-2.5-flash'; + let modelToUse = resolvedModel || 'gemini-2.5-flash'; args.push('--model', modelToUse); args.push('--output-format', 'stream-json'); diff --git a/server/modules/providers/list/claude/claude-models.provider.ts b/server/modules/providers/list/claude/claude-models.provider.ts index b0d845a3..02e04194 100644 --- a/server/modules/providers/list/claude/claude-models.provider.ts +++ b/server/modules/providers/list/claude/claude-models.provider.ts @@ -1,16 +1,21 @@ -import { spawn } from 'node:child_process'; +import { readFile } from 'node:fs/promises'; import { query, type ModelInfo, type Options } from '@anthropic-ai/claude-agent-sdk'; -import crossSpawn from 'cross-spawn'; +import { sessionsDb } from '@/modules/database/index.js'; import { resolveClaudeCodeExecutablePath } from '@/shared/claude-cli-path.js'; import type { IProviderModels } from '@/shared/interfaces.js'; import type { + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelOption, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; -import { buildDefaultProviderCurrentActiveModel } from '@/shared/utils.js'; +import { + buildDefaultProviderCurrentActiveModel, + writeProviderSessionActiveModelChange, +} from '@/shared/utils.js'; export const CLAUDE_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -26,13 +31,23 @@ export const CLAUDE_FALLBACK_MODELS: ProviderModelsDefinition = { type ClaudeModelQueryOptions = Pick; type ClaudeInitEvent = { + sessionId?: string; + session_id?: string; type?: string; subtype?: string; model?: string; + message?: { + content?: unknown; + model?: string; + }; }; -const CLAUDE_ACTIVE_MODEL_TIMEOUT_MS = 20_000; -const claudeSpawn = process.platform === 'win32' ? crossSpawn : spawn; +const ANSI_PATTERN = new RegExp( + '[\\u001B\\u009B][[\\]()#;?]*(?:' + + '(?:[0-9]{1,4}(?:;[0-9]{0,4})*)?[0-9A-ORZcf-nqry=><]' + + '|(?:[\\dA-PR-TZcf-ntqry=><~]))', + 'g', +); const buildClaudeQueryOptions = (): ClaudeModelQueryOptions => ({ env: { ...process.env }, @@ -74,82 +89,94 @@ const buildClaudeModelsDefinition = (models: ModelInfo[]): ProviderModelsDefinit }; }; -const runClaudeSessionModelCommand = async (sessionId: string): Promise => { - const cliPath = resolveClaudeCodeExecutablePath(process.env.CLAUDE_CLI_PATH); +const extractClaudeEventModel = (event: ClaudeInitEvent, sessionId: string): string | null => { + const eventSessionId = event.sessionId ?? event.session_id; + if (eventSessionId && eventSessionId !== sessionId) { + return null; + } - return new Promise((resolve, reject) => { - const child = claudeSpawn( - cliPath, - ['-p', '--verbose', '--output-format', 'stream-json', '--resume', sessionId, 'ok'], - { - env: { ...process.env }, - windowsHide: true, - }, - ); + const contentModel = extractClaudeModelFromMessageContent(event.message?.content); + if (contentModel) { + return contentModel; + } - let stdout = ''; - let stderr = ''; - let settled = false; + const directModel = event.model?.trim(); + if (directModel) { + return directModel; + } - const timer = setTimeout(() => { - child.kill('SIGTERM'); - if (!settled) { - settled = true; - reject(new Error('Claude current-model lookup timed out')); + const messageModel = event.message?.model?.trim(); + return messageModel || null; +}; + +const stripAnsi = (value: string): string => value.replace(ANSI_PATTERN, ''); + +const extractTaggedContent = (content: string, tagName: string): string | null => { + const escapedTagName = tagName.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + const match = new RegExp(`<${escapedTagName}>([\\s\\S]*?)<\\/${escapedTagName}>`).exec(content); + return match ? match[1] : null; +}; + +const extractClaudeModelFromTextContent = (content: string): string | null => { + const localCommandStdout = extractTaggedContent(content, 'local-command-stdout'); + if (localCommandStdout !== null) { + const cleanedStdout = stripAnsi(localCommandStdout).replace(/\s+/g, ' ').trim(); + const changedModel = /(?:set|changed|switched)\s+model\s+to\s+(.+?)\.?$/i.exec(cleanedStdout); + if (changedModel?.[1]?.trim()) { + return changedModel[1].trim(); + } + } + + const modelTag = extractTaggedContent(content, 'model')?.trim(); + return modelTag || null; +}; + +const extractClaudeModelFromMessageContent = (content: unknown): string | null => { + if (typeof content === 'string') { + return extractClaudeModelFromTextContent(content); + } + + if (!Array.isArray(content)) { + return null; + } + + for (const part of content) { + if (!part || typeof part !== 'object' || !('text' in part) || typeof part.text !== 'string') { + continue; + } + + const model = extractClaudeModelFromTextContent(part.text); + if (model) { + return model; + } + } + + return null; +}; + +const readClaudeSessionModelFromJsonl = async ( + sessionId: string, + jsonlPath: string, +): Promise => { + const content = await readFile(jsonlPath, 'utf8'); + const lines = content + .split(/\r?\n/) + .map((line) => line.trim()) + .filter(Boolean); + + for (let index = lines.length - 1; index >= 0; index -= 1) { + try { + const event = JSON.parse(lines[index]) as ClaudeInitEvent; + const model = extractClaudeEventModel(event, sessionId); + if (model) { + return { model }; } - }, CLAUDE_ACTIVE_MODEL_TIMEOUT_MS); + } catch { + // Skip malformed JSONL lines that can happen during concurrent writes. + } + } - const finish = (error: Error | null, result: ProviderCurrentActiveModel | null) => { - if (settled) { - return; - } - - settled = true; - clearTimeout(timer); - - if (error) { - reject(error); - return; - } - - resolve(result); - }; - - child.stdout?.on('data', (chunk: Buffer) => { - stdout += chunk.toString(); - }); - - child.stderr?.on('data', (chunk: Buffer) => { - stderr += chunk.toString(); - }); - - child.on('error', (error) => { - finish(error instanceof Error ? error : new Error(String(error)), null); - }); - - child.on('close', () => { - const lines = `${stdout}\n${stderr}` - .split(/\r?\n/) - .map((line) => line.trim()) - .filter(Boolean); - - for (const line of lines) { - try { - const event = JSON.parse(line) as ClaudeInitEvent; - if (event.type === 'system' && event.subtype === 'init' && event.model) { - finish(null, { - model: event.model, - }); - return; - } - } catch { - // The Claude CLI mixes non-JSON lines into verbose output; ignore them. - } - } - - finish(null, null); - }); - }); + return null; }; export class ClaudeProviderModels implements IProviderModels { @@ -161,7 +188,7 @@ export class ClaudeProviderModels implements IProviderModels { // instance, so we create a lightweight query and immediately close it // after reading the control-plane metadata. queryInstance = query({ - prompt: '', + prompt: 'Get supported models', options: buildClaudeQueryOptions(), }); @@ -181,7 +208,10 @@ export class ClaudeProviderModels implements IProviderModels { } try { - const activeModel = await runClaudeSessionModelCommand(sessionId); + const jsonlPath = sessionsDb.getSessionById(sessionId)?.jsonl_path; + const activeModel = jsonlPath + ? await readClaudeSessionModelFromJsonl(sessionId, jsonlPath) + : null; if (activeModel?.model) { return activeModel; } @@ -191,4 +221,10 @@ export class ClaudeProviderModels implements IProviderModels { return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('claude', input); + } } diff --git a/server/modules/providers/list/codex/codex-models.provider.ts b/server/modules/providers/list/codex/codex-models.provider.ts index 75a767a3..de4de0ea 100644 --- a/server/modules/providers/list/codex/codex-models.provider.ts +++ b/server/modules/providers/list/codex/codex-models.provider.ts @@ -6,14 +6,17 @@ import TOML from '@iarna/toml'; import type { IProviderModels } from '@/shared/interfaces.js'; import type { + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelOption, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; import { buildDefaultProviderCurrentActiveModel, readObjectRecord, readOptionalString, + writeProviderSessionActiveModelChange, } from '@/shared/utils.js'; export const CODEX_FALLBACK_MODELS: ProviderModelsDefinition = { @@ -113,4 +116,10 @@ export class CodexProviderModels implements IProviderModels { return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); } } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('codex', input); + } } diff --git a/server/modules/providers/list/cursor/cursor-models.provider.ts b/server/modules/providers/list/cursor/cursor-models.provider.ts index 9290e870..ccc2288a 100644 --- a/server/modules/providers/list/cursor/cursor-models.provider.ts +++ b/server/modules/providers/list/cursor/cursor-models.provider.ts @@ -7,13 +7,16 @@ import crossSpawn from 'cross-spawn'; import type { IProviderModels } from '@/shared/interfaces.js'; import type { + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelOption, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; import { buildDefaultProviderCurrentActiveModel, sanitizeLeafDirectoryName, + writeProviderSessionActiveModelChange, } from '@/shared/utils.js'; export const CURSOR_FALLBACK_MODELS: ProviderModelsDefinition = { @@ -257,4 +260,10 @@ export class CursorProviderModels implements IProviderModels { return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('cursor', input); + } } diff --git a/server/modules/providers/list/gemini/gemini-models.provider.ts b/server/modules/providers/list/gemini/gemini-models.provider.ts index 45102058..fc830394 100644 --- a/server/modules/providers/list/gemini/gemini-models.provider.ts +++ b/server/modules/providers/list/gemini/gemini-models.provider.ts @@ -1,6 +1,14 @@ import type { IProviderModels } from '@/shared/interfaces.js'; -import type { ProviderCurrentActiveModel, ProviderModelsDefinition } from '@/shared/types.js'; -import { buildDefaultProviderCurrentActiveModel } from '@/shared/utils.js'; +import type { + ProviderChangeActiveModelInput, + ProviderCurrentActiveModel, + ProviderModelsDefinition, + ProviderSessionActiveModelChange, +} from '@/shared/types.js'; +import { + buildDefaultProviderCurrentActiveModel, + writeProviderSessionActiveModelChange, +} from '@/shared/utils.js'; export const GEMINI_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -25,4 +33,10 @@ export class GeminiProviderModels implements IProviderModels { async getCurrentActiveModel(): Promise { return buildDefaultProviderCurrentActiveModel(GEMINI_FALLBACK_MODELS); } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('gemini', input); + } } diff --git a/server/modules/providers/list/opencode/opencode-models.provider.ts b/server/modules/providers/list/opencode/opencode-models.provider.ts index 8dc75ee6..1b939256 100644 --- a/server/modules/providers/list/opencode/opencode-models.provider.ts +++ b/server/modules/providers/list/opencode/opencode-models.provider.ts @@ -5,15 +5,18 @@ import crossSpawn from 'cross-spawn'; import type { IProviderModels } from '@/shared/interfaces.js'; import type { + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelOption, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; import { buildDefaultProviderCurrentActiveModel, getOpenCodeDatabasePath, readObjectRecord, readOptionalString, + writeProviderSessionActiveModelChange, } from '@/shared/utils.js'; export const OPENCODE_FALLBACK_MODELS: ProviderModelsDefinition = { @@ -220,4 +223,10 @@ export class OpenCodeProviderModels implements IProviderModels { return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); } + + async changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise { + return writeProviderSessionActiveModelChange('opencode', input); + } } diff --git a/server/modules/providers/provider.routes.ts b/server/modules/providers/provider.routes.ts index 1d128882..2604fcc8 100644 --- a/server/modules/providers/provider.routes.ts +++ b/server/modules/providers/provider.routes.ts @@ -6,7 +6,13 @@ import { providerModelsService } from '@/modules/providers/services/provider-mod import { providerSkillsService } from '@/modules/providers/services/skills.service.js'; import { sessionConversationsSearchService } from '@/modules/providers/services/session-conversations-search.service.js'; import { sessionsService } from '@/modules/providers/services/sessions.service.js'; -import type { LLMProvider, McpScope, McpTransport, UpsertProviderMcpServerInput } from '@/shared/types.js'; +import type { + LLMProvider, + McpScope, + McpTransport, + ProviderChangeActiveModelInput, + UpsertProviderMcpServerInput, +} from '@/shared/types.js'; import { AppError, asyncHandler, createApiSuccessResponse } from '@/shared/utils.js'; const router = express.Router(); @@ -246,6 +252,29 @@ const parseSessionSearchLimit = (value: unknown): number => { return Math.max(1, Math.min(parsed, 100)); }; +const parseChangeActiveModelPayload = (payload: unknown): ProviderChangeActiveModelInput => { + if (!payload || typeof payload !== 'object') { + throw new AppError('Request body must be an object.', { + code: 'INVALID_REQUEST_BODY', + statusCode: 400, + }); + } + + const body = payload as Record; + const model = readOptionalQueryString(body.model); + if (!model) { + throw new AppError('model is required.', { + code: 'MODEL_REQUIRED', + statusCode: 400, + }); + } + + return { + sessionId: '', + model, + }; +}; + router.get( '/:provider/auth/status', asyncHandler(async (req: Request, res: Response) => { @@ -265,6 +294,20 @@ router.get( }), ); +router.post( + '/:provider/sessions/:sessionId/active-model', + asyncHandler(async (req: Request, res: Response) => { + const provider = parseProvider(req.params.provider); + const sessionId = parseSessionId(req.params.sessionId); + const payload = parseChangeActiveModelPayload(req.body); + const result = await providerModelsService.changeActiveModel(provider, { + ...payload, + sessionId, + }); + res.json(createApiSuccessResponse(result)); + }), +); + // ----------------- Skills routes ----------------- router.get( '/:provider/skills', diff --git a/server/modules/providers/services/provider-models.service.ts b/server/modules/providers/services/provider-models.service.ts index e0883fc6..5cb62433 100644 --- a/server/modules/providers/services/provider-models.service.ts +++ b/server/modules/providers/services/provider-models.service.ts @@ -6,11 +6,14 @@ import { providerRegistry } from '@/modules/providers/provider.registry.js'; import type { IProvider } from '@/shared/interfaces.js'; import type { LLMProvider, + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelsCacheInfo, ProviderModelsDefinition, ProviderModelsResult, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; +import { readProviderSessionActiveModelChange } from '@/shared/utils.js'; export const PROVIDER_MODELS_CACHE_TTL_MS = 3 * 24 * 60 * 60 * 1000; const PROVIDER_MODELS_CACHE_VERSION = 1; @@ -18,6 +21,7 @@ const PROVIDER_MODELS_CACHE_VERSION = 1; type ProviderModelsServiceDependencies = { resolveProvider?: (provider: LLMProvider) => Pick; cachePath?: string; + activeModelChangesPath?: string; now?: () => number; }; @@ -132,6 +136,7 @@ const writeProviderModelsCacheFile = async ( export const createProviderModelsService = (dependencies: ProviderModelsServiceDependencies = {}) => { const resolveProvider = dependencies.resolveProvider ?? providerRegistry.resolveProvider; const cachePath = dependencies.cachePath ?? getProviderModelsCachePath(); + const activeModelChangesPath = dependencies.activeModelChangesPath; const now = dependencies.now ?? (() => Date.now()); const memoryCache = new Map(); const pendingRequests = new Map>(); @@ -270,6 +275,36 @@ export const createProviderModelsService = (dependencies: ProviderModelsServiceD sessionId?: string, ): Promise => resolveProvider(provider).models.getCurrentActiveModel(sessionId); + const changeActiveModel = async ( + provider: LLMProvider, + input: ProviderChangeActiveModelInput, + ): Promise => resolveProvider(provider).models.changeActiveModel(input); + + const getChangedActiveModel = async ( + provider: LLMProvider, + sessionId: string, + ): Promise => readProviderSessionActiveModelChange(provider, sessionId, { + filePath: activeModelChangesPath, + }); + + const resolveResumeModel = async ( + provider: LLMProvider, + sessionId: string | undefined, + requestedModel?: string | null, + ): Promise => { + const normalizedRequestedModel = typeof requestedModel === 'string' ? requestedModel.trim() : ''; + if (!sessionId?.trim()) { + return normalizedRequestedModel || undefined; + } + + const changedModel = await getChangedActiveModel(provider, sessionId); + if (changedModel.supported && changedModel.changed && changedModel.model?.trim()) { + return changedModel.model.trim(); + } + + return normalizedRequestedModel || undefined; + }; + const clearCache = (): void => { memoryCache.clear(); pendingRequests.clear(); @@ -280,6 +315,9 @@ export const createProviderModelsService = (dependencies: ProviderModelsServiceD return { getProviderModels, getCurrentActiveModel, + getChangedActiveModel, + changeActiveModel, + resolveResumeModel, clearCache, }; }; diff --git a/server/modules/providers/tests/provider-models.service.test.ts b/server/modules/providers/tests/provider-models.service.test.ts index 6abc4c45..fb9ebf7d 100644 --- a/server/modules/providers/tests/provider-models.service.test.ts +++ b/server/modules/providers/tests/provider-models.service.test.ts @@ -9,10 +9,13 @@ import { PROVIDER_MODELS_CACHE_TTL_MS, } from '@/modules/providers/services/provider-models.service.js'; import type { + ProviderChangeActiveModelInput, LLMProvider, ProviderCurrentActiveModel, ProviderModelsDefinition, + ProviderSessionActiveModelChange, } from '@/shared/types.js'; +import { writeProviderSessionActiveModelChange } from '@/shared/utils.js'; const createModels = (value: string): ProviderModelsDefinition => ({ OPTIONS: [{ value, label: value }], @@ -23,6 +26,17 @@ const createCurrentActiveModel = (model: string): ProviderCurrentActiveModel => model, }); +const createSessionActiveModelChange = ( + provider: LLMProvider, + input: ProviderChangeActiveModelInput, +): ProviderSessionActiveModelChange => ({ + provider, + sessionId: input.sessionId, + supported: true, + changed: true, + model: input.model, +}); + const createEphemeralCachePath = (): string => path.join( os.tmpdir(), `provider-model-cache-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}.json`, @@ -38,6 +52,7 @@ test('provider models service delegates to the resolved provider model adapter', models: { getSupportedModels: async () => createModels(`${provider}-models`), getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), }, }; }, @@ -65,6 +80,7 @@ test('provider models service returns each provider adapter result without rewri models: { getSupportedModels: async () => expectedModels, getCurrentActiveModel: async () => createCurrentActiveModel('cursor-active'), + changeActiveModel: async (input) => createSessionActiveModelChange('cursor', input), }, }), }); @@ -90,6 +106,7 @@ test('provider models are cached for the three-day ttl', async () => { return createModels(`${provider}-${loadCount}`); }, getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), }, }), }); @@ -124,6 +141,7 @@ test('provider model cache is persisted across service instances', async () => { models: { getSupportedModels: async () => createModels('gemini-cached'), getCurrentActiveModel: async () => createCurrentActiveModel('gemini-active'), + changeActiveModel: async (input) => createSessionActiveModelChange('gemini', input), }, }), }); @@ -137,6 +155,7 @@ test('provider model cache is persisted across service instances', async () => { throw new Error('loader should not be called for persisted cache hits'); }, getCurrentActiveModel: async () => createCurrentActiveModel('gemini-active'), + changeActiveModel: async (input) => createSessionActiveModelChange('gemini', input), }, }), }); @@ -163,6 +182,7 @@ test('concurrent provider model requests share one load operation', async () => return createModels('claude-cached'); }, getCurrentActiveModel: async () => createCurrentActiveModel('claude-active'), + changeActiveModel: async (input) => createSessionActiveModelChange('claude', input), }, }), }); @@ -196,6 +216,7 @@ test('bypassCache forces a fresh provider fetch and updates cache metadata', asy return createModels(`${provider}-${loadCount}`); }, getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active-${loadCount}`), + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), }, }), }); @@ -224,6 +245,7 @@ test('provider models service delegates current active model lookups to the prov calls.push({ provider, sessionId }); return createCurrentActiveModel(`${provider}-${sessionId}`); }, + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), }, }), }); @@ -233,3 +255,64 @@ test('provider models service delegates current active model lookups to the prov assert.deepEqual(calls, [{ provider: 'opencode', sessionId: 'session-123' }]); assert.equal(activeModel.model, 'opencode-session-123'); }); + +test('provider models service delegates active model change requests to the provider adapter', async () => { + const calls: Array<{ provider: LLMProvider; input: ProviderChangeActiveModelInput }> = []; + const service = createProviderModelsService({ + resolveProvider: (provider) => ({ + models: { + getSupportedModels: async () => createModels(`${provider}-models`), + getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), + changeActiveModel: async (input) => { + calls.push({ provider, input }); + return createSessionActiveModelChange(provider, input); + }, + }, + }), + }); + + const changedModel = await service.changeActiveModel('claude', { + sessionId: 'session-123', + model: 'opus', + }); + + assert.deepEqual(calls, [{ + provider: 'claude', + input: { + sessionId: 'session-123', + model: 'opus', + }, + }]); + assert.equal(changedModel.changed, true); + assert.equal(changedModel.model, 'opus'); +}); + +test('resolveResumeModel prefers a stored changed model over the requested one', async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), 'provider-model-change-')); + const activeModelChangesPath = path.join(tempRoot, 'session-model-changes.json'); + + try { + const service = createProviderModelsService({ + activeModelChangesPath, + resolveProvider: (provider) => ({ + models: { + getSupportedModels: async () => createModels(`${provider}-models`), + getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), + changeActiveModel: async (input) => createSessionActiveModelChange(provider, input), + }, + }), + }); + + await writeProviderSessionActiveModelChange('cursor', { + sessionId: 'session-456', + model: 'composer-2', + }, { + filePath: activeModelChangesPath, + }); + + const model = await service.resolveResumeModel('cursor', 'session-456', 'composer-2-fast'); + assert.equal(model, 'composer-2'); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } +}); diff --git a/server/openai-codex.js b/server/openai-codex.js index 03497c30..95913586 100644 --- a/server/openai-codex.js +++ b/server/openai-codex.js @@ -17,6 +17,7 @@ import { Codex } from '@openai/codex-sdk'; import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js'; import { sessionsService } from './modules/providers/services/sessions.service.js'; import { providerAuthService } from './modules/providers/services/provider-auth.service.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { createNormalizedMessage } from './shared/utils.js'; // Track active sessions @@ -202,6 +203,12 @@ export async function queryCodex(command, options = {}, ws) { permissionMode = 'default' } = options; + const resolvedModel = await providerModelsService.resolveResumeModel( + 'codex', + sessionId, + model, + ); + const workingDirectory = cwd || projectPath || process.cwd(); const { sandboxMode, approvalPolicy } = mapPermissionModeToCodexOptions(permissionMode); @@ -222,7 +229,7 @@ export async function queryCodex(command, options = {}, ws) { skipGitRepoCheck: true, sandboxMode, approvalPolicy, - model + model: resolvedModel }; // Start or resume thread diff --git a/server/opencode-cli.js b/server/opencode-cli.js index 9d3c3851..ceabc082 100644 --- a/server/opencode-cli.js +++ b/server/opencode-cli.js @@ -4,6 +4,7 @@ import crossSpawn from 'cross-spawn'; import { sessionsService } from './modules/providers/services/sessions.service.js'; import { providerAuthService } from './modules/providers/services/provider-auth.service.js'; +import { providerModelsService } from './modules/providers/services/provider-models.service.js'; import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js'; import { createNormalizedMessage } from './shared/utils.js'; @@ -29,17 +30,6 @@ async function spawnOpenCode(command, options = {}, ws) { let stdoutLineBuffer = ''; let terminalNotificationSent = false; - const args = ['run', '--format', 'json']; - if (sessionId) { - args.push('--session', sessionId); - } - if (model) { - args.push('--model', model); - } - if (command && command.trim()) { - args.push(command.trim()); - } - const notifyTerminalState = ({ code = null, error = null } = {}) => { if (terminalNotificationSent) { return; @@ -67,16 +57,6 @@ async function spawnOpenCode(command, options = {}, ws) { }); }; - const opencodeProcess = spawnFunction('opencode', args, { - cwd: workingDir, - stdio: ['pipe', 'pipe', 'pipe'], - env: { ...process.env }, - }); - - activeOpenCodeProcesses.set(processKey, opencodeProcess); - opencodeProcess.sessionId = processKey; - opencodeProcess.stdin.end(); - const registerSession = (nextSessionId) => { if (!nextSessionId || capturedSessionId === nextSessionId) { return; @@ -130,89 +110,112 @@ async function spawnOpenCode(command, options = {}, ws) { } }; - opencodeProcess.stdout.on('data', (data) => { - stdoutLineBuffer += data.toString(); - const completeLines = stdoutLineBuffer.split(/\r?\n/); - stdoutLineBuffer = completeLines.pop() || ''; + void providerModelsService.resolveResumeModel('opencode', sessionId, model).then((resolvedModel) => { + const args = ['run', '--format', 'json']; + if (sessionId) { + args.push('--session', sessionId); + } + if (resolvedModel) { + args.push('--model', resolvedModel); + } + if (command && command.trim()) { + args.push(command.trim()); + } - completeLines.forEach((line) => { - processOpenCodeOutputLine(line.trim()); + const opencodeProcess = spawnFunction('opencode', args, { + cwd: workingDir, + stdio: ['pipe', 'pipe', 'pipe'], + env: { ...process.env }, }); - }); - opencodeProcess.stderr.on('data', (data) => { - const stderrText = data.toString(); - if (!stderrText.trim()) { - return; - } + activeOpenCodeProcesses.set(processKey, opencodeProcess); + opencodeProcess.sessionId = processKey; + opencodeProcess.stdin.end(); - ws.send(createNormalizedMessage({ - kind: 'error', - content: stderrText, - sessionId: capturedSessionId || sessionId || null, - provider: 'opencode', - })); - }); + opencodeProcess.stdout.on('data', (data) => { + stdoutLineBuffer += data.toString(); + const completeLines = stdoutLineBuffer.split(/\r?\n/); + stdoutLineBuffer = completeLines.pop() || ''; - opencodeProcess.on('close', async (code) => { - const finalSessionId = capturedSessionId || sessionId || processKey; - activeOpenCodeProcesses.delete(finalSessionId); - activeOpenCodeProcesses.delete(processKey); + completeLines.forEach((line) => { + processOpenCodeOutputLine(line.trim()); + }); + }); - if (stdoutLineBuffer.trim()) { - processOpenCodeOutputLine(stdoutLineBuffer.trim()); - stdoutLineBuffer = ''; - } - - ws.send(createNormalizedMessage({ - kind: 'complete', - exitCode: code, - isNewSession: !sessionId && !!command, - sessionId: finalSessionId, - provider: 'opencode', - })); - - if (code === 0) { - notifyTerminalState({ code }); - resolve(); - return; - } - - if (code === 127 || code === null) { - const installed = await providerAuthService.isProviderInstalled('opencode'); - if (!installed) { - ws.send(createNormalizedMessage({ - kind: 'error', - content: 'OpenCode CLI is not installed. Install it from https://opencode.ai/docs/', - sessionId: finalSessionId, - provider: 'opencode', - })); + opencodeProcess.stderr.on('data', (data) => { + const stderrText = data.toString(); + if (!stderrText.trim()) { + return; } - } - notifyTerminalState({ code }); - reject(new Error(code === null ? 'OpenCode CLI process was terminated' : `OpenCode CLI exited with code ${code}`)); - }); + ws.send(createNormalizedMessage({ + kind: 'error', + content: stderrText, + sessionId: capturedSessionId || sessionId || null, + provider: 'opencode', + })); + }); - opencodeProcess.on('error', async (error) => { - const finalSessionId = capturedSessionId || sessionId || processKey; - activeOpenCodeProcesses.delete(finalSessionId); - activeOpenCodeProcesses.delete(processKey); + opencodeProcess.on('close', async (code) => { + const finalSessionId = capturedSessionId || sessionId || processKey; + activeOpenCodeProcesses.delete(finalSessionId); + activeOpenCodeProcesses.delete(processKey); - const installed = await providerAuthService.isProviderInstalled('opencode'); - const errorContent = !installed - ? 'OpenCode CLI is not installed. Install it from https://opencode.ai/docs/' - : error.message; + if (stdoutLineBuffer.trim()) { + processOpenCodeOutputLine(stdoutLineBuffer.trim()); + stdoutLineBuffer = ''; + } - ws.send(createNormalizedMessage({ - kind: 'error', - content: errorContent, - sessionId: finalSessionId, - provider: 'opencode', - })); - notifyTerminalState({ error }); - reject(error); - }); + ws.send(createNormalizedMessage({ + kind: 'complete', + exitCode: code, + isNewSession: !sessionId && !!command, + sessionId: finalSessionId, + provider: 'opencode', + })); + + if (code === 0) { + notifyTerminalState({ code }); + resolve(); + return; + } + + if (code === 127 || code === null) { + const installed = await providerAuthService.isProviderInstalled('opencode'); + if (!installed) { + ws.send(createNormalizedMessage({ + kind: 'error', + content: 'OpenCode CLI is not installed. Install it from https://opencode.ai/docs/', + sessionId: finalSessionId, + provider: 'opencode', + })); + } + } + + notifyTerminalState({ code }); + reject(new Error(code === null ? 'OpenCode CLI process was terminated' : `OpenCode CLI exited with code ${code}`)); + }); + + opencodeProcess.on('error', async (error) => { + const finalSessionId = capturedSessionId || sessionId || processKey; + activeOpenCodeProcesses.delete(finalSessionId); + activeOpenCodeProcesses.delete(processKey); + + const installed = await providerAuthService.isProviderInstalled('opencode'); + const errorContent = !installed + ? 'OpenCode CLI is not installed. Install it from https://opencode.ai/docs/' + : error.message; + + ws.send(createNormalizedMessage({ + kind: 'error', + content: errorContent, + sessionId: finalSessionId, + provider: 'opencode', + })); + notifyTerminalState({ error }); + reject(error); + }); + }).catch(reject); }); } diff --git a/server/shared/interfaces.ts b/server/shared/interfaces.ts index 9228698e..1864fd02 100644 --- a/server/shared/interfaces.ts +++ b/server/shared/interfaces.ts @@ -7,9 +7,11 @@ import type { ProviderSkill, ProviderSkillListOptions, ProviderAuthStatus, + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelsDefinition, ProviderMcpServer, + ProviderSessionActiveModelChange, UpsertProviderMcpServerInput, } from '@/shared/types.js'; @@ -55,6 +57,19 @@ export interface IProviderModels { * no active model can be resolved. */ getCurrentActiveModel(sessionId?: string): Promise; + + /** + * Persists a session-scoped model override that the next resumed turn should + * honor for this provider. + * + * This does not require the provider to mutate an already running remote + * session in-place. Instead, adapters store the user's explicit model choice + * so the backend resume path can add the correct provider-native model option + * on the next CLI/SDK invocation for the same session. + */ + changeActiveModel( + input: ProviderChangeActiveModelInput, + ): Promise; } // --------------------------- diff --git a/server/shared/types.ts b/server/shared/types.ts index 121cc509..7d981988 100644 --- a/server/shared/types.ts +++ b/server/shared/types.ts @@ -123,6 +123,37 @@ export type ProviderCurrentActiveModel = { model: string; }; +/** + * Input payload used when one session needs to use a different model on its + * next resumed turn. + * + * This is a backend-owned session override, not a claim that the provider has + * already switched the currently running session in-place. Provider adapters + * persist this request so the next CLI/SDK resume can inject the chosen model + * using the provider-specific mechanism supported by that runtime. + */ +export type ProviderChangeActiveModelInput = { + sessionId: string; + model: string; +}; + +/** + * Provider-neutral session model-change state. + * + * `supported` indicates whether the provider adapter supports the app's + * session-scoped resume override flow. `changed` is the persisted boolean the + * resume layer checks before forcing a model on the next resumed turn. When + * `changed` is `false`, `model` is `null` and the runtime should use the + * normal request/default model selection path. + */ +export type ProviderSessionActiveModelChange = { + provider: LLMProvider; + sessionId: string; + supported: boolean; + changed: boolean; + model: string | null; +}; + /** * Message/event variants emitted by provider adapters and normalized transports. * diff --git a/server/shared/utils.ts b/server/shared/utils.ts index 437ea034..821489c5 100644 --- a/server/shared/utils.ts +++ b/server/shared/utils.ts @@ -22,9 +22,12 @@ import type { AnyRecord, ApiSuccessShape, AppErrorOptions, + LLMProvider, NormalizedMessage, + ProviderChangeActiveModelInput, ProviderCurrentActiveModel, ProviderModelsDefinition, + ProviderSessionActiveModelChange, ProviderSkillSource, WorkspacePathValidationResult, } from '@/shared/types.js'; @@ -434,6 +437,213 @@ export function buildDefaultProviderCurrentActiveModel( }; } +// --------------------------- +//----------------- PROVIDER SESSION MODEL CHANGE UTILITIES ------------ +type ProviderSessionActiveModelChangeCacheEntry = ProviderSessionActiveModelChange & { + updatedAt: string; +}; + +type ProviderSessionActiveModelChangeCacheFile = { + version: number; + entries: Record; +}; + +const PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION = 1; + +/** + * Resolves the backend-owned cache file used for session-scoped resume model + * overrides. + * + * The file lives under `~/.cloudcli` because these overrides are an application + * concern rather than a provider-native config file. Providers, routes, and + * runtime command launchers should all use this helper instead of re-creating + * the path so the storage location stays consistent. + */ +export function getProviderSessionActiveModelChangesPath(): string { + return path.join(os.homedir(), '.cloudcli', 'provider-session-active-model-changes.json'); +} + +const buildProviderSessionActiveModelChangeKey = ( + provider: LLMProvider, + sessionId: string, +): string => `${provider}:${sessionId}`; + +const isProviderSessionActiveModelChangeCacheEntry = ( + value: unknown, +): value is ProviderSessionActiveModelChangeCacheEntry => { + const record = readObjectRecord(value); + return Boolean( + record + && typeof record.provider === 'string' + && typeof record.sessionId === 'string' + && typeof record.supported === 'boolean' + && typeof record.changed === 'boolean' + && (typeof record.model === 'string' || record.model === null) + && typeof record.updatedAt === 'string', + ); +}; + +const readProviderSessionActiveModelChangeCacheFile = async ( + filePath: string, +): Promise => { + try { + const raw = await readFile(filePath, 'utf8'); + const parsed = readObjectRecord(JSON.parse(raw)); + if ( + !parsed + || parsed.version !== PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION + || !readObjectRecord(parsed.entries) + ) { + return { + version: PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION, + entries: {}, + }; + } + + const entries = Object.fromEntries( + Object.entries(parsed.entries).filter((entry): entry is [string, ProviderSessionActiveModelChangeCacheEntry] => + isProviderSessionActiveModelChangeCacheEntry(entry[1]), + ), + ); + + return { + version: PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION, + entries, + }; + } catch { + return { + version: PROVIDER_SESSION_ACTIVE_MODEL_CHANGE_CACHE_VERSION, + entries: {}, + }; + } +}; + +const writeProviderSessionActiveModelChangeCacheFile = async ( + filePath: string, + payload: ProviderSessionActiveModelChangeCacheFile, +): Promise => { + await mkdir(path.dirname(filePath), { recursive: true }); + await writeFile(filePath, `${JSON.stringify(payload, null, 2)}\n`, 'utf8'); +}; + +const buildUnsupportedProviderSessionActiveModelChange = ( + provider: LLMProvider, + sessionId: string, +): ProviderSessionActiveModelChange => ({ + provider, + sessionId, + supported: false, + changed: false, + model: null, +}); + +/** + * Reads the persisted session model-change state for one provider session. + * + * Runtime resume paths use this to decide whether they should inject a + * provider-specific model argument/thread option for the next resumed turn. + * Missing cache entries are normalized to `{ changed: false }` so callers can + * treat absence as "use the ordinary model selection flow". + */ +export async function readProviderSessionActiveModelChange( + provider: LLMProvider, + sessionId: string, + options: { + filePath?: string; + supported?: boolean; + } = {}, +): Promise { + const normalizedSessionId = sessionId.trim(); + if (!normalizedSessionId) { + return buildUnsupportedProviderSessionActiveModelChange(provider, normalizedSessionId); + } + + const supported = options.supported ?? true; + if (!supported) { + return buildUnsupportedProviderSessionActiveModelChange(provider, normalizedSessionId); + } + + const filePath = options.filePath ?? getProviderSessionActiveModelChangesPath(); + const cacheFile = await readProviderSessionActiveModelChangeCacheFile(filePath); + const cacheEntry = cacheFile.entries[ + buildProviderSessionActiveModelChangeKey(provider, normalizedSessionId) + ]; + + if (!cacheEntry || !cacheEntry.changed || !cacheEntry.model?.trim()) { + return { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: false, + model: null, + }; + } + + return { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: true, + model: cacheEntry.model.trim(), + }; +} + +/** + * Persists a session model-change request for one provider. + * + * Provider adapters call this when the frontend explicitly selects a different + * model for an existing session. The stored `changed: true` flag is the single + * source of truth used later by resume paths to decide whether they should add + * a provider-native model override on the next invocation. + */ +export async function writeProviderSessionActiveModelChange( + provider: LLMProvider, + input: ProviderChangeActiveModelInput, + options: { + filePath?: string; + supported?: boolean; + } = {}, +): Promise { + const normalizedSessionId = input.sessionId.trim(); + const normalizedModel = input.model.trim(); + const supported = options.supported ?? true; + + if (!supported) { + return buildUnsupportedProviderSessionActiveModelChange(provider, normalizedSessionId); + } + + if (!normalizedSessionId || !normalizedModel) { + return { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: false, + model: null, + }; + } + + const filePath = options.filePath ?? getProviderSessionActiveModelChangesPath(); + const cacheFile = await readProviderSessionActiveModelChangeCacheFile(filePath); + cacheFile.entries[buildProviderSessionActiveModelChangeKey(provider, normalizedSessionId)] = { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: true, + model: normalizedModel, + updatedAt: new Date().toISOString(), + }; + + await writeProviderSessionActiveModelChangeCacheFile(filePath, cacheFile); + + return { + provider, + sessionId: normalizedSessionId, + supported: true, + changed: true, + model: normalizedModel, + }; +} + // --------------------------- //----------------- WEBSOCKET PAYLOAD PARSING UTILITIES ------------ /** diff --git a/src/components/chat/hooks/useChatProviderState.ts b/src/components/chat/hooks/useChatProviderState.ts index e8f0ed5c..33c54d4b 100644 --- a/src/components/chat/hooks/useChatProviderState.ts +++ b/src/components/chat/hooks/useChatProviderState.ts @@ -43,6 +43,17 @@ type ProviderModelsApiResponse = { }; }; +type ChangeActiveModelApiResponse = { + success?: boolean; + data?: { + provider?: LLMProvider; + sessionId?: string; + supported?: boolean; + changed?: boolean; + model?: string | null; + }; +}; + export function useChatProviderState({ selectedSession, selectedProject }: UseChatProviderStateArgs) { const [permissionMode, setPermissionMode] = useState('default'); const [pendingPermissionRequests, setPendingPermissionRequests] = useState([]); @@ -77,6 +88,35 @@ export function useChatProviderState({ selectedSession, selectedProject }: UseCh const lastProviderRef = useRef(provider); const providerModelsRequestIdRef = useRef(0); + const setStoredProviderModel = useCallback((targetProvider: LLMProvider, model: string) => { + if (targetProvider === 'claude') { + setClaudeModel(model); + localStorage.setItem('claude-model', model); + return; + } + + if (targetProvider === 'cursor') { + setCursorModel(model); + localStorage.setItem('cursor-model', model); + return; + } + + if (targetProvider === 'codex') { + setCodexModel(model); + localStorage.setItem('codex-model', model); + return; + } + + if (targetProvider === 'gemini') { + setGeminiModel(model); + localStorage.setItem('gemini-model', model); + return; + } + + setOpenCodeModel(model); + localStorage.setItem('opencode-model', model); + }, []); + const loadProviderModels = useCallback(async (options: { bypassCache?: boolean } = {}) => { const providers: LLMProvider[] = ['claude', 'cursor', 'codex', 'gemini', 'opencode']; const requestId = providerModelsRequestIdRef.current + 1; @@ -289,6 +329,41 @@ export function useChatProviderState({ selectedSession, selectedProject }: UseCh } }, [permissionMode, provider, selectedSession?.id]); + const selectProviderModel = useCallback(async ( + targetProvider: LLMProvider, + model: string, + sessionId?: string | null, + ) => { + const normalizedSessionId = typeof sessionId === 'string' ? sessionId.trim() : ''; + if (!normalizedSessionId) { + setStoredProviderModel(targetProvider, model); + return { + scope: 'default' as const, + changed: false, + model, + }; + } + + const response = await authenticatedFetch( + `/api/providers/${targetProvider}/sessions/${encodeURIComponent(normalizedSessionId)}/active-model`, + { + method: 'POST', + body: JSON.stringify({ model }), + }, + ); + + const body = (await response.json()) as ChangeActiveModelApiResponse; + if (!response.ok || !body.success || !body.data?.supported) { + throw new Error('Unable to change the active model for this session.'); + } + + return { + scope: 'session' as const, + changed: body.data.changed === true, + model: body.data.model || model, + }; + }, [setStoredProviderModel]); + return { provider, setProvider, @@ -312,5 +387,6 @@ export function useChatProviderState({ selectedSession, selectedProject }: UseCh providerModelsLoading, providerModelsRefreshing, hardRefreshProviderModels: () => loadProviderModels({ bypassCache: true }), + selectProviderModel, }; } diff --git a/src/components/chat/view/ChatInterface.tsx b/src/components/chat/view/ChatInterface.tsx index c724bc9c..63e0cbb5 100644 --- a/src/components/chat/view/ChatInterface.tsx +++ b/src/components/chat/view/ChatInterface.tsx @@ -83,6 +83,7 @@ function ChatInterface({ providerModelsLoading, providerModelsRefreshing, hardRefreshProviderModels, + selectProviderModel, } = useChatProviderState({ selectedSession, selectedProject, @@ -441,6 +442,8 @@ function ChatInterface({ providerModelCacheCatalog={providerModelCacheCatalog} providerModelsRefreshing={providerModelsRefreshing} onHardRefreshProviderModels={hardRefreshProviderModels} + currentSessionId={currentSessionId || selectedSession?.id || null} + onSelectProviderModel={selectProviderModel} /> ); diff --git a/src/components/chat/view/subcomponents/CommandResultModal.tsx b/src/components/chat/view/subcomponents/CommandResultModal.tsx index f0163bba..8faa677b 100644 --- a/src/components/chat/view/subcomponents/CommandResultModal.tsx +++ b/src/components/chat/view/subcomponents/CommandResultModal.tsx @@ -38,6 +38,16 @@ type CommandResultModalProps = { providerModelCacheCatalog: Partial>; providerModelsRefreshing: boolean; onHardRefreshProviderModels: () => void; + currentSessionId: string | null; + onSelectProviderModel: ( + provider: LLMProvider, + model: string, + sessionId?: string | null, + ) => Promise<{ + scope: 'default' | 'session'; + changed: boolean; + model: string; + }>; }; type CommandEntry = { @@ -254,15 +264,22 @@ function ModelsContent({ providerModelCacheCatalog, providerModelsRefreshing, onHardRefreshProviderModels, + currentSessionId, + onSelectProviderModel, }: { data: ModelCommandData; providerModelCatalog: Partial>; providerModelCacheCatalog: Partial>; providerModelsRefreshing: boolean; onHardRefreshProviderModels: () => void; + currentSessionId: string | null; + onSelectProviderModel: CommandResultModalProps['onSelectProviderModel']; }) { const [query, setQuery] = useState(''); const [copiedModel, setCopiedModel] = useState(null); + const [changingModel, setChangingModel] = useState(null); + const [pendingSessionModel, setPendingSessionModel] = useState(null); + const [selectionNotice, setSelectionNotice] = useState(null); const currentProvider = (data?.current?.provider || 'claude') as LLMProvider; const currentModel = data?.current?.model || 'Unknown'; const providerLabel = data?.current?.providerLabel || getProviderLabel(currentProvider); @@ -295,6 +312,7 @@ function ModelsContent({ }, [availableOptions, query]); const activeOption = availableOptions.find((option) => option.value === currentModel); + const hasConcreteSessionId = typeof currentSessionId === 'string' && currentSessionId.trim().length > 0; const copyModel = (model: string) => { if (typeof navigator !== 'undefined' && navigator.clipboard) { @@ -306,6 +324,26 @@ function ModelsContent({ }, 1300); }; + const handleSelectModel = async (model: string) => { + setChangingModel(model); + try { + const result = await onSelectProviderModel(currentProvider, model, currentSessionId); + if (result.scope === 'session') { + setPendingSessionModel(result.model); + setSelectionNotice(`Next response will resume with ${result.model}.`); + return; + } + + setPendingSessionModel(null); + setSelectionNotice(`Default ${providerLabel} model set to ${result.model}.`); + } catch (error) { + const message = error instanceof Error ? error.message : 'Unable to change the model right now.'; + setSelectionNotice(message); + } finally { + setChangingModel(null); + } + }; + return (
@@ -331,6 +369,13 @@ function ModelsContent({
+
+ {hasConcreteSessionId + ? 'Selecting a model stores a session override and applies it on the next response for this session.' + : 'Selecting a model updates the default model used for new turns in this provider.'} + {selectionNotice && {selectionNotice}} +
+
@@ -344,6 +389,11 @@ function ModelsContent({ {activeOption?.description && (

{activeOption.description}

)} + {pendingSessionModel && pendingSessionModel !== currentModel && ( +

+ Next response: {pendingSessionModel} +

+ )}
Live
@@ -367,20 +417,27 @@ function ModelsContent({ {filteredOptions.map((option, index) => { const isCurrent = option.value === currentModel; const wasCopied = copiedModel === option.value; + const isPendingSelection = option.value === pendingSessionModel; + const isChanging = option.value === changingModel; return ( - + + +
); })}
@@ -521,6 +593,8 @@ export default function CommandResultModal({ providerModelCacheCatalog, providerModelsRefreshing, onHardRefreshProviderModels, + currentSessionId, + onSelectProviderModel, }: CommandResultModalProps) { const isOpen = Boolean(payload); const kind = payload?.kind; @@ -613,6 +687,8 @@ export default function CommandResultModal({ providerModelCacheCatalog={providerModelCacheCatalog} providerModelsRefreshing={providerModelsRefreshing} onHardRefreshProviderModels={onHardRefreshProviderModels} + currentSessionId={currentSessionId} + onSelectProviderModel={onSelectProviderModel} /> )} {payload?.kind === 'cost' && }