-
-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Expand file tree
/
Copy pathmodelStrings.ts
More file actions
167 lines (154 loc) · 5.15 KB
/
modelStrings.ts
File metadata and controls
167 lines (154 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import {
getModelStrings as getModelStringsState,
setModelStrings as setModelStringsState,
} from 'src/bootstrap/state.js'
import { logError } from '../log.js'
import { sequential } from '../sequential.js'
import { getInitialSettings } from '../settings/settings.js'
import { findFirstMatch, getBedrockInferenceProfiles } from './bedrock.js'
import {
ALL_MODEL_CONFIGS,
CANONICAL_ID_TO_KEY,
type CanonicalModelId,
type ModelKey,
} from './configs.js'
import { type APIProvider, getAPIProvider } from './providers.js'
/**
* Maps each model version to its provider-specific model ID string.
* Derived from ALL_MODEL_CONFIGS — adding a model there extends this type.
*/
export type ModelStrings = Record<ModelKey, string>
const MODEL_KEYS = Object.keys(ALL_MODEL_CONFIGS) as ModelKey[]
function getBuiltinModelStrings(provider: APIProvider): ModelStrings {
const out = {} as ModelStrings
for (const key of MODEL_KEYS) {
out[key] = ALL_MODEL_CONFIGS[key][provider] ?? ALL_MODEL_CONFIGS[key].firstParty
}
return out
}
async function getBedrockModelStrings(): Promise<ModelStrings> {
const fallback = getBuiltinModelStrings('bedrock')
let profiles: string[] | undefined
try {
profiles = await getBedrockInferenceProfiles()
} catch (error) {
logError(error as Error)
return fallback
}
if (!profiles?.length) {
return fallback
}
// Each config's firstParty ID is the canonical substring we search for in the
// user's inference profile list (e.g. "claude-opus-4-6" matches
// "eu.anthropic.claude-opus-4-6-v1"). Fall back to the hardcoded bedrock ID
// when no matching profile is found.
const out = {} as ModelStrings
for (const key of MODEL_KEYS) {
const needle = ALL_MODEL_CONFIGS[key].firstParty
out[key] = findFirstMatch(profiles, needle) || fallback[key]
}
return out
}
/**
* Layer user-configured modelOverrides (from settings.json) on top of the
* provider-derived model strings. Overrides are keyed by canonical first-party
* model ID (e.g. "claude-opus-4-6") and map to arbitrary provider-specific
* strings — typically Bedrock inference profile ARNs.
*/
function applyModelOverrides(ms: ModelStrings): ModelStrings {
const overrides = getInitialSettings().modelOverrides
if (!overrides) {
return ms
}
const out = { ...ms }
for (const [canonicalId, override] of Object.entries(overrides)) {
const key = CANONICAL_ID_TO_KEY[canonicalId as CanonicalModelId]
if (key && override) {
out[key] = override
}
}
return out
}
/**
* Resolve an overridden model ID (e.g. a Bedrock ARN) back to its canonical
* first-party model ID. If the input doesn't match any current override value,
* it is returned unchanged. Safe to call during module init (no-ops if settings
* aren't loaded yet).
*/
export function resolveOverriddenModel(modelId: string): string {
let overrides: Record<string, string> | undefined
try {
overrides = getInitialSettings().modelOverrides
} catch {
return modelId
}
if (!overrides) {
return modelId
}
for (const [canonicalId, override] of Object.entries(overrides)) {
if (override === modelId) {
return canonicalId
}
}
return modelId
}
const updateBedrockModelStrings = sequential(async () => {
if (getModelStringsState() !== null) {
// Already initialized. Doing the check here, combined with
// `sequential`, allows the test suite to reset the state
// between tests while still preventing multiple API calls
// in production.
return
}
try {
const ms = await getBedrockModelStrings()
setModelStringsState(ms)
} catch (error) {
logError(error as Error)
}
})
function initModelStrings(): void {
const ms = getModelStringsState()
if (ms !== null) {
// Already initialized
return
}
// Initial with default values for non-Bedrock providers
if (getAPIProvider() !== 'bedrock') {
setModelStringsState(getBuiltinModelStrings(getAPIProvider()))
return
}
// On Bedrock, update model strings in the background without blocking.
// Don't set the state in this case so that we can use `sequential` on
// `updateBedrockModelStrings` and check for existing state on multiple
// calls.
void updateBedrockModelStrings()
}
export function getModelStrings(): ModelStrings {
const ms = getModelStringsState()
if (ms === null) {
initModelStrings()
// Bedrock path falls through here while the profile fetch runs in the
// background — still honor overrides on the interim defaults.
return applyModelOverrides(getBuiltinModelStrings(getAPIProvider()))
}
return applyModelOverrides(ms)
}
/**
* Ensure model strings are fully initialized.
* For Bedrock users, this waits for the profile fetch to complete.
* Call this before generating model options to ensure correct region strings.
*/
export async function ensureModelStringsInitialized(): Promise<void> {
const ms = getModelStringsState()
if (ms !== null) {
return
}
// For non-Bedrock, initialize synchronously
if (getAPIProvider() !== 'bedrock') {
setModelStringsState(getBuiltinModelStrings(getAPIProvider()))
return
}
// For Bedrock, wait for the profile fetch
await updateBedrockModelStrings()
}