diff --git a/server/modules/providers/list/claude/claude-models.provider.ts b/server/modules/providers/list/claude/claude-models.provider.ts index ca8e8eb8..b0d845a3 100644 --- a/server/modules/providers/list/claude/claude-models.provider.ts +++ b/server/modules/providers/list/claude/claude-models.provider.ts @@ -1,8 +1,16 @@ +import { spawn } from 'node:child_process'; + import { query, type ModelInfo, type Options } from '@anthropic-ai/claude-agent-sdk'; +import crossSpawn from 'cross-spawn'; import { resolveClaudeCodeExecutablePath } from '@/shared/claude-cli-path.js'; import type { IProviderModels } from '@/shared/interfaces.js'; -import type { ProviderModelOption, ProviderModelsDefinition } from '@/shared/types.js'; +import type { + ProviderCurrentActiveModel, + ProviderModelOption, + ProviderModelsDefinition, +} from '@/shared/types.js'; +import { buildDefaultProviderCurrentActiveModel } from '@/shared/utils.js'; export const CLAUDE_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -17,6 +25,14 @@ export const CLAUDE_FALLBACK_MODELS: ProviderModelsDefinition = { }; type ClaudeModelQueryOptions = Pick; +type ClaudeInitEvent = { + type?: string; + subtype?: string; + model?: string; +}; + +const CLAUDE_ACTIVE_MODEL_TIMEOUT_MS = 20_000; +const claudeSpawn = process.platform === 'win32' ? crossSpawn : spawn; const buildClaudeQueryOptions = (): ClaudeModelQueryOptions => ({ env: { ...process.env }, @@ -58,6 +74,84 @@ const buildClaudeModelsDefinition = (models: ModelInfo[]): ProviderModelsDefinit }; }; +const runClaudeSessionModelCommand = async (sessionId: string): Promise => { + const cliPath = resolveClaudeCodeExecutablePath(process.env.CLAUDE_CLI_PATH); + + return new Promise((resolve, reject) => { + const child = claudeSpawn( + cliPath, + ['-p', '--verbose', '--output-format', 'stream-json', '--resume', sessionId, 'ok'], + { + env: { ...process.env }, + windowsHide: true, + }, + ); + + let stdout = ''; + let stderr = ''; + let settled = false; + + const timer = setTimeout(() => { + child.kill('SIGTERM'); + if (!settled) { + settled = true; + reject(new Error('Claude current-model lookup timed out')); + } + }, CLAUDE_ACTIVE_MODEL_TIMEOUT_MS); + + const finish = (error: Error | null, result: ProviderCurrentActiveModel | null) => { + if (settled) { + return; + } + + settled = true; + clearTimeout(timer); + + if (error) { + reject(error); + return; + } + + resolve(result); + }; + + child.stdout?.on('data', (chunk: Buffer) => { + stdout += chunk.toString(); + }); + + child.stderr?.on('data', (chunk: Buffer) => { + stderr += chunk.toString(); + }); + + child.on('error', (error) => { + finish(error instanceof Error ? error : new Error(String(error)), null); + }); + + child.on('close', () => { + const lines = `${stdout}\n${stderr}` + .split(/\r?\n/) + .map((line) => line.trim()) + .filter(Boolean); + + for (const line of lines) { + try { + const event = JSON.parse(line) as ClaudeInitEvent; + if (event.type === 'system' && event.subtype === 'init' && event.model) { + finish(null, { + model: event.model, + }); + return; + } + } catch { + // The Claude CLI mixes non-JSON lines into verbose output; ignore them. + } + } + + finish(null, null); + }); + }); +}; + export class ClaudeProviderModels implements IProviderModels { async getSupportedModels(): Promise { let queryInstance: ReturnType | null = null; @@ -80,4 +174,21 @@ export class ClaudeProviderModels implements IProviderModels { queryInstance?.close(); } } + + async getCurrentActiveModel(sessionId?: string): Promise { + if (!sessionId?.trim()) { + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } + + try { + const activeModel = await runClaudeSessionModelCommand(sessionId); + if (activeModel?.model) { + return activeModel; + } + } catch { + // Fall through to the provider default when the session-backed lookup fails. + } + + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } } diff --git a/server/modules/providers/list/codex/codex-models.provider.ts b/server/modules/providers/list/codex/codex-models.provider.ts index e4dcdf09..75a767a3 100644 --- a/server/modules/providers/list/codex/codex-models.provider.ts +++ b/server/modules/providers/list/codex/codex-models.provider.ts @@ -2,9 +2,19 @@ import { readFile } from 'node:fs/promises'; import os from 'node:os'; import path from 'node:path'; +import TOML from '@iarna/toml'; + import type { IProviderModels } from '@/shared/interfaces.js'; -import type { ProviderModelOption, ProviderModelsDefinition } from '@/shared/types.js'; -import { readObjectRecord, readOptionalString } from '@/shared/utils.js'; +import type { + ProviderCurrentActiveModel, + ProviderModelOption, + ProviderModelsDefinition, +} from '@/shared/types.js'; +import { + buildDefaultProviderCurrentActiveModel, + readObjectRecord, + readOptionalString, +} from '@/shared/utils.js'; export const CODEX_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -27,6 +37,7 @@ type CodexCachedModel = { }; const CODEX_MODELS_CACHE_PATH = path.join(os.homedir(), '.codex', 'models_cache.json'); +const CODEX_CONFIG_PATH = path.join(os.homedir(), '.codex', 'config.toml'); const isCodexCachedModel = (value: unknown): value is CodexCachedModel => { const record = readObjectRecord(value); @@ -85,4 +96,21 @@ export class CodexProviderModels implements IProviderModels { return CODEX_FALLBACK_MODELS; } } + + async getCurrentActiveModel(): Promise { + try { + const raw = await readFile(CODEX_CONFIG_PATH, 'utf8'); + const parsed = readObjectRecord(TOML.parse(raw)); + const model = readOptionalString(parsed?.model); + if (!model) { + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } + + return { + model, + }; + } catch { + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } + } } diff --git a/server/modules/providers/list/cursor/cursor-models.provider.ts b/server/modules/providers/list/cursor/cursor-models.provider.ts index a5ed4f47..9290e870 100644 --- a/server/modules/providers/list/cursor/cursor-models.provider.ts +++ b/server/modules/providers/list/cursor/cursor-models.provider.ts @@ -1,9 +1,20 @@ +import { access, readdir } from 'node:fs/promises'; +import os from 'node:os'; +import path from 'node:path'; import { spawn } from 'node:child_process'; import crossSpawn from 'cross-spawn'; import type { IProviderModels } from '@/shared/interfaces.js'; -import type { ProviderModelOption, ProviderModelsDefinition } from '@/shared/types.js'; +import type { + ProviderCurrentActiveModel, + ProviderModelOption, + ProviderModelsDefinition, +} from '@/shared/types.js'; +import { + buildDefaultProviderCurrentActiveModel, + sanitizeLeafDirectoryName, +} from '@/shared/utils.js'; export const CURSOR_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -24,6 +35,7 @@ type CursorModelRow = { }; const CURSOR_MODELS_TIMEOUT_MS = 10_000; +const CURSOR_CHATS_ROOT = path.join(os.homedir(), '.cursor', 'chats'); const spawnFunction = process.platform === 'win32' ? crossSpawn : spawn; const ANSI_PATTERN = new RegExp( // eslint-disable-next-line no-control-regex @@ -167,6 +179,31 @@ const buildCursorModelsDefinition = (models: CursorModelRow[]): ProviderModelsDe }; }; +const resolveCursorSessionStorePath = async (sessionId: string): Promise => { + const safeSessionId = sanitizeLeafDirectoryName(sessionId, 'cursor session id'); + + try { + const workspaceEntries = await readdir(CURSOR_CHATS_ROOT, { withFileTypes: true }); + for (const workspaceEntry of workspaceEntries) { + if (!workspaceEntry.isDirectory()) { + continue; + } + + const storeDbPath = path.join(CURSOR_CHATS_ROOT, workspaceEntry.name, safeSessionId, 'store.db'); + try { + await access(storeDbPath); + return storeDbPath; + } catch { + // Keep scanning sibling workspaces until the matching session directory is found. + } + } + } catch { + return null; + } + + return null; +}; + export class CursorProviderModels implements IProviderModels { async getSupportedModels(): Promise { try { @@ -177,4 +214,47 @@ export class CursorProviderModels implements IProviderModels { return CURSOR_FALLBACK_MODELS; } } + + async getCurrentActiveModel(sessionId?: string): Promise { + if (!sessionId?.trim()) { + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } + + try { + const storeDbPath = await resolveCursorSessionStorePath(sessionId); + if (!storeDbPath) { + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } + + const { default: Database } = await import('better-sqlite3'); + const db = new Database(storeDbPath, { readonly: true, fileMustExist: true }); + + try { + const row = db.prepare(`SELECT value FROM meta WHERE key='0' LIMIT 1;`).get() as { + value?: Buffer | string; + } | undefined; + const metadataText = Buffer.isBuffer(row?.value) + ? row.value.toString('utf8') + : typeof row?.value === 'string' && row.value.trim() + ? Buffer.from(row.value.trim(), 'hex').toString('utf8') + : ''; + if (!metadataText) { + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } + + const metadata = JSON.parse(metadataText) as { lastUsedModel?: string }; + if (typeof metadata.lastUsedModel === 'string' && metadata.lastUsedModel.trim()) { + return { + model: metadata.lastUsedModel.trim(), + }; + } + } finally { + db.close(); + } + } catch { + // Fall through to the provider default when Cursor metadata cannot be read. + } + + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } } diff --git a/server/modules/providers/list/cursor/cursor-sessions.provider.ts b/server/modules/providers/list/cursor/cursor-sessions.provider.ts index 90c9afa0..33f93ea5 100644 --- a/server/modules/providers/list/cursor/cursor-sessions.provider.ts +++ b/server/modules/providers/list/cursor/cursor-sessions.provider.ts @@ -4,7 +4,12 @@ import path from 'node:path'; import type { IProviderSessions } from '@/shared/interfaces.js'; import type { AnyRecord, FetchHistoryOptions, FetchHistoryResult, NormalizedMessage } from '@/shared/types.js'; -import { createNormalizedMessage, generateMessageId, readObjectRecord } from '@/shared/utils.js'; +import { + createNormalizedMessage, + generateMessageId, + readObjectRecord, + sanitizeLeafDirectoryName, +} from '@/shared/utils.js'; const PROVIDER = 'cursor'; @@ -186,24 +191,6 @@ function normalizeCursorToolInput(toolName: string, rawInput: unknown): unknown return normalized; } -function sanitizeCursorSessionId(sessionId: string): string { - const normalized = sessionId.trim(); - if (!normalized) { - throw new Error('Cursor session id is required.'); - } - - if ( - normalized.includes('..') - || normalized.includes(path.posix.sep) - || normalized.includes(path.win32.sep) - || normalized !== path.basename(normalized) - ) { - throw new Error(`Invalid cursor session id "${sessionId}".`); - } - - return normalized; -} - export class CursorSessionsProvider implements IProviderSessions { /** * Loads Cursor's SQLite blob DAG and returns message blobs in conversation @@ -214,7 +201,7 @@ export class CursorSessionsProvider implements IProviderSessions { const { default: Database } = await import('better-sqlite3'); const cwdId = crypto.createHash('md5').update(projectPath || process.cwd()).digest('hex'); - const safeSessionId = sanitizeCursorSessionId(sessionId); + const safeSessionId = sanitizeLeafDirectoryName(sessionId, 'cursor session id'); const baseChatsPath = path.join(os.homedir(), '.cursor', 'chats', cwdId); const storeDbPath = path.join(baseChatsPath, safeSessionId, 'store.db'); const resolvedBaseChatsPath = path.resolve(baseChatsPath); diff --git a/server/modules/providers/list/gemini/gemini-models.provider.ts b/server/modules/providers/list/gemini/gemini-models.provider.ts index 4e2eb480..45102058 100644 --- a/server/modules/providers/list/gemini/gemini-models.provider.ts +++ b/server/modules/providers/list/gemini/gemini-models.provider.ts @@ -1,5 +1,6 @@ import type { IProviderModels } from '@/shared/interfaces.js'; -import type { ProviderModelsDefinition } from '@/shared/types.js'; +import type { ProviderCurrentActiveModel, ProviderModelsDefinition } from '@/shared/types.js'; +import { buildDefaultProviderCurrentActiveModel } from '@/shared/utils.js'; export const GEMINI_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -20,4 +21,8 @@ export class GeminiProviderModels implements IProviderModels { async getSupportedModels(): Promise { return GEMINI_FALLBACK_MODELS; } + + async getCurrentActiveModel(): Promise { + return buildDefaultProviderCurrentActiveModel(GEMINI_FALLBACK_MODELS); + } } diff --git a/server/modules/providers/list/opencode/opencode-models.provider.ts b/server/modules/providers/list/opencode/opencode-models.provider.ts index ef5b7c7b..8dc75ee6 100644 --- a/server/modules/providers/list/opencode/opencode-models.provider.ts +++ b/server/modules/providers/list/opencode/opencode-models.provider.ts @@ -1,9 +1,20 @@ +import Database from 'better-sqlite3'; import { spawn } from 'node:child_process'; import crossSpawn from 'cross-spawn'; import type { IProviderModels } from '@/shared/interfaces.js'; -import type { ProviderModelOption, ProviderModelsDefinition } from '@/shared/types.js'; +import type { + ProviderCurrentActiveModel, + ProviderModelOption, + ProviderModelsDefinition, +} from '@/shared/types.js'; +import { + buildDefaultProviderCurrentActiveModel, + getOpenCodeDatabasePath, + readObjectRecord, + readOptionalString, +} from '@/shared/utils.js'; export const OPENCODE_FALLBACK_MODELS: ProviderModelsDefinition = { OPTIONS: [ @@ -66,6 +77,32 @@ const buildOpenCodeDefinitionFromIds = (ids: string[]): ProviderModelsDefinition }; }; +const parseOpenCodeSessionModelValue = (rawModel: unknown): string | null => { + if (typeof rawModel === 'string') { + const trimmed = rawModel.trim(); + if (!trimmed) { + return null; + } + + try { + return parseOpenCodeSessionModelValue(JSON.parse(trimmed)); + } catch { + return trimmed; + } + } + + const record = readObjectRecord(rawModel); + if (!record) { + return null; + } + + return readOptionalString(record.id) + ?? readOptionalString(record.model) + ?? readOptionalString(record.name) + ?? readOptionalString(record.value) + ?? null; +}; + const runOpenCodeModelsCommand = (): Promise => new Promise((resolve, reject) => { const openCodeProcess = spawnFunction('opencode', ['models'], { cwd: process.cwd(), @@ -136,4 +173,51 @@ export class OpenCodeProviderModels implements IProviderModels { return OPENCODE_FALLBACK_MODELS; } } + + async getCurrentActiveModel(sessionId?: string): Promise { + if (!sessionId?.trim()) { + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } + + try { + const dbPath = getOpenCodeDatabasePath(); + const db = new Database(dbPath, { readonly: true, fileMustExist: true }); + + try { + const row = db.prepare(` + SELECT + s.id AS sessionId, + s.model AS model, + s.agent AS agent, + s.directory AS directory, + s.time_updated AS timeUpdated, + s.time_created AS timeCreated + FROM session s + WHERE s.id = ? + ORDER BY COALESCE(s.time_updated, s.time_created, 0) DESC + LIMIT 1 + `).get(sessionId) as { + sessionId?: string; + model?: unknown; + agent?: string | null; + directory?: string | null; + timeUpdated?: number | null; + timeCreated?: number | null; + } | undefined; + + const model = parseOpenCodeSessionModelValue(row?.model); + if (model) { + return { + model, + }; + } + } finally { + db.close(); + } + } catch { + // Fall through to the provider default when OpenCode session lookup fails. + } + + return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels()); + } } diff --git a/server/modules/providers/services/provider-models.service.ts b/server/modules/providers/services/provider-models.service.ts index 031ae30c..e0883fc6 100644 --- a/server/modules/providers/services/provider-models.service.ts +++ b/server/modules/providers/services/provider-models.service.ts @@ -6,6 +6,7 @@ import { providerRegistry } from '@/modules/providers/provider.registry.js'; import type { IProvider } from '@/shared/interfaces.js'; import type { LLMProvider, + ProviderCurrentActiveModel, ProviderModelsCacheInfo, ProviderModelsDefinition, ProviderModelsResult, @@ -264,6 +265,11 @@ export const createProviderModelsService = (dependencies: ProviderModelsServiceD return loadAndCacheModels(provider); }; + const getCurrentActiveModel = async ( + provider: LLMProvider, + sessionId?: string, + ): Promise => resolveProvider(provider).models.getCurrentActiveModel(sessionId); + const clearCache = (): void => { memoryCache.clear(); pendingRequests.clear(); @@ -273,6 +279,7 @@ export const createProviderModelsService = (dependencies: ProviderModelsServiceD return { getProviderModels, + getCurrentActiveModel, clearCache, }; }; diff --git a/server/modules/providers/tests/provider-models.service.test.ts b/server/modules/providers/tests/provider-models.service.test.ts index 2ca107bd..6abc4c45 100644 --- a/server/modules/providers/tests/provider-models.service.test.ts +++ b/server/modules/providers/tests/provider-models.service.test.ts @@ -8,27 +8,42 @@ import { createProviderModelsService, PROVIDER_MODELS_CACHE_TTL_MS, } from '@/modules/providers/services/provider-models.service.js'; -import type { LLMProvider, ProviderModelsDefinition } from '@/shared/types.js'; +import type { + LLMProvider, + ProviderCurrentActiveModel, + ProviderModelsDefinition, +} from '@/shared/types.js'; const createModels = (value: string): ProviderModelsDefinition => ({ OPTIONS: [{ value, label: value }], DEFAULT: value, }); +const createCurrentActiveModel = (model: string): ProviderCurrentActiveModel => ({ + model, +}); + +const createEphemeralCachePath = (): string => path.join( + os.tmpdir(), + `provider-model-cache-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}.json`, +); + test('provider models service delegates to the resolved provider model adapter', async () => { const calls: LLMProvider[] = []; const service = createProviderModelsService({ + cachePath: createEphemeralCachePath(), resolveProvider: (provider) => { calls.push(provider); return { models: { getSupportedModels: async () => createModels(`${provider}-models`), + getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), }, }; }, }); - const models = await service.getProviderModels('codex'); + const models = await service.getProviderModels('codex', { bypassCache: true }); assert.deepEqual(calls, ['codex']); assert.equal(models.models.DEFAULT, 'codex-models'); @@ -45,14 +60,16 @@ test('provider models service returns each provider adapter result without rewri }; const service = createProviderModelsService({ + cachePath: createEphemeralCachePath(), resolveProvider: () => ({ models: { getSupportedModels: async () => expectedModels, + getCurrentActiveModel: async () => createCurrentActiveModel('cursor-active'), }, }), }); - const models = await service.getProviderModels('cursor'); + const models = await service.getProviderModels('cursor', { bypassCache: true }); assert.deepEqual(models.models, expectedModels); }); @@ -72,6 +89,7 @@ test('provider models are cached for the three-day ttl', async () => { loadCount += 1; return createModels(`${provider}-${loadCount}`); }, + getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`), }, }), }); @@ -105,6 +123,7 @@ test('provider model cache is persisted across service instances', async () => { resolveProvider: () => ({ models: { getSupportedModels: async () => createModels('gemini-cached'), + getCurrentActiveModel: async () => createCurrentActiveModel('gemini-active'), }, }), }); @@ -117,6 +136,7 @@ test('provider model cache is persisted across service instances', async () => { getSupportedModels: async () => { throw new Error('loader should not be called for persisted cache hits'); }, + getCurrentActiveModel: async () => createCurrentActiveModel('gemini-active'), }, }), }); @@ -142,6 +162,7 @@ test('concurrent provider model requests share one load operation', async () => await new Promise((resolve) => setTimeout(resolve, 20)); return createModels('claude-cached'); }, + getCurrentActiveModel: async () => createCurrentActiveModel('claude-active'), }, }), }); @@ -174,6 +195,7 @@ test('bypassCache forces a fresh provider fetch and updates cache metadata', asy loadCount += 1; return createModels(`${provider}-${loadCount}`); }, + getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active-${loadCount}`), }, }), }); @@ -191,3 +213,23 @@ test('bypassCache forces a fresh provider fetch and updates cache metadata', asy await rm(tempRoot, { recursive: true, force: true }); } }); + +test('provider models service delegates current active model lookups to the provider adapter', async () => { + const calls: Array<{ provider: LLMProvider; sessionId?: string }> = []; + const service = createProviderModelsService({ + resolveProvider: (provider) => ({ + models: { + getSupportedModels: async () => createModels(`${provider}-models`), + getCurrentActiveModel: async (sessionId) => { + calls.push({ provider, sessionId }); + return createCurrentActiveModel(`${provider}-${sessionId}`); + }, + }, + }), + }); + + const activeModel = await service.getCurrentActiveModel('opencode', 'session-123'); + + assert.deepEqual(calls, [{ provider: 'opencode', sessionId: 'session-123' }]); + assert.equal(activeModel.model, 'opencode-session-123'); +}); diff --git a/server/routes/commands.js b/server/routes/commands.js index 72df1a86..fd197565 100644 --- a/server/routes/commands.js +++ b/server/routes/commands.js @@ -34,20 +34,36 @@ const readModelProvider = (value) => { return MODEL_PROVIDERS.includes(normalized) ? normalized : "claude"; }; +const hasConcreteSessionId = (value) => + typeof value === "string" && value.trim().length > 0; + +const resolveCommandModel = async (provider, catalog, sessionId) => { + if (!hasConcreteSessionId(sessionId)) { + return catalog.DEFAULT; + } + + const currentActiveModel = await providerModelsService.getCurrentActiveModel( + provider, + sessionId, + ); + return currentActiveModel?.model || catalog.DEFAULT; +}; + export const executeModelsCommand = async (args, context) => { const currentProvider = readModelProvider(context?.provider); const result = await providerModelsService.getProviderModels(currentProvider); const catalog = result.models; + const currentModel = await resolveCommandModel( + currentProvider, + catalog, + context?.sessionId, + ); const availableModels = catalog.OPTIONS.map((option) => option.value); const availableOptions = catalog.OPTIONS.map((option) => ({ value: option.value, label: option.label, description: option.description, })); - const currentModel = - typeof context?.model === "string" && context.model - ? context.model - : catalog.DEFAULT; return { type: "builtin", @@ -240,7 +256,7 @@ Custom commands can be created in: const tokenUsage = context?.tokenUsage || {}; const provider = readModelProvider(context?.provider); const catalog = (await providerModelsService.getProviderModels(provider)).models; - const model = context?.model || catalog.DEFAULT; + const model = await resolveCommandModel(provider, catalog, context?.sessionId); const used = Number( @@ -349,6 +365,7 @@ Custom commands can be created in: const statusProvider = readModelProvider(context?.provider); const statusCatalog = (await providerModelsService.getProviderModels(statusProvider)).models; + const model = await resolveCommandModel(statusProvider, statusCatalog, context?.sessionId); const memoryUsage = process.memoryUsage(); return { @@ -359,7 +376,7 @@ Custom commands can be created in: packageName, uptime: uptimeFormatted, uptimeSeconds: Math.floor(uptime), - model: context?.model || statusCatalog.DEFAULT, + model, provider: statusProvider, nodeVersion: process.version, platform: process.platform, diff --git a/server/routes/tests/commands.test.js b/server/routes/tests/commands.test.js index 0dd481a1..ada58dac 100644 --- a/server/routes/tests/commands.test.js +++ b/server/routes/tests/commands.test.js @@ -2,29 +2,81 @@ import assert from 'node:assert/strict'; import test from 'node:test'; import { executeModelsCommand } from '../commands.js'; +import { providerModelsService } from '../../modules/providers/services/provider-models.service.js'; test('models command returns available models only for the active provider', async () => { - const result = await executeModelsCommand([], { - provider: 'codex', - model: 'gpt-5.4', - }); + const originalGetProviderModels = providerModelsService.getProviderModels; + const originalGetCurrentActiveModel = providerModelsService.getCurrentActiveModel; + let getCurrentActiveModelCalls = 0; - assert.equal(result.type, 'builtin'); - assert.equal(result.action, 'models'); - assert.equal(result.data.current.provider, 'codex'); - assert.equal(result.data.current.model, 'gpt-5.4'); - assert.deepEqual(Object.keys(result.data.available), ['codex']); - assert.deepEqual(result.data.available.codex, result.data.availableModels); - assert.ok(result.data.availableModels.includes('gpt-5.4')); - assert.equal(result.data.available.claude, undefined); - assert.equal(result.data.available.cursor, undefined); + providerModelsService.getProviderModels = async () => ({ + models: { + OPTIONS: [{ value: 'gpt-5.4', label: 'gpt-5.4' }], + DEFAULT: 'gpt-5.4', + }, + cache: { + updatedAt: '2026-01-01T00:00:00.000Z', + expiresAt: '2026-01-04T00:00:00.000Z', + source: 'fresh', + }, + }); + providerModelsService.getCurrentActiveModel = async () => { + getCurrentActiveModelCalls += 1; + return { + model: 'gpt-5.3-codex', + }; + }; + + try { + const result = await executeModelsCommand([], { + provider: 'codex', + model: 'gpt-5.4', + }); + + assert.equal(result.type, 'builtin'); + assert.equal(result.action, 'models'); + assert.equal(result.data.current.provider, 'codex'); + assert.equal(result.data.current.model, 'gpt-5.4'); + assert.deepEqual(Object.keys(result.data.available), ['codex']); + assert.deepEqual(result.data.available.codex, result.data.availableModels); + assert.ok(result.data.availableModels.includes('gpt-5.4')); + assert.equal(result.data.available.claude, undefined); + assert.equal(result.data.available.cursor, undefined); + assert.equal(getCurrentActiveModelCalls, 0); + } finally { + providerModelsService.getProviderModels = originalGetProviderModels; + providerModelsService.getCurrentActiveModel = originalGetCurrentActiveModel; + } }); test('models command falls back to claude for unsupported providers', async () => { - const result = await executeModelsCommand([], { - provider: 'unknown-provider', + const originalGetProviderModels = providerModelsService.getProviderModels; + const originalGetCurrentActiveModel = providerModelsService.getCurrentActiveModel; + + providerModelsService.getProviderModels = async () => ({ + models: { + OPTIONS: [{ value: 'default', label: 'Default (recommended)' }], + DEFAULT: 'default', + }, + cache: { + updatedAt: '2026-01-01T00:00:00.000Z', + expiresAt: '2026-01-04T00:00:00.000Z', + source: 'fresh', + }, + }); + providerModelsService.getCurrentActiveModel = async () => ({ + model: 'default', }); - assert.equal(result.data.current.provider, 'claude'); - assert.deepEqual(Object.keys(result.data.available), ['claude']); + try { + const result = await executeModelsCommand([], { + provider: 'unknown-provider', + }); + + assert.equal(result.data.current.provider, 'claude'); + assert.deepEqual(Object.keys(result.data.available), ['claude']); + } finally { + providerModelsService.getProviderModels = originalGetProviderModels; + providerModelsService.getCurrentActiveModel = originalGetCurrentActiveModel; + } }); diff --git a/server/shared/interfaces.ts b/server/shared/interfaces.ts index acf8d516..9228698e 100644 --- a/server/shared/interfaces.ts +++ b/server/shared/interfaces.ts @@ -7,6 +7,7 @@ import type { ProviderSkill, ProviderSkillListOptions, ProviderAuthStatus, + ProviderCurrentActiveModel, ProviderModelsDefinition, ProviderMcpServer, UpsertProviderMcpServerInput, @@ -45,6 +46,15 @@ export interface IProviderModels { * Returns the provider's currently supported model catalog. */ getSupportedModels(): Promise; + + /** + * Returns the currently active model for one session or provider runtime. + * + * Implementations must use the provider-specific lookup mechanism approved + * for that provider and fall back only to the provider catalog default when + * no active model can be resolved. + */ + getCurrentActiveModel(sessionId?: string): Promise; } // --------------------------- diff --git a/server/shared/types.ts b/server/shared/types.ts index c9fb2878..121cc509 100644 --- a/server/shared/types.ts +++ b/server/shared/types.ts @@ -109,6 +109,20 @@ export type ProviderModelsResult = { cache: ProviderModelsCacheInfo; }; +// --------------------------- +//----------------- PROVIDER ACTIVE MODEL TYPES ------------ +/** + * Provider-neutral result for the model that is actively driving a session or + * provider runtime at the time of lookup. + * + * `model` must always be populated. Provider adapters should use the + * provider-specific lookup method requested by the caller, and only fall back + * to the provider catalog `DEFAULT` value when the active model cannot be read. + */ +export type ProviderCurrentActiveModel = { + model: string; +}; + /** * Message/event variants emitted by provider adapters and normalized transports. * diff --git a/server/shared/utils.ts b/server/shared/utils.ts index 62762f39..437ea034 100644 --- a/server/shared/utils.ts +++ b/server/shared/utils.ts @@ -23,6 +23,8 @@ import type { ApiSuccessShape, AppErrorOptions, NormalizedMessage, + ProviderCurrentActiveModel, + ProviderModelsDefinition, ProviderSkillSource, WorkspacePathValidationResult, } from '@/shared/types.js'; @@ -414,6 +416,24 @@ export const readStringRecord = (value: unknown): Record | undef return Object.keys(normalized).length > 0 ? normalized : undefined; }; +// --------------------------- +//----------------- PROVIDER MODEL LOOKUP UTILITIES ------------ +/** + * Builds the standard "default current model" result used when a provider + * cannot resolve a session-backed active model. + * + * Provider model adapters should call this after loading their supported model + * catalog so the fallback stays aligned with the provider's current `DEFAULT` + * selection instead of drifting to a hard-coded duplicate. + */ +export function buildDefaultProviderCurrentActiveModel( + models: ProviderModelsDefinition, +): ProviderCurrentActiveModel { + return { + model: models.DEFAULT, + }; +} + // --------------------------- //----------------- WEBSOCKET PAYLOAD PARSING UTILITIES ------------ /** @@ -742,6 +762,34 @@ export function getOpenCodeDatabasePath(): string { return path.join(os.homedir(), '.local', 'share', 'opencode', 'opencode.db'); } +// --------------------------- +//----------------- SAFE DIRECTORY NAME UTILITIES ------------ +/** + * Validates that a user or provider supplied identifier can safely be treated + * as one leaf directory name under an existing root folder. + * + * Use this before composing paths like `//file.db>` to block + * path traversal and accidental nested paths. The returned string is trimmed but + * otherwise unchanged so callers can still match the provider's on-disk naming. + */ +export function sanitizeLeafDirectoryName(inputName: string, label = 'directory name'): string { + const normalized = inputName.trim(); + if (!normalized) { + throw new Error(`${label} is required.`); + } + + if ( + normalized.includes('..') + || normalized.includes(path.posix.sep) + || normalized.includes(path.win32.sep) + || normalized !== path.basename(normalized) + ) { + throw new Error(`Invalid ${label} "${inputName}".`); + } + + return normalized; +} + // --------------------------- //----------------- SESSION SYNCHRONIZER FILESYSTEM HELPERS ------------ /**