diff --git a/src/browser/features/Settings/Sections/ModelRow.tsx b/src/browser/features/Settings/Sections/ModelRow.tsx index 4e31cb92d8..2a159d1fa4 100644 --- a/src/browser/features/Settings/Sections/ModelRow.tsx +++ b/src/browser/features/Settings/Sections/ModelRow.tsx @@ -179,6 +179,9 @@ export interface ModelRowProps { editModelValue?: string; editContextValue?: string; editMappedToModel?: string; + editMaxOutputTokensValue?: string; + editTemperatureValue?: string; + editTopPValue?: string; editAutofocus?: "model" | "context"; customContextWindowTokens?: number | null; mappedToModel?: string | null; @@ -211,6 +214,9 @@ export interface ModelRowProps { onEditModelChange?: (value: string) => void; onEditContextChange?: (value: string) => void; onEditMappedToModelChange?: (value: string) => void; + onEditMaxOutputTokensChange?: (value: string) => void; + onEditTemperatureChange?: (value: string) => void; + onEditTopPChange?: (value: string) => void; onRemove?: () => void; /** Set/clear explicit route override (null = auto) */ onSetRouteOverride?: (route: string | null) => void; @@ -332,6 +338,52 @@ export function ModelRow(props: ModelRowProps) { /> )} + {props.isCustom && ( +
+ Params +
+ props.onEditMaxOutputTokensChange?.(e.target.value)} + onKeyDown={createEditKeyHandler({ + onSave: () => props.onSaveEdit?.(), + onCancel: () => props.onCancelEdit?.(), + })} + className="bg-modal-bg border-border-medium focus:border-accent min-w-0 rounded border px-2 py-0.5 text-right font-mono text-xs focus:outline-none" + placeholder="max_output_tokens" + title="max_output_tokens" + /> + props.onEditTemperatureChange?.(e.target.value)} + onKeyDown={createEditKeyHandler({ + onSave: () => props.onSaveEdit?.(), + onCancel: () => props.onCancelEdit?.(), + })} + className="bg-modal-bg border-border-medium focus:border-accent min-w-0 rounded border px-2 py-0.5 text-right font-mono text-xs focus:outline-none" + placeholder="temperature" + title="temperature" + /> + props.onEditTopPChange?.(e.target.value)} + onKeyDown={createEditKeyHandler({ + onSave: () => props.onSaveEdit?.(), + onCancel: () => props.onCancelEdit?.(), + })} + className="bg-modal-bg border-border-medium focus:border-accent min-w-0 rounded border px-2 py-0.5 text-right font-mono text-xs focus:outline-none" + placeholder="top_p" + title="top_p" + /> +
+
+ )} {props.editError &&
{props.editError}
} diff --git a/src/browser/features/Settings/Sections/ModelsSection.test.ts b/src/browser/features/Settings/Sections/ModelsSection.test.ts index df3efe7b90..31bb227271 100644 --- a/src/browser/features/Settings/Sections/ModelsSection.test.ts +++ b/src/browser/features/Settings/Sections/ModelsSection.test.ts @@ -1,6 +1,14 @@ import { describe, expect, test } from "bun:test"; import { KNOWN_MODELS } from "@/common/constants/knownModels"; -import { shouldAllowRouteOverrideInSettings, shouldShowModelInSettings } from "./ModelsSection"; +import { + buildUpdatedModelParameters, + migrateModelParameterEntry, + parseBoundedNumberInput, + parsePositiveIntegerInput, + removeModelParameterEntry, + shouldAllowRouteOverrideInSettings, + shouldShowModelInSettings, +} from "./ModelsSection"; describe("shouldShowModelInSettings", () => { test("hides OAuth-required Codex model when OpenAI OAuth is not configured", () => { @@ -45,3 +53,125 @@ describe("shouldAllowRouteOverrideInSettings", () => { expect(shouldAllowRouteOverrideInSettings("ollama:gpt-oss:20b")).toBe(true); }); }); + +describe("model parameter edit helpers", () => { + test("parses positive integer input for max_output_tokens", () => { + expect(parsePositiveIntegerInput("42")).toBe(42); + expect(parsePositiveIntegerInput("0")).toBeNull(); + expect(parsePositiveIntegerInput("1.5")).toBeNull(); + expect(parsePositiveIntegerInput("abc")).toBeNull(); + }); + + test("parses bounded decimal input for temperature and top_p", () => { + expect(parseBoundedNumberInput("0", 0, 2)).toBe(0); + expect(parseBoundedNumberInput("1.5", 0, 2)).toBe(1.5); + expect(parseBoundedNumberInput("2", 0, 2)).toBe(2); + expect(parseBoundedNumberInput("2.1", 0, 2)).toBeNull(); + expect(parseBoundedNumberInput("-0.1", 0, 1)).toBeNull(); + }); + + test("buildUpdatedModelParameters preserves non-editable parameters when clearing editable fields", () => { + const updated = buildUpdatedModelParameters( + { + "gpt-5": { + max_output_tokens: 1024, + temperature: 0.7, + top_k: 42, + }, + }, + "gpt-5", + { + max_output_tokens: null, + temperature: null, + top_p: null, + } + ); + + expect(updated).toEqual({ + "gpt-5": { + top_k: 42, + }, + }); + }); + + test("buildUpdatedModelParameters updates editable fields", () => { + const withNewOverrides = buildUpdatedModelParameters(undefined, "renamed", { + max_output_tokens: 2048, + temperature: 0.5, + top_p: null, + }); + + expect(withNewOverrides).toEqual({ + renamed: { + max_output_tokens: 2048, + temperature: 0.5, + }, + }); + }); + + test("rename migration preserves non-editable overrides and removes legacy key", () => { + const migrated = migrateModelParameterEntry( + { + legacy: { + temperature: 0.2, + top_k: 42, + seed: 7, + }, + }, + "legacy", + "renamed" + ); + + const updated = buildUpdatedModelParameters(migrated, "renamed", { + max_output_tokens: null, + temperature: 0.8, + top_p: null, + }); + + expect(updated).toEqual({ + renamed: { + temperature: 0.8, + top_k: 42, + seed: 7, + }, + }); + }); + + test("rename migration merges existing destination overrides", () => { + const migrated = migrateModelParameterEntry( + { + legacy: { + temperature: 0.4, + top_k: 12, + }, + renamed: { + seed: 99, + }, + }, + "legacy", + "renamed" + ); + + expect(migrated).toEqual({ + renamed: { + seed: 99, + temperature: 0.4, + top_k: 12, + }, + }); + }); + + test("removeModelParameterEntry clears old model parameter keys", () => { + const withoutLegacy = removeModelParameterEntry( + { + legacy: { temperature: 0.2 }, + renamed: { top_p: 0.8 }, + }, + "legacy" + ); + + expect(withoutLegacy).toEqual({ + renamed: { top_p: 0.8 }, + }); + }); +}); diff --git a/src/browser/features/Settings/Sections/ModelsSection.tsx b/src/browser/features/Settings/Sections/ModelsSection.tsx index 1ae1d58206..e1eecc931a 100644 --- a/src/browser/features/Settings/Sections/ModelsSection.tsx +++ b/src/browser/features/Settings/Sections/ModelsSection.tsx @@ -57,16 +57,36 @@ function ModelsTableHeader() { ); } +interface EditableModelParameterOverrides { + max_output_tokens: number | null; + temperature: number | null; + top_p: number | null; +} + interface EditingState { provider: string; originalModelId: string; newModelId: string; contextWindowTokens: string; mappedToModel: string; + maxOutputTokens: string; + temperature: string; + topP: string; focus?: "model" | "context"; } -function parseContextWindowTokensInput(value: string): number | null { +interface CustomModelInfo { + provider: string; + modelId: string; + fullId: string; + contextWindowTokens: number | null; + mappedToModel: string | null; + maxOutputTokens: number | null; + temperature: number | null; + topP: number | null; +} + +export function parsePositiveIntegerInput(value: string): number | null { const trimmed = value.trim(); if (trimmed.length === 0) { return null; @@ -80,6 +100,43 @@ function parseContextWindowTokensInput(value: string): number | null { return parsed; } +function parseContextWindowTokensInput(value: string): number | null { + return parsePositiveIntegerInput(value); +} + +export function parseBoundedNumberInput(value: string, min: number, max: number): number | null { + const trimmed = value.trim(); + if (trimmed.length === 0) { + return null; + } + + const parsed = Number(trimmed); + if (!Number.isFinite(parsed) || parsed < min || parsed > max) { + return null; + } + + return parsed; +} + +function getEditableModelParameterOverrides( + modelParametersByModel: Record< + string, + { max_output_tokens?: number; temperature?: number; top_p?: number } | undefined + >, + modelId: string +): EditableModelParameterOverrides { + const modelOverrides = modelParametersByModel[modelId]; + return { + max_output_tokens: + typeof modelOverrides?.max_output_tokens === "number" + ? modelOverrides.max_output_tokens + : null, + temperature: + typeof modelOverrides?.temperature === "number" ? modelOverrides.temperature : null, + top_p: typeof modelOverrides?.top_p === "number" ? modelOverrides.top_p : null, + }; +} + function buildProviderModelEntry( modelId: string, contextWindowTokens: number | null, @@ -100,6 +157,84 @@ function buildProviderModelEntry( return entry; } +export function buildUpdatedModelParameters( + currentModelParameters: Record> | undefined, + modelId: string, + overrides: EditableModelParameterOverrides +): Record> | undefined { + const nextModelParameters = { ...(currentModelParameters ?? {}) }; + const currentOverrides = nextModelParameters[modelId]; + const nextOverrides: Record = { + ...(typeof currentOverrides === "object" && currentOverrides !== null ? currentOverrides : {}), + }; + + if (overrides.max_output_tokens === null) { + delete nextOverrides.max_output_tokens; + } else { + nextOverrides.max_output_tokens = overrides.max_output_tokens; + } + + if (overrides.temperature === null) { + delete nextOverrides.temperature; + } else { + nextOverrides.temperature = overrides.temperature; + } + + if (overrides.top_p === null) { + delete nextOverrides.top_p; + } else { + nextOverrides.top_p = overrides.top_p; + } + + if (Object.keys(nextOverrides).length === 0) { + delete nextModelParameters[modelId]; + } else { + nextModelParameters[modelId] = nextOverrides; + } + + return Object.keys(nextModelParameters).length > 0 ? nextModelParameters : undefined; +} + +export function removeModelParameterEntry( + currentModelParameters: Record> | undefined, + modelId: string +): Record> | undefined { + if (!currentModelParameters || !(modelId in currentModelParameters)) { + return currentModelParameters; + } + + const nextModelParameters = { ...currentModelParameters }; + delete nextModelParameters[modelId]; + return Object.keys(nextModelParameters).length > 0 ? nextModelParameters : undefined; +} + +export function migrateModelParameterEntry( + currentModelParameters: Record> | undefined, + originalModelId: string, + nextModelId: string +): Record> | undefined { + if (!currentModelParameters || originalModelId === nextModelId) { + return currentModelParameters; + } + + const originalOverrides = currentModelParameters[originalModelId]; + if (!originalOverrides || typeof originalOverrides !== "object") { + return currentModelParameters; + } + + const destinationOverrides = currentModelParameters[nextModelId]; + const nextModelParameters = { ...currentModelParameters }; + nextModelParameters[nextModelId] = { + ...(destinationOverrides && typeof destinationOverrides === "object" + ? destinationOverrides + : {}), + ...originalOverrides, + }; + delete nextModelParameters[originalModelId]; + + return Object.keys(nextModelParameters).length > 0 ? nextModelParameters : undefined; +} + export function shouldShowModelInSettings(modelId: string, codexOauthConfigured: boolean): boolean { // OpenAI OAuth gating only applies to OpenAI-routed models; other providers can // reuse the same providerModelId string without requiring OpenAI OAuth. @@ -126,7 +261,8 @@ export function ModelsSection() { const { api } = useAPI(); const { open: openSettings } = useSettings(); - const { config, loading, updateModelsOptimistically } = useProvidersConfig(); + const { config, loading, refresh, updateModelsOptimistically, updateOptimistically } = + useProvidersConfig(); const [lastProvider, setLastProvider] = usePersistedState(LAST_CUSTOM_MODEL_PROVIDER_KEY, ""); const [newModelId, setNewModelId] = useState(""); const [editing, setEditing] = useState(null); @@ -218,50 +354,68 @@ export function ModelsSection() { models.filter((entry) => getProviderModelEntryId(entry) !== modelId) ); - // Save in background - void api.providers.setModels({ provider, models: updatedModels }); + const providerModelParameters = config[provider]?.modelParameters as + | Record> + | undefined; + updateOptimistically(provider, { + modelParameters: removeModelParameterEntry(providerModelParameters, modelId), + }); + + void (async () => { + const setModelsResult = await api.providers.setModels({ provider, models: updatedModels }); + if (!setModelsResult.success) { + setError(setModelsResult.error); + void refresh(); + return; + } + + const clearOverridesResult = await api.providers.setModelParameters({ + provider, + modelId, + overrides: { + max_output_tokens: null, + temperature: null, + top_p: null, + }, + }); + + if (!clearOverridesResult.success) { + setError(clearOverridesResult.error); + void refresh(); + } + })(); }, - [api, config, updateModelsOptimistically] + [api, config, refresh, updateModelsOptimistically, updateOptimistically] ); + const startModelEdit = useCallback((model: CustomModelInfo, focus: "model" | "context") => { + setEditing({ + provider: model.provider, + originalModelId: model.modelId, + newModelId: model.modelId, + contextWindowTokens: + model.contextWindowTokens === null ? "" : String(model.contextWindowTokens), + mappedToModel: model.mappedToModel ?? "", + maxOutputTokens: model.maxOutputTokens === null ? "" : String(model.maxOutputTokens), + temperature: model.temperature === null ? "" : String(model.temperature), + topP: model.topP === null ? "" : String(model.topP), + focus, + }); + setError(null); + }, []); + const handleStartEdit = useCallback( - ( - provider: string, - modelId: string, - contextWindowTokens: number | null, - mappedToModel: string | null - ) => { - setEditing({ - provider, - originalModelId: modelId, - newModelId: modelId, - contextWindowTokens: contextWindowTokens === null ? "" : String(contextWindowTokens), - mappedToModel: mappedToModel ?? "", - focus: "model", - }); - setError(null); + (model: CustomModelInfo) => { + startModelEdit(model, "model"); }, - [] + [startModelEdit] ); const handleStartContextEdit = useCallback( - ( - provider: string, - modelId: string, - contextWindowTokens: number | null, - mappedToModel: string | null - ) => { - setEditing({ - provider, - originalModelId: modelId, - newModelId: modelId, - contextWindowTokens: contextWindowTokens === null ? "" : String(contextWindowTokens), - mappedToModel: mappedToModel ?? "", - focus: "context", - }); - setError(null); + (model: CustomModelInfo) => { + startModelEdit(model, "context"); }, - [] + [startModelEdit] ); const handleCancelEdit = useCallback(() => { @@ -285,6 +439,27 @@ export function ModelsSection() { return; } + const maxOutputTokensInput = editing.maxOutputTokens.trim(); + const parsedMaxOutputTokens = parsePositiveIntegerInput(maxOutputTokensInput); + if (maxOutputTokensInput.length > 0 && parsedMaxOutputTokens === null) { + setError("Max output tokens must be a positive integer"); + return; + } + + const temperatureInput = editing.temperature.trim(); + const parsedTemperature = parseBoundedNumberInput(temperatureInput, 0, 2); + if (temperatureInput.length > 0 && parsedTemperature === null) { + setError("Temperature must be a number between 0 and 2"); + return; + } + + const topPInput = editing.topP.trim(); + const parsedTopP = parseBoundedNumberInput(topPInput, 0, 1); + if (topPInput.length > 0 && parsedTopP === null) { + setError("Top P must be a number between 0 and 1"); + return; + } + // Only validate duplicates if the model ID actually changed if (trimmedModelId !== editing.originalModelId) { if (modelExists(editing.provider, trimmedModelId)) { @@ -301,6 +476,11 @@ export function ModelsSection() { parsedContextWindowTokens, mappedTo ); + const overrides: EditableModelParameterOverrides = { + max_output_tokens: parsedMaxOutputTokens, + temperature: parsedTemperature, + top_p: parsedTopP, + }; // Optimistic update - returns new models array for API call const updatedModels = updateModelsOptimistically(editing.provider, (models) => { @@ -323,11 +503,65 @@ export function ModelsSection() { return nextModels; }); + + const providerModelParameters = config[editing.provider]?.modelParameters as + | Record> + | undefined; + let nextModelParameters = providerModelParameters; + + if (trimmedModelId !== editing.originalModelId) { + nextModelParameters = migrateModelParameterEntry( + nextModelParameters, + editing.originalModelId, + trimmedModelId + ); + } + + nextModelParameters = buildUpdatedModelParameters( + nextModelParameters, + trimmedModelId, + overrides + ); + + updateOptimistically(editing.provider, { + modelParameters: nextModelParameters, + }); + + const providerId = editing.provider; + const originalModelId = editing.originalModelId; setEditing(null); - // Save in background - void api.providers.setModels({ provider: editing.provider, models: updatedModels }); - }, [api, editing, config, modelExists, updateModelsOptimistically]); + void (async () => { + const setModelsResult = await api.providers.setModels({ + provider: providerId, + models: updatedModels, + }); + if (!setModelsResult.success) { + setError(setModelsResult.error); + void refresh(); + return; + } + + const setOverridesResult = await api.providers.setModelParameters({ + provider: providerId, + modelId: trimmedModelId, + renameFromModelId: trimmedModelId !== originalModelId ? originalModelId : undefined, + overrides, + }); + if (!setOverridesResult.success) { + setError(setOverridesResult.error); + void refresh(); + } + })(); + }, [ + api, + editing, + config, + modelExists, + refresh, + updateModelsOptimistically, + updateOptimistically, + ]); // Show loading state while config is being fetched if (loading || !config) { @@ -340,34 +574,32 @@ export function ModelsSection() { } // Get all custom models across providers (excluding hidden providers like mux-gateway) - const getCustomModels = (): Array<{ - provider: string; - modelId: string; - fullId: string; - contextWindowTokens: number | null; - mappedToModel: string | null; - }> => { - const models: Array<{ - provider: string; - modelId: string; - fullId: string; - contextWindowTokens: number | null; - mappedToModel: string | null; - }> = []; + const getCustomModels = (): CustomModelInfo[] => { + const models: CustomModelInfo[] = []; for (const [provider, providerConfig] of Object.entries(config)) { // Skip hidden providers (mux-gateway models are routed, not managed as a standalone list) if (HIDDEN_PROVIDERS.has(provider)) continue; if (!providerConfig.models) continue; + const modelParametersByModel = + (providerConfig.modelParameters as Record< + string, + { max_output_tokens?: number; temperature?: number; top_p?: number } | undefined + > | null) ?? {}; + for (const modelEntry of providerConfig.models) { const modelId = getProviderModelEntryId(modelEntry); + const modelOverrides = getEditableModelParameterOverrides(modelParametersByModel, modelId); models.push({ provider, modelId, fullId: `${provider}:${modelId}`, contextWindowTokens: getProviderModelEntryContextWindowTokens(modelEntry), mappedToModel: getProviderModelEntryMappedTo(modelEntry), + maxOutputTokens: modelOverrides.max_output_tokens, + temperature: modelOverrides.temperature, + topP: modelOverrides.top_p, }); } } @@ -470,6 +702,11 @@ export function ModelsSection() { editModelValue={isModelEditing ? editing.newModelId : undefined} editContextValue={isModelEditing ? editing.contextWindowTokens : undefined} editMappedToModel={isModelEditing ? editing.mappedToModel : undefined} + editMaxOutputTokensValue={ + isModelEditing ? editing.maxOutputTokens : undefined + } + editTemperatureValue={isModelEditing ? editing.temperature : undefined} + editTopPValue={isModelEditing ? editing.topP : undefined} editAutofocus={isModelEditing ? editing.focus : undefined} customContextWindowTokens={model.contextWindowTokens} allModels={knownModelIds} @@ -481,22 +718,8 @@ export function ModelsSection() { availableRoutes={routing.availableRoutes(model.fullId)} is1MContextEnabled={has1MContext(model.fullId)} onSetDefault={() => setDefaultModel(model.fullId)} - onStartEdit={() => - handleStartEdit( - model.provider, - model.modelId, - model.contextWindowTokens, - model.mappedToModel - ) - } - onStartContextEdit={() => - handleStartContextEdit( - model.provider, - model.modelId, - model.contextWindowTokens, - model.mappedToModel - ) - } + onStartEdit={() => handleStartEdit(model)} + onStartContextEdit={() => handleStartContextEdit(model)} onSaveEdit={handleSaveEdit} onCancelEdit={handleCancelEdit} onEditModelChange={(value) => @@ -510,6 +733,15 @@ export function ModelsSection() { onEditMappedToModelChange={(value) => setEditing((prev) => (prev ? { ...prev, mappedToModel: value } : null)) } + onEditMaxOutputTokensChange={(value) => + setEditing((prev) => (prev ? { ...prev, maxOutputTokens: value } : null)) + } + onEditTemperatureChange={(value) => + setEditing((prev) => (prev ? { ...prev, temperature: value } : null)) + } + onEditTopPChange={(value) => + setEditing((prev) => (prev ? { ...prev, topP: value } : null)) + } onRemove={() => handleRemoveModel(model.provider, model.modelId)} isHiddenFromSelector={hiddenModels.includes(model.fullId)} onToggleVisibility={() => diff --git a/src/common/orpc/schemas/api.ts b/src/common/orpc/schemas/api.ts index 8e20fd0d8a..d87288715b 100644 --- a/src/common/orpc/schemas/api.ts +++ b/src/common/orpc/schemas/api.ts @@ -110,6 +110,10 @@ import { CodexOauthDefaultAuthSchema, ServiceTierSchema, } from "../../config/schemas/providersConfig"; +import { + ModelParametersByModelSchema, + StandardModelParameterOverridesSchema, +} from "../../config/schemas/modelParameters"; import { ProviderModelEntrySchema } from "../../config/schemas/providerModelEntry"; import { TaskSettingsSchema } from "../../config/schemas/taskSettings"; import { ThinkingLevelSchema } from "../../types/thinking"; @@ -226,6 +230,7 @@ export const ProviderConfigInfoSchema = z.object({ displayName: z.string().optional(), isCustom: z.boolean().optional(), models: z.array(ProviderModelEntrySchema).optional(), + modelParameters: ModelParametersByModelSchema.optional(), /** OpenAI-specific fields */ serviceTier: ServiceTierSchema.optional(), wireFormat: z.enum(["responses", "chatCompletions"]).optional(), @@ -299,6 +304,16 @@ export const CustomProviderMutationErrorSchema = z.discriminatedUnion("code", [ }), ]); +const EditableCustomModelParameterOverridesSchema = StandardModelParameterOverridesSchema.pick({ + max_output_tokens: true, + temperature: true, + top_p: true, +}).extend({ + max_output_tokens: z.number().int().positive().nullish(), + temperature: z.number().min(0).max(2).nullish(), + top_p: z.number().min(0).max(1).nullish(), +}); + export const providers = { addCustomOpenAICompatibleProvider: { input: z.object({ @@ -342,6 +357,15 @@ export const providers = { }), output: ResultSchema(z.void(), z.string()), }, + setModelParameters: { + input: z.object({ + provider: z.string(), + modelId: z.string().min(1), + renameFromModelId: z.string().min(1).optional(), + overrides: EditableCustomModelParameterOverridesSchema, + }), + output: ResultSchema(z.void(), z.string()), + }, list: { input: z.void(), output: z.array(z.string()), diff --git a/src/node/orpc/router.ts b/src/node/orpc/router.ts index 6e3fbf5e66..5f522b0595 100644 --- a/src/node/orpc/router.ts +++ b/src/node/orpc/router.ts @@ -1948,6 +1948,17 @@ export const router = (authToken?: string) => { .handler(({ context, input }) => context.providerService.setModels(input.provider, input.models) ), + setModelParameters: t + .input(schemas.providers.setModelParameters.input) + .output(schemas.providers.setModelParameters.output) + .handler(({ context, input }) => + context.providerService.setModelParameters( + input.provider, + input.modelId, + input.overrides, + input.renameFromModelId + ) + ), onConfigChanged: t .input(schemas.providers.onConfigChanged.input) .output(schemas.providers.onConfigChanged.output) diff --git a/src/node/services/providerService.test.ts b/src/node/services/providerService.test.ts index 2259bc95e3..c3a43fcb74 100644 --- a/src/node/services/providerService.test.ts +++ b/src/node/services/providerService.test.ts @@ -628,6 +628,185 @@ describe("ProviderService model normalization", () => { }); }); +describe("ProviderService model parameter overrides", () => { + it("surfaces modelParameters in getConfig", () => { + withTempConfig((config, service) => { + saveOpenAIConfig(config, { + modelParameters: { + "gpt-5": { + max_output_tokens: 2048, + temperature: 0.4, + top_p: 0.9, + }, + }, + }); + + const providers = service.getConfig(); + + expect(providers.openai.modelParameters).toEqual({ + "gpt-5": { + max_output_tokens: 2048, + temperature: 0.4, + top_p: 0.9, + }, + }); + }); + }); + + it("sets model parameters while preserving unrelated model parameter keys", () => { + withTempConfig((config, service) => { + saveOpenAIConfig(config, { + modelParameters: { + "gpt-5": { + max_output_tokens: 1024, + top_k: 32, + }, + }, + }); + + const result = service.setModelParameters("openai", "gpt-5", { + max_output_tokens: null, + temperature: 0.7, + top_p: 0.95, + }); + + expect(result.success).toBe(true); + const providersConfig = config.loadProvidersConfig(); + expect(providersConfig?.openai?.modelParameters).toEqual({ + "gpt-5": { + temperature: 0.7, + top_p: 0.95, + top_k: 32, + }, + }); + }); + }); + + it("clears editable keys without dropping non-editable model parameters", () => { + withTempConfig((config, service) => { + saveOpenAIConfig(config, { + models: ["gpt-5"], + modelParameters: { + "gpt-5": { + temperature: 0.3, + top_k: 16, + }, + }, + }); + + const result = service.setModelParameters("openai", "gpt-5", { + max_output_tokens: null, + temperature: null, + top_p: null, + }); + + expect(result.success).toBe(true); + const providersConfig = config.loadProvidersConfig(); + expect(providersConfig?.openai?.modelParameters).toEqual({ + "gpt-5": { + top_k: 16, + }, + }); + }); + }); + + it("retains valid model parameter entries when another entry is malformed", () => { + withTempConfig((config, service) => { + saveOpenAIConfig(config, { + modelParameters: { + "*": { + top_p: 0.8, + custom_passthrough: true, + }, + "bad-model": { + temperature: 3, + }, + "gpt-5": { + temperature: 0.4, + }, + }, + }); + + const result = service.setModelParameters("openai", "gpt-5", { + max_output_tokens: 2048, + temperature: null, + top_p: null, + }); + + expect(result.success).toBe(true); + const providersConfig = config.loadProvidersConfig(); + expect(providersConfig?.openai?.modelParameters).toEqual({ + "*": { + top_p: 0.8, + custom_passthrough: true, + }, + "gpt-5": { + max_output_tokens: 2048, + }, + }); + }); + }); + + it("renames model parameter entries while preserving non-editable overrides", () => { + withTempConfig((config, service) => { + saveOpenAIConfig(config, { + models: ["renamed"], + modelParameters: { + legacy: { + temperature: 0.2, + top_k: 32, + seed: 7, + }, + }, + }); + + const renameResult = service.setModelParameters( + "openai", + "renamed", + { + max_output_tokens: 4096, + temperature: 0.5, + top_p: null, + }, + "legacy" + ); + expect(renameResult.success).toBe(true); + + const providersConfig = config.loadProvidersConfig(); + expect(providersConfig?.openai?.modelParameters).toEqual({ + renamed: { + max_output_tokens: 4096, + temperature: 0.5, + top_k: 32, + seed: 7, + }, + }); + }); + }); + + it("removes modelParameters when the last override entry is cleared", () => { + withTempConfig((config, service) => { + saveOpenAIConfig(config, { + modelParameters: { + "gpt-5": { + max_output_tokens: 1024, + }, + }, + }); + + const result = service.setModelParameters("openai", "gpt-5", { + max_output_tokens: null, + temperature: null, + top_p: null, + }); + + expect(result.success).toBe(true); + const providersConfig = config.loadProvidersConfig(); + expect(providersConfig?.openai?.modelParameters).toBeUndefined(); + }); + }); +}); + describe("ProviderService custom provider mutations", () => { it("rejects adding a built-in provider id", () => { withTempConfig((config, service) => { diff --git a/src/node/services/providerService.ts b/src/node/services/providerService.ts index 900e3b9546..1ee3b44292 100644 --- a/src/node/services/providerService.ts +++ b/src/node/services/providerService.ts @@ -6,6 +6,10 @@ import { type ProviderName, } from "@/common/constants/providers"; import type { BaseProviderConfig } from "@/common/config/schemas/providersConfig"; +import { + ModelParameterOverridesSchema, + ModelParametersByModelSchema, +} from "@/common/config/schemas/modelParameters"; import type { Result } from "@/common/types/result"; import type { AddCustomOpenAICompatibleProviderInput, @@ -60,14 +64,96 @@ function filterProviderModelsByPolicy( return models.filter((entry) => allowedModels.includes(getProviderModelEntryId(entry))); } +type EditableModelParameterOverrides = { + max_output_tokens?: number | null; + temperature?: number | null; + top_p?: number | null; +}; + +const MODEL_PARAMETER_OVERRIDE_KEYS = ["max_output_tokens", "temperature", "top_p"] as const; + +type ModelParameterOverrideKey = (typeof MODEL_PARAMETER_OVERRIDE_KEYS)[number]; + +function normalizeModelParameters( + modelParameters: unknown +): BaseProviderConfig["modelParameters"] | undefined { + const parsed = ModelParametersByModelSchema.safeParse(modelParameters); + if (parsed.success) { + return parsed.data; + } + + if ( + typeof modelParameters !== "object" || + modelParameters === null || + Array.isArray(modelParameters) + ) { + return undefined; + } + + const recoveredEntries: Array<[string, Record]> = []; + for (const [modelId, overrides] of Object.entries(modelParameters)) { + if (modelId.trim().length === 0) { + continue; + } + + const parsedOverrides = ModelParameterOverridesSchema.safeParse(overrides); + if (!parsedOverrides.success) { + continue; + } + + recoveredEntries.push([modelId, parsedOverrides.data]); + } + + if (recoveredEntries.length === 0) { + return undefined; + } + + return Object.fromEntries(recoveredEntries); +} + +function filterModelParametersByPolicy( + modelParameters: BaseProviderConfig["modelParameters"] | undefined, + allowedModels: string[] | null +): BaseProviderConfig["modelParameters"] | undefined { + if (!modelParameters) { + return undefined; + } + + if (!Array.isArray(allowedModels)) { + return modelParameters; + } + + const filteredEntries = Object.entries(modelParameters).filter( + ([modelId]) => modelId === "*" || allowedModels.includes(modelId) + ); + + if (filteredEntries.length === 0) { + return undefined; + } + + return Object.fromEntries(filteredEntries); +} + +function hasAnyEditableModelParameterOverride(overrides: EditableModelParameterOverrides): boolean { + return MODEL_PARAMETER_OVERRIDE_KEYS.some((key) => { + const value = overrides[key as ModelParameterOverrideKey]; + return typeof value === "number"; + }); +} + function buildCustomProviderConfigInfo( config: BaseProviderConfig, policy?: { forcedBaseUrl?: string; allowedModels?: string[] | null } ): ProviderConfigInfo { const baseUrl = policy?.forcedBaseUrl ?? resolveConfigBaseUrl(config); + const allowedModels = policy?.allowedModels ?? null; const models = filterProviderModelsByPolicy( normalizeProviderModelEntries(config.models), - policy?.allowedModels ?? null + allowedModels + ); + const modelParameters = filterModelParametersByPolicy( + normalizeModelParameters(config.modelParameters), + allowedModels ); const apiKeyIsOpRef = isOpReference(config.apiKey); const apiKeySet = typeof config.apiKey === "string" && config.apiKey.trim().length > 0; @@ -83,6 +169,7 @@ function buildCustomProviderConfigInfo( apiKeySource: apiKeySet ? "config" : apiKeyFile ? "file" : "keyless", baseUrl, models, + modelParameters, displayName: config.displayName, providerType: "openai-compatible", isCustom: true, @@ -304,6 +391,7 @@ export class ProviderService { baseUrl?: string; baseURL?: string; models?: unknown[]; + modelParameters?: unknown; serviceTier?: string; wireFormat?: string; store?: unknown; @@ -336,6 +424,10 @@ export class ProviderService { const normalizedModels = config.models === undefined ? undefined : normalizeProviderModelEntries(config.models); const filteredModels = filterProviderModelsByPolicy(normalizedModels, allowedModels); + const filteredModelParameters = filterModelParametersByPolicy( + normalizeModelParameters(config.modelParameters), + allowedModels + ); const codexOauthSet = provider === "openai" && parseCodexOauthAuth(config.codexOauth) !== null; @@ -358,6 +450,7 @@ export class ProviderService { baseUrl: forcedBaseUrl ?? explicitBaseUrl, apiKeyFile: typeof config.apiKeyFile === "string" ? config.apiKeyFile : undefined, models: filteredModels, + modelParameters: filteredModelParameters, }; // OpenAI-specific fields @@ -852,6 +945,128 @@ export class ProviderService { } } + public setModelParameters( + provider: string, + modelId: string, + overrides: EditableModelParameterOverrides, + renameFromModelId?: string + ): Result { + const normalizedModelId = modelId.trim(); + if (normalizedModelId.length === 0) { + return { success: false, error: "Model ID cannot be empty" }; + } + + const normalizedRenameFromModelId = renameFromModelId?.trim(); + if ( + renameFromModelId !== undefined && + (!normalizedRenameFromModelId || normalizedRenameFromModelId.length === 0) + ) { + return { success: false, error: "Rename source model ID cannot be empty" }; + } + + try { + if (this.policyService?.isEnforced()) { + if (!this.policyService.isProviderAllowed(provider)) { + return { success: false, error: `Provider ${provider} is not allowed by policy` }; + } + + const allowedModels = + this.policyService.getEffectivePolicy()?.providerAccess?.find((p) => p.id === provider) + ?.allowedModels ?? null; + + if ( + normalizedModelId !== "*" && + Array.isArray(allowedModels) && + !allowedModels.includes(normalizedModelId) + ) { + return { + success: false, + error: `Model ${normalizedModelId} is not allowed by policy`, + }; + } + } + + const providersConfig = this.config.loadProvidersConfig() ?? {}; + + if (!providersConfig[provider]) { + providersConfig[provider] = {}; + } + + const providerConfig = providersConfig[provider] as BaseProviderConfig; + const currentModelParameters = normalizeModelParameters(providerConfig.modelParameters) ?? {}; + const nextModelParameters = { ...currentModelParameters }; + + if ( + normalizedRenameFromModelId && + normalizedRenameFromModelId !== normalizedModelId && + normalizedRenameFromModelId in nextModelParameters + ) { + const renameSourceOverrides = nextModelParameters[normalizedRenameFromModelId]; + const destinationOverrides = nextModelParameters[normalizedModelId]; + + if ( + renameSourceOverrides && + typeof renameSourceOverrides === "object" && + !Array.isArray(renameSourceOverrides) + ) { + nextModelParameters[normalizedModelId] = { + ...(destinationOverrides && + typeof destinationOverrides === "object" && + !Array.isArray(destinationOverrides) + ? destinationOverrides + : {}), + ...renameSourceOverrides, + }; + } + + delete nextModelParameters[normalizedRenameFromModelId]; + } + + const existingModelOverrides = nextModelParameters[normalizedModelId]; + const nextModelOverrides: Record = + typeof existingModelOverrides === "object" && + existingModelOverrides !== null && + !Array.isArray(existingModelOverrides) + ? { ...existingModelOverrides } + : {}; + + for (const key of MODEL_PARAMETER_OVERRIDE_KEYS) { + const value = overrides[key as ModelParameterOverrideKey]; + if (typeof value === "number") { + nextModelOverrides[key] = value; + } else { + delete nextModelOverrides[key]; + } + } + + const modelIsConfigured = normalizeProviderModelEntries(providerConfig.models).some( + (entry) => getProviderModelEntryId(entry) === normalizedModelId + ); + + if (!modelIsConfigured && !hasAnyEditableModelParameterOverride(overrides)) { + delete nextModelParameters[normalizedModelId]; + } else if (Object.keys(nextModelOverrides).length === 0) { + delete nextModelParameters[normalizedModelId]; + } else { + nextModelParameters[normalizedModelId] = nextModelOverrides; + } + + if (Object.keys(nextModelParameters).length === 0) { + delete providerConfig.modelParameters; + } else { + providerConfig.modelParameters = nextModelParameters; + } + + this.config.saveProvidersConfig(providersConfig); + this.notifyFromMutation(); + + return { success: true, data: undefined }; + } catch (error) { + const message = getErrorMessage(error); + return { success: false, error: `Failed to set model parameters: ${message}` }; + } + } + /** * After a credential change, sync gateway presence in routePriority. * Configured gateways auto-insert immediately before "direct" in routePriority,