diff --git a/server/modules/providers/services/provider-models.service.ts b/server/modules/providers/services/provider-models.service.ts index 1c4d851d..abceb4ab 100644 --- a/server/modules/providers/services/provider-models.service.ts +++ b/server/modules/providers/services/provider-models.service.ts @@ -1,11 +1,16 @@ import { spawn } from 'node:child_process'; import fsSync from 'node:fs'; +import { mkdir, readFile, writeFile } from 'node:fs/promises'; +import os from 'node:os'; +import path from 'node:path'; import crossSpawn from 'cross-spawn'; import type { LLMProvider, ProviderModelOption, ProviderModelsDefinition } from '@/shared/types.js'; const OPEN_CODE_MODELS_TIMEOUT_MS = 20_000; +export const PROVIDER_MODELS_CACHE_TTL_MS = 2 * 24 * 60 * 60 * 1000; +const PROVIDER_MODELS_CACHE_VERSION = 1; /** * Claude (Anthropic) — SDK-style ids used by the UI and claude-sdk.js. @@ -101,6 +106,31 @@ const BUILTIN_BY_PROVIDER: Record, ProviderMode gemini: GEMINI_MODELS, }; +type ProviderModelsOptions = { + cwd?: string; +}; + +type ProviderModelsLoader = ( + provider: LLMProvider, + options?: ProviderModelsOptions, +) => Promise; + +type ProviderModelsCacheEntry = { + expiresAt: number; + models: ProviderModelsDefinition; +}; + +type ProviderModelsCacheFile = { + version: number; + entries: Record; +}; + +type ProviderModelsServiceDependencies = { + cachePath?: string; + loadModels?: ProviderModelsLoader; + now?: () => number; +}; + const MODEL_ID_LINE = /^[a-z0-9][a-z0-9._-]*\/[a-z0-9][a-z0-9._-]*$/i; const parseOpenCodeModelsStdout = (stdout: string): string[] => { @@ -144,6 +174,81 @@ const resolveOpenCodeCwd = (cwd?: string): string => { return process.cwd(); }; +const getProviderModelsCachePath = (): string => + process.env.CLOUDCLI_PROVIDER_MODELS_CACHE_PATH + || path.join(os.homedir(), '.cloudcli', 'provider-models-cache.json'); + +const getProviderModelsCacheKey = ( + provider: LLMProvider, + options?: ProviderModelsOptions, +): string => { + if (provider === 'opencode') { + return `${provider}:${resolveOpenCodeCwd(options?.cwd)}`; + } + + return provider; +}; + +const isProviderModelOption = (value: unknown): value is ProviderModelOption => ( + Boolean(value) + && typeof value === 'object' + && typeof (value as ProviderModelOption).value === 'string' + && typeof (value as ProviderModelOption).label === 'string' +); + +const isProviderModelsDefinition = (value: unknown): value is ProviderModelsDefinition => ( + Boolean(value) + && typeof value === 'object' + && Array.isArray((value as ProviderModelsDefinition).OPTIONS) + && (value as ProviderModelsDefinition).OPTIONS.every(isProviderModelOption) + && typeof (value as ProviderModelsDefinition).DEFAULT === 'string' +); + +const isProviderModelsCacheEntry = (value: unknown): value is ProviderModelsCacheEntry => ( + Boolean(value) + && typeof value === 'object' + && typeof (value as ProviderModelsCacheEntry).expiresAt === 'number' + && isProviderModelsDefinition((value as ProviderModelsCacheEntry).models) +); + +const readProviderModelsCacheFile = async ( + cachePath: string, +): Promise => { + try { + const raw = await readFile(cachePath, 'utf8'); + const parsed = JSON.parse(raw) as Partial; + if (parsed.version !== PROVIDER_MODELS_CACHE_VERSION || !parsed.entries || typeof parsed.entries !== 'object') { + return null; + } + + const entries = Object.fromEntries( + Object.entries(parsed.entries).filter((entry): entry is [string, ProviderModelsCacheEntry] => + isProviderModelsCacheEntry(entry[1]), + ), + ); + return { version: PROVIDER_MODELS_CACHE_VERSION, entries }; + } catch { + return null; + } +}; + +const writeProviderModelsCacheFile = async ( + cachePath: string, + entries: Map, + now: number, +): Promise => { + const serializableEntries = Object.fromEntries( + [...entries.entries()].filter(([, entry]) => entry.expiresAt > now), + ); + const payload: ProviderModelsCacheFile = { + version: PROVIDER_MODELS_CACHE_VERSION, + entries: serializableEntries, + }; + + await mkdir(path.dirname(cachePath), { recursive: true }); + await writeFile(cachePath, `${JSON.stringify(payload, null, 2)}\n`, 'utf8'); +}; + const runOpenCodeModelsCommand = (cwd?: string): Promise => new Promise((resolve, reject) => { const spawnFn = process.platform === 'win32' ? crossSpawn : spawn; @@ -222,7 +327,133 @@ async function getProviderModelsInternal( } } -export const providerModelsService = { - getProviderModels: (provider: LLMProvider, options?: { cwd?: string }): Promise => - getProviderModelsInternal(provider, options), +export const createProviderModelsService = (dependencies: ProviderModelsServiceDependencies = {}) => { + const memoryCache = new Map(); + const pendingRequests = new Map>(); + const loadModels = dependencies.loadModels ?? getProviderModelsInternal; + const now = dependencies.now ?? (() => Date.now()); + let persistedCacheLoaded = false; + let persistedCacheLoadPromise: Promise | null = null; + + const loadPersistedCache = async (cachePath: string): Promise => { + if (persistedCacheLoaded) { + return; + } + + if (!persistedCacheLoadPromise) { + persistedCacheLoadPromise = (async () => { + const cacheFile = await readProviderModelsCacheFile(cachePath); + const currentTime = now(); + for (const [key, entry] of Object.entries(cacheFile?.entries ?? {})) { + if (entry.expiresAt > currentTime) { + memoryCache.set(key, entry); + } + } + persistedCacheLoaded = true; + })().finally(() => { + persistedCacheLoadPromise = null; + }); + } + + await persistedCacheLoadPromise; + }; + + const persistCache = async (cachePath: string): Promise => { + try { + await writeProviderModelsCacheFile(cachePath, memoryCache, now()); + } catch (error) { + console.warn('Unable to persist provider models cache:', error); + } + }; + + const setCacheEntry = async ( + cachePath: string, + cacheKey: string, + models: ProviderModelsDefinition, + ): Promise => { + const entry = { + expiresAt: now() + PROVIDER_MODELS_CACHE_TTL_MS, + models, + }; + memoryCache.set(cacheKey, entry); + + await persistCache(cachePath); + }; + + const loadAndCacheModels = ( + provider: LLMProvider, + options: ProviderModelsOptions | undefined, + cachePath: string, + cacheKey: string, + ): Promise => { + const request = loadModels(provider, options) + .then(async (models) => { + await setCacheEntry(cachePath, cacheKey, models); + return models; + }) + .finally(() => { + pendingRequests.delete(cacheKey); + }); + + pendingRequests.set(cacheKey, request); + return request; + }; + + const pruneExpiredMemoryEntry = (cacheKey: string, currentTime: number): ProviderModelsDefinition | null => { + const cachedEntry = memoryCache.get(cacheKey); + if (!cachedEntry) { + return null; + } + + if (cachedEntry.expiresAt > currentTime) { + return cachedEntry.models; + } + + memoryCache.delete(cacheKey); + return null; + }; + + const getProviderModels = async ( + provider: LLMProvider, + options?: ProviderModelsOptions, + ): Promise => { + const cachePath = dependencies.cachePath ?? getProviderModelsCachePath(); + const cacheKey = getProviderModelsCacheKey(provider, options); + const cachedModels = pruneExpiredMemoryEntry(cacheKey, now()); + if (cachedModels) { + return cachedModels; + } + + const pendingRequest = pendingRequests.get(cacheKey); + if (pendingRequest) { + return pendingRequest; + } + + await loadPersistedCache(cachePath); + const persistedModels = pruneExpiredMemoryEntry(cacheKey, now()); + if (persistedModels) { + return persistedModels; + } + + const postLoadPendingRequest = pendingRequests.get(cacheKey); + if (postLoadPendingRequest) { + return postLoadPendingRequest; + } + + return loadAndCacheModels(provider, options, cachePath, cacheKey); + }; + + const clearCache = (): void => { + memoryCache.clear(); + pendingRequests.clear(); + persistedCacheLoaded = false; + persistedCacheLoadPromise = null; + }; + + return { + getProviderModels, + clearCache, + }; }; + +export const providerModelsService = createProviderModelsService(); diff --git a/server/modules/providers/tests/provider-models.service.test.ts b/server/modules/providers/tests/provider-models.service.test.ts new file mode 100644 index 00000000..003d91be --- /dev/null +++ b/server/modules/providers/tests/provider-models.service.test.ts @@ -0,0 +1,128 @@ +import assert from 'node:assert/strict'; +import { mkdir, mkdtemp, rm } from 'node:fs/promises'; +import os from 'node:os'; +import path from 'node:path'; +import test from 'node:test'; + +import { + createProviderModelsService, + PROVIDER_MODELS_CACHE_TTL_MS, +} from '@/modules/providers/services/provider-models.service.js'; +import type { LLMProvider, ProviderModelsDefinition } from '@/shared/types.js'; + +const createModels = (value: string): ProviderModelsDefinition => ({ + OPTIONS: [{ value, label: value }], + DEFAULT: value, +}); + +test('provider models are cached for the two-day ttl', async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), 'provider-model-cache-ttl-')); + let currentTime = 1_000; + let loadCount = 0; + + try { + const service = createProviderModelsService({ + cachePath: path.join(tempRoot, 'models-cache.json'), + now: () => currentTime, + loadModels: async (provider: LLMProvider) => { + loadCount += 1; + return createModels(`${provider}-${loadCount}`); + }, + }); + + const first = await service.getProviderModels('codex'); + const cached = await service.getProviderModels('codex'); + assert.equal(loadCount, 1); + assert.equal(cached.DEFAULT, first.DEFAULT); + + currentTime += PROVIDER_MODELS_CACHE_TTL_MS - 1; + await service.getProviderModels('codex'); + assert.equal(loadCount, 1); + + currentTime += 2; + const refreshed = await service.getProviderModels('codex'); + assert.equal(loadCount, 2); + assert.equal(refreshed.DEFAULT, 'codex-2'); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } +}); + +test('provider model cache is persisted across service instances', async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), 'provider-model-cache-file-')); + const cachePath = path.join(tempRoot, 'models-cache.json'); + + try { + const writer = createProviderModelsService({ + cachePath, + loadModels: async () => createModels('gemini-cached'), + }); + await writer.getProviderModels('gemini'); + + const reader = createProviderModelsService({ + cachePath, + loadModels: async () => { + throw new Error('loader should not be called for persisted cache hits'); + }, + }); + const models = await reader.getProviderModels('gemini'); + assert.equal(models.DEFAULT, 'gemini-cached'); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } +}); + +test('concurrent provider model requests share one load operation', async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), 'provider-model-cache-pending-')); + let loadCount = 0; + + try { + const service = createProviderModelsService({ + cachePath: path.join(tempRoot, 'models-cache.json'), + loadModels: async () => { + loadCount += 1; + await new Promise((resolve) => setTimeout(resolve, 20)); + return createModels('claude-cached'); + }, + }); + + const [first, second] = await Promise.all([ + service.getProviderModels('claude'), + service.getProviderModels('claude'), + ]); + + assert.equal(loadCount, 1); + assert.equal(first.DEFAULT, 'claude-cached'); + assert.equal(second.DEFAULT, 'claude-cached'); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } +}); + +test('opencode model cache is scoped by workspace cwd', async () => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), 'provider-model-cache-opencode-')); + const workspaceA = path.join(tempRoot, 'workspace-a'); + const workspaceB = path.join(tempRoot, 'workspace-b'); + let loadCount = 0; + + try { + await mkdir(workspaceA, { recursive: true }); + await mkdir(workspaceB, { recursive: true }); + + const service = createProviderModelsService({ + cachePath: path.join(tempRoot, 'models-cache.json'), + loadModels: async () => { + loadCount += 1; + return createModels(`opencode-${loadCount}`); + }, + }); + + await service.getProviderModels('opencode', { cwd: workspaceA }); + await service.getProviderModels('opencode', { cwd: workspaceA }); + await service.getProviderModels('opencode', { cwd: workspaceB }); + + assert.equal(loadCount, 2); + } finally { + await rm(tempRoot, { recursive: true, force: true }); + } +}); diff --git a/server/routes/commands.js b/server/routes/commands.js index d0ca6bb0..4260f5ad 100644 --- a/server/routes/commands.js +++ b/server/routes/commands.js @@ -15,6 +15,65 @@ const APP_ROOT = findAppRoot(__dirname); const router = express.Router(); +const MODEL_PROVIDERS = ['claude', 'cursor', 'codex', 'gemini', 'opencode']; + +const MODEL_PROVIDER_LABELS = { + claude: 'Claude', + cursor: 'Cursor', + codex: 'Codex', + gemini: 'Gemini', + opencode: 'OpenCode', +}; + +const readModelProvider = (value) => { + if (typeof value !== 'string') { + return 'claude'; + } + + const normalized = value.trim().toLowerCase(); + return MODEL_PROVIDERS.includes(normalized) ? normalized : 'claude'; +}; + +const getProviderModelOptions = (provider, context) => { + if (provider !== 'opencode') { + return undefined; + } + + const cwd = typeof context?.projectPath === 'string' ? context.projectPath : undefined; + return { cwd }; +}; + +export const executeModelsCommand = async (args, context) => { + const currentProvider = readModelProvider(context?.provider); + const catalog = await providerModelsService.getProviderModels( + currentProvider, + getProviderModelOptions(currentProvider, context), + ); + const availableModels = catalog.OPTIONS.map((option) => option.value); + const currentModel = typeof context?.model === 'string' && context.model + ? context.model + : catalog.DEFAULT; + + return { + type: 'builtin', + action: 'models', + data: { + current: { + provider: currentProvider, + providerLabel: MODEL_PROVIDER_LABELS[currentProvider], + model: currentModel + }, + available: { + [currentProvider]: availableModels, + }, + availableModels, + message: args.length > 0 + ? `Switching to model: ${args[0]}` + : `Current model: ${currentModel}` + } + }; +}; + /** * Recursively scan directory for command files (.md) * @param {string} dir - Directory to scan @@ -90,14 +149,8 @@ const builtInCommands = [ metadata: { type: 'builtin' } }, { - name: '/clear', - description: 'Clear the conversation history', - namespace: 'builtin', - metadata: { type: 'builtin' } - }, - { - name: '/model', - description: 'Switch or view the current AI model', + name: '/models', + description: 'View available models for the current provider', namespace: 'builtin', metadata: { type: 'builtin' } }, @@ -125,12 +178,6 @@ const builtInCommands = [ namespace: 'builtin', metadata: { type: 'builtin' } }, - { - name: '/rewind', - description: 'Rewind the conversation to a previous state', - namespace: 'builtin', - metadata: { type: 'builtin' } - } ]; /** @@ -176,58 +223,7 @@ Custom commands can be created in: }; }, - '/clear': async (args, context) => { - return { - type: 'builtin', - action: 'clear', - data: { - message: 'Conversation history cleared' - } - }; - }, - - '/model': async (args, context) => { - const [claude, cursor, codex, gemini, opencode] = await Promise.all([ - providerModelsService.getProviderModels('claude'), - providerModelsService.getProviderModels('cursor'), - providerModelsService.getProviderModels('codex'), - providerModelsService.getProviderModels('gemini'), - providerModelsService.getProviderModels('opencode'), - ]); - - const availableModels = { - claude: claude.OPTIONS.map(o => o.value), - cursor: cursor.OPTIONS.map(o => o.value), - codex: codex.OPTIONS.map(o => o.value), - gemini: gemini.OPTIONS.map(o => o.value), - opencode: opencode.OPTIONS.map(o => o.value), - }; - - const currentProvider = context?.provider || 'claude'; - const defaults = { - claude: claude.DEFAULT, - cursor: cursor.DEFAULT, - codex: codex.DEFAULT, - gemini: gemini.DEFAULT, - opencode: opencode.DEFAULT, - }; - const currentModel = context?.model || defaults[currentProvider] || claude.DEFAULT; - - return { - type: 'builtin', - action: 'model', - data: { - current: { - provider: currentProvider, - model: currentModel - }, - available: availableModels, - message: args.length > 0 - ? `Switching to model: ${args[0]}` - : `Current model: ${currentModel}` - } - }; - }, + '/models': executeModelsCommand, '/cost': async (args, context) => { const tokenUsage = context?.tokenUsage || {}; @@ -392,30 +388,6 @@ Custom commands can be created in: message: 'Opening settings...' } }; - }, - - '/rewind': async (args, context) => { - const steps = args[0] ? parseInt(args[0]) : 1; - - if (isNaN(steps) || steps < 1) { - return { - type: 'builtin', - action: 'rewind', - data: { - error: 'Invalid steps parameter', - message: 'Usage: /rewind [number] - Rewind conversation by N steps (default: 1)' - } - }; - } - - return { - type: 'builtin', - action: 'rewind', - data: { - steps, - message: `Rewinding conversation by ${steps} step${steps > 1 ? 's' : ''}...` - } - }; } }; diff --git a/server/routes/tests/commands.test.js b/server/routes/tests/commands.test.js new file mode 100644 index 00000000..041fdf38 --- /dev/null +++ b/server/routes/tests/commands.test.js @@ -0,0 +1,54 @@ +import assert from 'node:assert/strict'; +import { mkdtemp, rm } from 'node:fs/promises'; +import os from 'node:os'; +import path from 'node:path'; +import test from 'node:test'; + +import { executeModelsCommand } from '../commands.js'; + +const withTemporaryModelsCache = async (callback) => { + const tempRoot = await mkdtemp(path.join(os.tmpdir(), 'commands-model-cache-')); + const previousCachePath = process.env.CLOUDCLI_PROVIDER_MODELS_CACHE_PATH; + process.env.CLOUDCLI_PROVIDER_MODELS_CACHE_PATH = path.join(tempRoot, 'models-cache.json'); + + try { + await callback(); + } finally { + if (previousCachePath === undefined) { + delete process.env.CLOUDCLI_PROVIDER_MODELS_CACHE_PATH; + } else { + process.env.CLOUDCLI_PROVIDER_MODELS_CACHE_PATH = previousCachePath; + } + await rm(tempRoot, { recursive: true, force: true }); + } +}; + +test('models command returns available models only for the active provider', async () => { + await withTemporaryModelsCache(async () => { + const result = await executeModelsCommand([], { + provider: 'codex', + model: 'gpt-5.4', + }); + + assert.equal(result.type, 'builtin'); + assert.equal(result.action, 'models'); + assert.equal(result.data.current.provider, 'codex'); + assert.equal(result.data.current.model, 'gpt-5.4'); + assert.deepEqual(Object.keys(result.data.available), ['codex']); + assert.deepEqual(result.data.available.codex, result.data.availableModels); + assert.ok(result.data.availableModels.includes('gpt-5.4')); + assert.equal(result.data.available.claude, undefined); + assert.equal(result.data.available.cursor, undefined); + }); +}); + +test('models command falls back to claude for unsupported providers', async () => { + await withTemporaryModelsCache(async () => { + const result = await executeModelsCommand([], { + provider: 'unknown-provider', + }); + + assert.equal(result.data.current.provider, 'claude'); + assert.deepEqual(Object.keys(result.data.available), ['claude']); + }); +}); diff --git a/src/components/chat/hooks/useChatComposerState.ts b/src/components/chat/hooks/useChatComposerState.ts index 0b70f306..79fe23f6 100644 --- a/src/components/chat/hooks/useChatComposerState.ts +++ b/src/components/chat/hooks/useChatComposerState.ts @@ -55,8 +55,6 @@ interface UseChatComposerStateArgs { pendingViewSessionRef: { current: PendingViewSession | null }; scrollToBottom: () => void; addMessage: (msg: ChatMessage) => void; - clearMessages: () => void; - rewindMessages: (count: number) => void; setIsLoading: (loading: boolean) => void; setCanAbortSession: (canAbort: boolean) => void; setClaudeStatus: (status: { text: string; tokens: number; can_interrupt: boolean } | null) => void; @@ -78,6 +76,50 @@ interface CommandExecutionResult { hasFileIncludes?: boolean; } +type ModelCommandData = { + current?: { + provider?: string; + providerLabel?: string; + model?: string; + }; + available?: Partial>; + availableModels?: string[]; +}; + +const PROVIDER_LABELS: Record = { + claude: 'Claude', + cursor: 'Cursor', + codex: 'Codex', + gemini: 'Gemini', + opencode: 'OpenCode', +}; + +const isLLMProvider = (value: unknown): value is LLMProvider => ( + value === 'claude' + || value === 'cursor' + || value === 'codex' + || value === 'gemini' + || value === 'opencode' +); + +const formatModelCommandMessage = (data: ModelCommandData): string => { + const currentProvider = isLLMProvider(data.current?.provider) + ? data.current.provider + : 'claude'; + const providerLabel = data.current?.providerLabel || PROVIDER_LABELS[currentProvider]; + const currentModel = data.current?.model || 'Unknown'; + // `availableModels` is the current response shape; the keyed map keeps older + // server responses readable without reintroducing cross-provider rendering. + const availableModels = Array.isArray(data.availableModels) + ? data.availableModels + : data.available?.[currentProvider] ?? []; + const availableText = availableModels.length > 0 + ? availableModels.join(', ') + : 'No models reported for this provider.'; + + return `**Current Model**: ${currentModel}\n\n**Provider**: ${providerLabel}\n\n**Available Models**:\n\n${availableText}`; +}; + const createFakeSubmitEvent = () => { return { preventDefault: () => undefined } as unknown as FormEvent; }; @@ -125,8 +167,6 @@ export function useChatComposerState({ pendingViewSessionRef, scrollToBottom, addMessage, - clearMessages, - rewindMessages, setIsLoading, setCanAbortSession, setClaudeStatus, @@ -159,10 +199,6 @@ export function useChatComposerState({ (result: CommandExecutionResult) => { const { action, data } = result; switch (action) { - case 'clear': - clearMessages(); - break; - case 'help': addMessage({ type: 'assistant', @@ -171,10 +207,10 @@ export function useChatComposerState({ }); break; - case 'model': + case 'models': addMessage({ type: 'assistant', - content: `**Current Model**: ${data.current.model}\n\n**Available Models**:\n\nClaude: ${data.available.claude.join(', ')}\n\nCursor: ${data.available.cursor.join(', ')}`, + content: formatModelCommandMessage(data as ModelCommandData), timestamp: Date.now(), }); break; @@ -214,28 +250,11 @@ export function useChatComposerState({ onShowSettings?.(); break; - case 'rewind': - if (data.error) { - addMessage({ - type: 'assistant', - content: `Warning: ${data.message}`, - timestamp: Date.now(), - }); - } else { - rewindMessages(data.steps * 2); - addMessage({ - type: 'assistant', - content: `Rewound ${data.steps} step(s). ${data.message}`, - timestamp: Date.now(), - }); - } - break; - default: console.warn('Unknown built-in command action:', action); } }, - [onFileOpen, onShowSettings, addMessage, clearMessages, rewindMessages], + [onFileOpen, onShowSettings, addMessage], ); const handleCustomCommand = useCallback(async (result: CommandExecutionResult) => { diff --git a/src/components/chat/view/ChatInterface.tsx b/src/components/chat/view/ChatInterface.tsx index 4a0781d3..ea071212 100644 --- a/src/components/chat/view/ChatInterface.tsx +++ b/src/components/chat/view/ChatInterface.tsx @@ -87,8 +87,6 @@ function ChatInterface({ const { chatMessages, addMessage, - clearMessages, - rewindMessages, isLoading, setIsLoading, currentSessionId, @@ -200,8 +198,6 @@ function ChatInterface({ pendingViewSessionRef, scrollToBottom, addMessage, - clearMessages, - rewindMessages, setIsLoading, setCanAbortSession, setClaudeStatus,