mirror of
https://github.com/siteboon/claudecodeui.git
synced 2026-05-28 14:55:34 +08:00
feat: support session-scoped model overrides
Model selection was acting like a provider-level preference. That made resumed sessions drift back to a default or request-time model. Users expect /models changes made inside a conversation to affect that session. Store explicit session choices in app-owned ~/.cloudcli state. This avoids editing provider transcripts or native provider config. Resolve the effective model before launching each provider runtime. Claude, Cursor, Codex, Gemini, and OpenCode now honor stored resume choices. Expose a backend active-model change endpoint for existing sessions. The models modal can now distinguish default changes from session overrides. It also shows when a selected model will apply on the next response. For Claude, stop probing active model state by resuming with a dummy prompt. Read the indexed JSONL transcript from the end instead. This preserves provider history while honoring /model stdout or model fields. Add service tests for adapter delegation and resume-model precedence. The tests keep cache state, override state, and requested fallback separate.
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
import { spawn } from 'node:child_process';
|
||||
import { readFile } from 'node:fs/promises';
|
||||
|
||||
import { query, type ModelInfo, type Options } from '@anthropic-ai/claude-agent-sdk';
|
||||
import crossSpawn from 'cross-spawn';
|
||||
|
||||
import { sessionsDb } from '@/modules/database/index.js';
|
||||
import { resolveClaudeCodeExecutablePath } from '@/shared/claude-cli-path.js';
|
||||
import type { IProviderModels } from '@/shared/interfaces.js';
|
||||
import type {
|
||||
ProviderChangeActiveModelInput,
|
||||
ProviderCurrentActiveModel,
|
||||
ProviderModelOption,
|
||||
ProviderModelsDefinition,
|
||||
ProviderSessionActiveModelChange,
|
||||
} from '@/shared/types.js';
|
||||
import { buildDefaultProviderCurrentActiveModel } from '@/shared/utils.js';
|
||||
import {
|
||||
buildDefaultProviderCurrentActiveModel,
|
||||
writeProviderSessionActiveModelChange,
|
||||
} from '@/shared/utils.js';
|
||||
|
||||
export const CLAUDE_FALLBACK_MODELS: ProviderModelsDefinition = {
|
||||
OPTIONS: [
|
||||
@@ -26,13 +31,23 @@ export const CLAUDE_FALLBACK_MODELS: ProviderModelsDefinition = {
|
||||
|
||||
type ClaudeModelQueryOptions = Pick<Options, 'env' | 'pathToClaudeCodeExecutable' | 'permissionMode'>;
|
||||
type ClaudeInitEvent = {
|
||||
sessionId?: string;
|
||||
session_id?: string;
|
||||
type?: string;
|
||||
subtype?: string;
|
||||
model?: string;
|
||||
message?: {
|
||||
content?: unknown;
|
||||
model?: string;
|
||||
};
|
||||
};
|
||||
|
||||
const CLAUDE_ACTIVE_MODEL_TIMEOUT_MS = 20_000;
|
||||
const claudeSpawn = process.platform === 'win32' ? crossSpawn : spawn;
|
||||
const ANSI_PATTERN = new RegExp(
|
||||
'[\\u001B\\u009B][[\\]()#;?]*(?:'
|
||||
+ '(?:[0-9]{1,4}(?:;[0-9]{0,4})*)?[0-9A-ORZcf-nqry=><]'
|
||||
+ '|(?:[\\dA-PR-TZcf-ntqry=><~]))',
|
||||
'g',
|
||||
);
|
||||
|
||||
const buildClaudeQueryOptions = (): ClaudeModelQueryOptions => ({
|
||||
env: { ...process.env },
|
||||
@@ -74,82 +89,94 @@ const buildClaudeModelsDefinition = (models: ModelInfo[]): ProviderModelsDefinit
|
||||
};
|
||||
};
|
||||
|
||||
const runClaudeSessionModelCommand = async (sessionId: string): Promise<ProviderCurrentActiveModel | null> => {
|
||||
const cliPath = resolveClaudeCodeExecutablePath(process.env.CLAUDE_CLI_PATH);
|
||||
const extractClaudeEventModel = (event: ClaudeInitEvent, sessionId: string): string | null => {
|
||||
const eventSessionId = event.sessionId ?? event.session_id;
|
||||
if (eventSessionId && eventSessionId !== sessionId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const child = claudeSpawn(
|
||||
cliPath,
|
||||
['-p', '--verbose', '--output-format', 'stream-json', '--resume', sessionId, 'ok'],
|
||||
{
|
||||
env: { ...process.env },
|
||||
windowsHide: true,
|
||||
},
|
||||
);
|
||||
const contentModel = extractClaudeModelFromMessageContent(event.message?.content);
|
||||
if (contentModel) {
|
||||
return contentModel;
|
||||
}
|
||||
|
||||
let stdout = '';
|
||||
let stderr = '';
|
||||
let settled = false;
|
||||
const directModel = event.model?.trim();
|
||||
if (directModel) {
|
||||
return directModel;
|
||||
}
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
child.kill('SIGTERM');
|
||||
if (!settled) {
|
||||
settled = true;
|
||||
reject(new Error('Claude current-model lookup timed out'));
|
||||
const messageModel = event.message?.model?.trim();
|
||||
return messageModel || null;
|
||||
};
|
||||
|
||||
const stripAnsi = (value: string): string => value.replace(ANSI_PATTERN, '');
|
||||
|
||||
const extractTaggedContent = (content: string, tagName: string): string | null => {
|
||||
const escapedTagName = tagName.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const match = new RegExp(`<${escapedTagName}>([\\s\\S]*?)<\\/${escapedTagName}>`).exec(content);
|
||||
return match ? match[1] : null;
|
||||
};
|
||||
|
||||
const extractClaudeModelFromTextContent = (content: string): string | null => {
|
||||
const localCommandStdout = extractTaggedContent(content, 'local-command-stdout');
|
||||
if (localCommandStdout !== null) {
|
||||
const cleanedStdout = stripAnsi(localCommandStdout).replace(/\s+/g, ' ').trim();
|
||||
const changedModel = /(?:set|changed|switched)\s+model\s+to\s+(.+?)\.?$/i.exec(cleanedStdout);
|
||||
if (changedModel?.[1]?.trim()) {
|
||||
return changedModel[1].trim();
|
||||
}
|
||||
}
|
||||
|
||||
const modelTag = extractTaggedContent(content, 'model')?.trim();
|
||||
return modelTag || null;
|
||||
};
|
||||
|
||||
const extractClaudeModelFromMessageContent = (content: unknown): string | null => {
|
||||
if (typeof content === 'string') {
|
||||
return extractClaudeModelFromTextContent(content);
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
for (const part of content) {
|
||||
if (!part || typeof part !== 'object' || !('text' in part) || typeof part.text !== 'string') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const model = extractClaudeModelFromTextContent(part.text);
|
||||
if (model) {
|
||||
return model;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const readClaudeSessionModelFromJsonl = async (
|
||||
sessionId: string,
|
||||
jsonlPath: string,
|
||||
): Promise<ProviderCurrentActiveModel | null> => {
|
||||
const content = await readFile(jsonlPath, 'utf8');
|
||||
const lines = content
|
||||
.split(/\r?\n/)
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean);
|
||||
|
||||
for (let index = lines.length - 1; index >= 0; index -= 1) {
|
||||
try {
|
||||
const event = JSON.parse(lines[index]) as ClaudeInitEvent;
|
||||
const model = extractClaudeEventModel(event, sessionId);
|
||||
if (model) {
|
||||
return { model };
|
||||
}
|
||||
}, CLAUDE_ACTIVE_MODEL_TIMEOUT_MS);
|
||||
} catch {
|
||||
// Skip malformed JSONL lines that can happen during concurrent writes.
|
||||
}
|
||||
}
|
||||
|
||||
const finish = (error: Error | null, result: ProviderCurrentActiveModel | null) => {
|
||||
if (settled) {
|
||||
return;
|
||||
}
|
||||
|
||||
settled = true;
|
||||
clearTimeout(timer);
|
||||
|
||||
if (error) {
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
resolve(result);
|
||||
};
|
||||
|
||||
child.stdout?.on('data', (chunk: Buffer) => {
|
||||
stdout += chunk.toString();
|
||||
});
|
||||
|
||||
child.stderr?.on('data', (chunk: Buffer) => {
|
||||
stderr += chunk.toString();
|
||||
});
|
||||
|
||||
child.on('error', (error) => {
|
||||
finish(error instanceof Error ? error : new Error(String(error)), null);
|
||||
});
|
||||
|
||||
child.on('close', () => {
|
||||
const lines = `${stdout}\n${stderr}`
|
||||
.split(/\r?\n/)
|
||||
.map((line) => line.trim())
|
||||
.filter(Boolean);
|
||||
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const event = JSON.parse(line) as ClaudeInitEvent;
|
||||
if (event.type === 'system' && event.subtype === 'init' && event.model) {
|
||||
finish(null, {
|
||||
model: event.model,
|
||||
});
|
||||
return;
|
||||
}
|
||||
} catch {
|
||||
// The Claude CLI mixes non-JSON lines into verbose output; ignore them.
|
||||
}
|
||||
}
|
||||
|
||||
finish(null, null);
|
||||
});
|
||||
});
|
||||
return null;
|
||||
};
|
||||
|
||||
export class ClaudeProviderModels implements IProviderModels {
|
||||
@@ -161,7 +188,7 @@ export class ClaudeProviderModels implements IProviderModels {
|
||||
// instance, so we create a lightweight query and immediately close it
|
||||
// after reading the control-plane metadata.
|
||||
queryInstance = query({
|
||||
prompt: '',
|
||||
prompt: 'Get supported models',
|
||||
options: buildClaudeQueryOptions(),
|
||||
});
|
||||
|
||||
@@ -181,7 +208,10 @@ export class ClaudeProviderModels implements IProviderModels {
|
||||
}
|
||||
|
||||
try {
|
||||
const activeModel = await runClaudeSessionModelCommand(sessionId);
|
||||
const jsonlPath = sessionsDb.getSessionById(sessionId)?.jsonl_path;
|
||||
const activeModel = jsonlPath
|
||||
? await readClaudeSessionModelFromJsonl(sessionId, jsonlPath)
|
||||
: null;
|
||||
if (activeModel?.model) {
|
||||
return activeModel;
|
||||
}
|
||||
@@ -191,4 +221,10 @@ export class ClaudeProviderModels implements IProviderModels {
|
||||
|
||||
return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels());
|
||||
}
|
||||
|
||||
async changeActiveModel(
|
||||
input: ProviderChangeActiveModelInput,
|
||||
): Promise<ProviderSessionActiveModelChange> {
|
||||
return writeProviderSessionActiveModelChange('claude', input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,14 +6,17 @@ import TOML from '@iarna/toml';
|
||||
|
||||
import type { IProviderModels } from '@/shared/interfaces.js';
|
||||
import type {
|
||||
ProviderChangeActiveModelInput,
|
||||
ProviderCurrentActiveModel,
|
||||
ProviderModelOption,
|
||||
ProviderModelsDefinition,
|
||||
ProviderSessionActiveModelChange,
|
||||
} from '@/shared/types.js';
|
||||
import {
|
||||
buildDefaultProviderCurrentActiveModel,
|
||||
readObjectRecord,
|
||||
readOptionalString,
|
||||
writeProviderSessionActiveModelChange,
|
||||
} from '@/shared/utils.js';
|
||||
|
||||
export const CODEX_FALLBACK_MODELS: ProviderModelsDefinition = {
|
||||
@@ -113,4 +116,10 @@ export class CodexProviderModels implements IProviderModels {
|
||||
return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels());
|
||||
}
|
||||
}
|
||||
|
||||
async changeActiveModel(
|
||||
input: ProviderChangeActiveModelInput,
|
||||
): Promise<ProviderSessionActiveModelChange> {
|
||||
return writeProviderSessionActiveModelChange('codex', input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,13 +7,16 @@ import crossSpawn from 'cross-spawn';
|
||||
|
||||
import type { IProviderModels } from '@/shared/interfaces.js';
|
||||
import type {
|
||||
ProviderChangeActiveModelInput,
|
||||
ProviderCurrentActiveModel,
|
||||
ProviderModelOption,
|
||||
ProviderModelsDefinition,
|
||||
ProviderSessionActiveModelChange,
|
||||
} from '@/shared/types.js';
|
||||
import {
|
||||
buildDefaultProviderCurrentActiveModel,
|
||||
sanitizeLeafDirectoryName,
|
||||
writeProviderSessionActiveModelChange,
|
||||
} from '@/shared/utils.js';
|
||||
|
||||
export const CURSOR_FALLBACK_MODELS: ProviderModelsDefinition = {
|
||||
@@ -257,4 +260,10 @@ export class CursorProviderModels implements IProviderModels {
|
||||
|
||||
return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels());
|
||||
}
|
||||
|
||||
async changeActiveModel(
|
||||
input: ProviderChangeActiveModelInput,
|
||||
): Promise<ProviderSessionActiveModelChange> {
|
||||
return writeProviderSessionActiveModelChange('cursor', input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
import type { IProviderModels } from '@/shared/interfaces.js';
|
||||
import type { ProviderCurrentActiveModel, ProviderModelsDefinition } from '@/shared/types.js';
|
||||
import { buildDefaultProviderCurrentActiveModel } from '@/shared/utils.js';
|
||||
import type {
|
||||
ProviderChangeActiveModelInput,
|
||||
ProviderCurrentActiveModel,
|
||||
ProviderModelsDefinition,
|
||||
ProviderSessionActiveModelChange,
|
||||
} from '@/shared/types.js';
|
||||
import {
|
||||
buildDefaultProviderCurrentActiveModel,
|
||||
writeProviderSessionActiveModelChange,
|
||||
} from '@/shared/utils.js';
|
||||
|
||||
export const GEMINI_FALLBACK_MODELS: ProviderModelsDefinition = {
|
||||
OPTIONS: [
|
||||
@@ -25,4 +33,10 @@ export class GeminiProviderModels implements IProviderModels {
|
||||
async getCurrentActiveModel(): Promise<ProviderCurrentActiveModel> {
|
||||
return buildDefaultProviderCurrentActiveModel(GEMINI_FALLBACK_MODELS);
|
||||
}
|
||||
|
||||
async changeActiveModel(
|
||||
input: ProviderChangeActiveModelInput,
|
||||
): Promise<ProviderSessionActiveModelChange> {
|
||||
return writeProviderSessionActiveModelChange('gemini', input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,15 +5,18 @@ import crossSpawn from 'cross-spawn';
|
||||
|
||||
import type { IProviderModels } from '@/shared/interfaces.js';
|
||||
import type {
|
||||
ProviderChangeActiveModelInput,
|
||||
ProviderCurrentActiveModel,
|
||||
ProviderModelOption,
|
||||
ProviderModelsDefinition,
|
||||
ProviderSessionActiveModelChange,
|
||||
} from '@/shared/types.js';
|
||||
import {
|
||||
buildDefaultProviderCurrentActiveModel,
|
||||
getOpenCodeDatabasePath,
|
||||
readObjectRecord,
|
||||
readOptionalString,
|
||||
writeProviderSessionActiveModelChange,
|
||||
} from '@/shared/utils.js';
|
||||
|
||||
export const OPENCODE_FALLBACK_MODELS: ProviderModelsDefinition = {
|
||||
@@ -220,4 +223,10 @@ export class OpenCodeProviderModels implements IProviderModels {
|
||||
|
||||
return buildDefaultProviderCurrentActiveModel(await this.getSupportedModels());
|
||||
}
|
||||
|
||||
async changeActiveModel(
|
||||
input: ProviderChangeActiveModelInput,
|
||||
): Promise<ProviderSessionActiveModelChange> {
|
||||
return writeProviderSessionActiveModelChange('opencode', input);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,13 @@ import { providerModelsService } from '@/modules/providers/services/provider-mod
|
||||
import { providerSkillsService } from '@/modules/providers/services/skills.service.js';
|
||||
import { sessionConversationsSearchService } from '@/modules/providers/services/session-conversations-search.service.js';
|
||||
import { sessionsService } from '@/modules/providers/services/sessions.service.js';
|
||||
import type { LLMProvider, McpScope, McpTransport, UpsertProviderMcpServerInput } from '@/shared/types.js';
|
||||
import type {
|
||||
LLMProvider,
|
||||
McpScope,
|
||||
McpTransport,
|
||||
ProviderChangeActiveModelInput,
|
||||
UpsertProviderMcpServerInput,
|
||||
} from '@/shared/types.js';
|
||||
import { AppError, asyncHandler, createApiSuccessResponse } from '@/shared/utils.js';
|
||||
|
||||
const router = express.Router();
|
||||
@@ -246,6 +252,29 @@ const parseSessionSearchLimit = (value: unknown): number => {
|
||||
return Math.max(1, Math.min(parsed, 100));
|
||||
};
|
||||
|
||||
const parseChangeActiveModelPayload = (payload: unknown): ProviderChangeActiveModelInput => {
|
||||
if (!payload || typeof payload !== 'object') {
|
||||
throw new AppError('Request body must be an object.', {
|
||||
code: 'INVALID_REQUEST_BODY',
|
||||
statusCode: 400,
|
||||
});
|
||||
}
|
||||
|
||||
const body = payload as Record<string, unknown>;
|
||||
const model = readOptionalQueryString(body.model);
|
||||
if (!model) {
|
||||
throw new AppError('model is required.', {
|
||||
code: 'MODEL_REQUIRED',
|
||||
statusCode: 400,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
sessionId: '',
|
||||
model,
|
||||
};
|
||||
};
|
||||
|
||||
router.get(
|
||||
'/:provider/auth/status',
|
||||
asyncHandler(async (req: Request, res: Response) => {
|
||||
@@ -265,6 +294,20 @@ router.get(
|
||||
}),
|
||||
);
|
||||
|
||||
router.post(
|
||||
'/:provider/sessions/:sessionId/active-model',
|
||||
asyncHandler(async (req: Request, res: Response) => {
|
||||
const provider = parseProvider(req.params.provider);
|
||||
const sessionId = parseSessionId(req.params.sessionId);
|
||||
const payload = parseChangeActiveModelPayload(req.body);
|
||||
const result = await providerModelsService.changeActiveModel(provider, {
|
||||
...payload,
|
||||
sessionId,
|
||||
});
|
||||
res.json(createApiSuccessResponse(result));
|
||||
}),
|
||||
);
|
||||
|
||||
// ----------------- Skills routes -----------------
|
||||
router.get(
|
||||
'/:provider/skills',
|
||||
|
||||
@@ -6,11 +6,14 @@ import { providerRegistry } from '@/modules/providers/provider.registry.js';
|
||||
import type { IProvider } from '@/shared/interfaces.js';
|
||||
import type {
|
||||
LLMProvider,
|
||||
ProviderChangeActiveModelInput,
|
||||
ProviderCurrentActiveModel,
|
||||
ProviderModelsCacheInfo,
|
||||
ProviderModelsDefinition,
|
||||
ProviderModelsResult,
|
||||
ProviderSessionActiveModelChange,
|
||||
} from '@/shared/types.js';
|
||||
import { readProviderSessionActiveModelChange } from '@/shared/utils.js';
|
||||
|
||||
export const PROVIDER_MODELS_CACHE_TTL_MS = 3 * 24 * 60 * 60 * 1000;
|
||||
const PROVIDER_MODELS_CACHE_VERSION = 1;
|
||||
@@ -18,6 +21,7 @@ const PROVIDER_MODELS_CACHE_VERSION = 1;
|
||||
type ProviderModelsServiceDependencies = {
|
||||
resolveProvider?: (provider: LLMProvider) => Pick<IProvider, 'models'>;
|
||||
cachePath?: string;
|
||||
activeModelChangesPath?: string;
|
||||
now?: () => number;
|
||||
};
|
||||
|
||||
@@ -132,6 +136,7 @@ const writeProviderModelsCacheFile = async (
|
||||
export const createProviderModelsService = (dependencies: ProviderModelsServiceDependencies = {}) => {
|
||||
const resolveProvider = dependencies.resolveProvider ?? providerRegistry.resolveProvider;
|
||||
const cachePath = dependencies.cachePath ?? getProviderModelsCachePath();
|
||||
const activeModelChangesPath = dependencies.activeModelChangesPath;
|
||||
const now = dependencies.now ?? (() => Date.now());
|
||||
const memoryCache = new Map<LLMProvider, ProviderModelsCacheEntry>();
|
||||
const pendingRequests = new Map<LLMProvider, Promise<ProviderModelsResult>>();
|
||||
@@ -270,6 +275,36 @@ export const createProviderModelsService = (dependencies: ProviderModelsServiceD
|
||||
sessionId?: string,
|
||||
): Promise<ProviderCurrentActiveModel> => resolveProvider(provider).models.getCurrentActiveModel(sessionId);
|
||||
|
||||
const changeActiveModel = async (
|
||||
provider: LLMProvider,
|
||||
input: ProviderChangeActiveModelInput,
|
||||
): Promise<ProviderSessionActiveModelChange> => resolveProvider(provider).models.changeActiveModel(input);
|
||||
|
||||
const getChangedActiveModel = async (
|
||||
provider: LLMProvider,
|
||||
sessionId: string,
|
||||
): Promise<ProviderSessionActiveModelChange> => readProviderSessionActiveModelChange(provider, sessionId, {
|
||||
filePath: activeModelChangesPath,
|
||||
});
|
||||
|
||||
const resolveResumeModel = async (
|
||||
provider: LLMProvider,
|
||||
sessionId: string | undefined,
|
||||
requestedModel?: string | null,
|
||||
): Promise<string | undefined> => {
|
||||
const normalizedRequestedModel = typeof requestedModel === 'string' ? requestedModel.trim() : '';
|
||||
if (!sessionId?.trim()) {
|
||||
return normalizedRequestedModel || undefined;
|
||||
}
|
||||
|
||||
const changedModel = await getChangedActiveModel(provider, sessionId);
|
||||
if (changedModel.supported && changedModel.changed && changedModel.model?.trim()) {
|
||||
return changedModel.model.trim();
|
||||
}
|
||||
|
||||
return normalizedRequestedModel || undefined;
|
||||
};
|
||||
|
||||
const clearCache = (): void => {
|
||||
memoryCache.clear();
|
||||
pendingRequests.clear();
|
||||
@@ -280,6 +315,9 @@ export const createProviderModelsService = (dependencies: ProviderModelsServiceD
|
||||
return {
|
||||
getProviderModels,
|
||||
getCurrentActiveModel,
|
||||
getChangedActiveModel,
|
||||
changeActiveModel,
|
||||
resolveResumeModel,
|
||||
clearCache,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -9,10 +9,13 @@ import {
|
||||
PROVIDER_MODELS_CACHE_TTL_MS,
|
||||
} from '@/modules/providers/services/provider-models.service.js';
|
||||
import type {
|
||||
ProviderChangeActiveModelInput,
|
||||
LLMProvider,
|
||||
ProviderCurrentActiveModel,
|
||||
ProviderModelsDefinition,
|
||||
ProviderSessionActiveModelChange,
|
||||
} from '@/shared/types.js';
|
||||
import { writeProviderSessionActiveModelChange } from '@/shared/utils.js';
|
||||
|
||||
const createModels = (value: string): ProviderModelsDefinition => ({
|
||||
OPTIONS: [{ value, label: value }],
|
||||
@@ -23,6 +26,17 @@ const createCurrentActiveModel = (model: string): ProviderCurrentActiveModel =>
|
||||
model,
|
||||
});
|
||||
|
||||
const createSessionActiveModelChange = (
|
||||
provider: LLMProvider,
|
||||
input: ProviderChangeActiveModelInput,
|
||||
): ProviderSessionActiveModelChange => ({
|
||||
provider,
|
||||
sessionId: input.sessionId,
|
||||
supported: true,
|
||||
changed: true,
|
||||
model: input.model,
|
||||
});
|
||||
|
||||
const createEphemeralCachePath = (): string => path.join(
|
||||
os.tmpdir(),
|
||||
`provider-model-cache-${process.pid}-${Date.now()}-${Math.random().toString(16).slice(2)}.json`,
|
||||
@@ -38,6 +52,7 @@ test('provider models service delegates to the resolved provider model adapter',
|
||||
models: {
|
||||
getSupportedModels: async () => createModels(`${provider}-models`),
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`),
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange(provider, input),
|
||||
},
|
||||
};
|
||||
},
|
||||
@@ -65,6 +80,7 @@ test('provider models service returns each provider adapter result without rewri
|
||||
models: {
|
||||
getSupportedModels: async () => expectedModels,
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel('cursor-active'),
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange('cursor', input),
|
||||
},
|
||||
}),
|
||||
});
|
||||
@@ -90,6 +106,7 @@ test('provider models are cached for the three-day ttl', async () => {
|
||||
return createModels(`${provider}-${loadCount}`);
|
||||
},
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`),
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange(provider, input),
|
||||
},
|
||||
}),
|
||||
});
|
||||
@@ -124,6 +141,7 @@ test('provider model cache is persisted across service instances', async () => {
|
||||
models: {
|
||||
getSupportedModels: async () => createModels('gemini-cached'),
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel('gemini-active'),
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange('gemini', input),
|
||||
},
|
||||
}),
|
||||
});
|
||||
@@ -137,6 +155,7 @@ test('provider model cache is persisted across service instances', async () => {
|
||||
throw new Error('loader should not be called for persisted cache hits');
|
||||
},
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel('gemini-active'),
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange('gemini', input),
|
||||
},
|
||||
}),
|
||||
});
|
||||
@@ -163,6 +182,7 @@ test('concurrent provider model requests share one load operation', async () =>
|
||||
return createModels('claude-cached');
|
||||
},
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel('claude-active'),
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange('claude', input),
|
||||
},
|
||||
}),
|
||||
});
|
||||
@@ -196,6 +216,7 @@ test('bypassCache forces a fresh provider fetch and updates cache metadata', asy
|
||||
return createModels(`${provider}-${loadCount}`);
|
||||
},
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active-${loadCount}`),
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange(provider, input),
|
||||
},
|
||||
}),
|
||||
});
|
||||
@@ -224,6 +245,7 @@ test('provider models service delegates current active model lookups to the prov
|
||||
calls.push({ provider, sessionId });
|
||||
return createCurrentActiveModel(`${provider}-${sessionId}`);
|
||||
},
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange(provider, input),
|
||||
},
|
||||
}),
|
||||
});
|
||||
@@ -233,3 +255,64 @@ test('provider models service delegates current active model lookups to the prov
|
||||
assert.deepEqual(calls, [{ provider: 'opencode', sessionId: 'session-123' }]);
|
||||
assert.equal(activeModel.model, 'opencode-session-123');
|
||||
});
|
||||
|
||||
test('provider models service delegates active model change requests to the provider adapter', async () => {
|
||||
const calls: Array<{ provider: LLMProvider; input: ProviderChangeActiveModelInput }> = [];
|
||||
const service = createProviderModelsService({
|
||||
resolveProvider: (provider) => ({
|
||||
models: {
|
||||
getSupportedModels: async () => createModels(`${provider}-models`),
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`),
|
||||
changeActiveModel: async (input) => {
|
||||
calls.push({ provider, input });
|
||||
return createSessionActiveModelChange(provider, input);
|
||||
},
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const changedModel = await service.changeActiveModel('claude', {
|
||||
sessionId: 'session-123',
|
||||
model: 'opus',
|
||||
});
|
||||
|
||||
assert.deepEqual(calls, [{
|
||||
provider: 'claude',
|
||||
input: {
|
||||
sessionId: 'session-123',
|
||||
model: 'opus',
|
||||
},
|
||||
}]);
|
||||
assert.equal(changedModel.changed, true);
|
||||
assert.equal(changedModel.model, 'opus');
|
||||
});
|
||||
|
||||
test('resolveResumeModel prefers a stored changed model over the requested one', async () => {
|
||||
const tempRoot = await mkdtemp(path.join(os.tmpdir(), 'provider-model-change-'));
|
||||
const activeModelChangesPath = path.join(tempRoot, 'session-model-changes.json');
|
||||
|
||||
try {
|
||||
const service = createProviderModelsService({
|
||||
activeModelChangesPath,
|
||||
resolveProvider: (provider) => ({
|
||||
models: {
|
||||
getSupportedModels: async () => createModels(`${provider}-models`),
|
||||
getCurrentActiveModel: async () => createCurrentActiveModel(`${provider}-active`),
|
||||
changeActiveModel: async (input) => createSessionActiveModelChange(provider, input),
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
await writeProviderSessionActiveModelChange('cursor', {
|
||||
sessionId: 'session-456',
|
||||
model: 'composer-2',
|
||||
}, {
|
||||
filePath: activeModelChangesPath,
|
||||
});
|
||||
|
||||
const model = await service.resolveResumeModel('cursor', 'session-456', 'composer-2-fast');
|
||||
assert.equal(model, 'composer-2');
|
||||
} finally {
|
||||
await rm(tempRoot, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user