diff --git a/packages/opencode/src/provider/nim-defense.ts b/packages/opencode/src/provider/nim-defense.ts new file mode 100644 index 000000000000..e6665c494844 --- /dev/null +++ b/packages/opencode/src/provider/nim-defense.ts @@ -0,0 +1,477 @@ +import crypto from "crypto" + +// ─── NVIDIA NIM Provider Detection ────────────────────────────────── + +export interface NimDefenseOptions { + /** Max retries for NIM transient failures (default: 3) */ + nimRetries?: number + /** Base delay in ms for retry backoff (default: 1000) */ + nimRetryDelay?: number +} + +/** + * Detect whether the given provider config is an NVIDIA NIM endpoint. + * Multi-factor detection: provider ID, npm package, and base URL. + */ +export function isNimProvider( + providerID: string, + npm: string, + baseURL?: string, +): boolean { + if (npm === "@ai-sdk/openai-compatible" && providerID === "nvidia") { + return true + } + if (baseURL?.includes("nvidia.com") || baseURL?.includes("api.nvidia.com")) { + return true + } + // Fallback: openai-compatible provider with nvidia in the provider ID + if (npm === "@ai-sdk/openai-compatible" && providerID.toLowerCase().includes("nvidia")) { + return true + } + return false +} + +// ─── Model ID Normalization ───────────────────────────────────────── + +/** + * Fix #22493: deduplicate `nvidia/nvidia/` prefix in model IDs. + * Handles 2+ levels. Idempotent. + */ +export function normalizeNvidiaModelId(modelId: string): string { + return modelId.replace(/^(nvidia\/)+/, "nvidia/") +} + +// ─── JSON Repair ──────────────────────────────────────────────────── + +/** + * Attempt JSON.parse first; only repair on failure. + * Repair pipeline: + * 1. Remove trailing commas before ] and } + * 2. Context-aware single-quote replacement: + * Only replace single quotes used as JSON string delimiters + * (after `{`, `,`, `[`, `:`, whitespace boundaries), + * NOT apostrophes inside string values like "it's". + * 3. Convert Python literals (True/False/None) to JSON equivalents + * 4. Brace balancing (only counts braces outside string values) + */ +export function repairMalformedJson(jsonStr: string): string { + // First attempt: maybe it's already valid + try { + JSON.parse(jsonStr) + return jsonStr + } catch { + // Proceed with repair + } + + // 1. Remove trailing commas before ] and } + let repaired = jsonStr.replace(/,\s*([}\]])/g, "$1") + + // 2. Context-aware single-quote replacement: + // Only replace `'` at JSON token boundaries to preserve apostrophes. + repaired = replaceSingleQuotesSafely(repaired) + + // 2b. Fix escaped single quotes inside now-double-quoted strings: + // JSON doesn't recognize \' escapes. Single quotes inside + // double-quoted strings don't need escaping at all. + // e.g. {"msg": "it\'s broken"} → {"msg": "it's broken"} + repaired = repaired.replace(/\\'/g, "'") + + // 3. Convert Python literals + repaired = repaired.replace(/\b(True|False|None)\b/g, (_, m: string) => + m === "True" ? "true" : m === "False" ? "false" : "null", + ) + + // 4. Brace balancing (only count braces outside string values) + const openBraces = countBracesOutsideStrings(repaired, "{") + const closeBraces = countBracesOutsideStrings(repaired, "}") + if (openBraces > closeBraces) { + repaired += "}".repeat(openBraces - closeBraces) + } + + return repaired +} + +/** + * Context-aware replacement of single quotes used as JSON delimiters. + * Only replaces single quotes that appear OUTSIDE double-quoted string values. + */ +/** + * Context-aware replacement of single quotes used as JSON delimiters. + * + * Strategy: replace `'` with `"` only when it appears at a JSON token + * boundary — preceded by `{`, `,`, `:`, `[`, or whitespace (opening), or + * followed by `}`, `:`, `,`, `]`, or whitespace (closing). This + * preserves apostrophes INSIDE string values like `"it's fine"` while + * fixing malformed JSON that uses single quotes as property/value + * delimiters like `{'key': 'value'}`. + */ +function replaceSingleQuotesSafely(str: string): string { + // Replace opening single quotes: preceded by structural chars or whitespace + let result = str.replace(/(?<=[\s{:,[])'/g, '"') + // Replace closing single quotes: followed by structural chars or whitespace + result = result.replace(/'(?=[\s}:,\]])/g, '"') + return result +} + +/** + * Count occurrences of a brace character outside of double-quoted string values. + * This prevents unbalanced braces inside string values from corrupting the count. + */ +function countBracesOutsideStrings(str: string, brace: "{" | "}"): number { + let inString = false + let escapeNext = false + let count = 0 + + for (const c of str) { + if (escapeNext) { + escapeNext = false + continue + } + if (c === "\\") { + escapeNext = true + continue + } + if (c === '"') { + inString = !inString + continue + } + if (!inString && c === brace) { + count++ + } + } + + return count +} + +// ─── Response Normalization ───────────────────────────────────────── + +interface NimToolCall { + id?: string | number | null + type?: string + function?: { + name?: string + arguments?: string | object + } +} + +interface NimChoiceMessage { + content?: string | unknown[] | null + tool_calls?: NimToolCall[] | null +} + +interface NimChoice { + message?: NimChoiceMessage | null +} + +interface NimResponse { + id?: string | null + choices?: NimChoice[] | null + [key: string]: unknown +} + +/** + * Normalize a raw NIM API response to be compliant with OpenAI response schema. + * + * Fixes: + * - Missing/numeric/null tool_call.id + * - Raw object arguments → JSON string + * - Malformed JSON arguments + * - Thinking block leakage ( and ) + * - Missing top-level response.id + * + * Valid responses pass through unchanged (defensive copy via structuredClone). + */ +export function normalizeNimResponse(raw: unknown): unknown { + if (!raw || typeof raw !== "object") return raw + + const normalized = structuredClone(raw) as NimResponse + + // Fix top-level response.id + if (normalized.id === undefined || normalized.id === null || typeof normalized.id !== "string") { + normalized.id = `nim_${crypto.randomUUID()}` + } + + if (!Array.isArray(normalized.choices)) return normalized + + for (const choice of normalized.choices) { + const message = choice?.message + if (!message) continue + + // Strip reasoning leakage from content (string content only) + if (typeof message.content === "string") { + // Remove entire think/thinking blocks AND leftover unpaired tags + message.content = message.content + .replace(/[\s\S]*?<\/think>/gi, "") + .replace(/[\s\S]*?<\/thinking>/gi, "") + .replace(/<\/?think(?:ing)?>/gi, "") + } + // Content arrays pass through untouched + + // Fix tool_calls + if (!Array.isArray(message.tool_calls)) continue + + for (const tool of message.tool_calls) { + if (!tool || typeof tool !== "object") continue + + // Fix id: ensure string + if (tool.id === undefined || tool.id === null) { + tool.id = `call_${crypto.randomUUID()}` + } else if (typeof tool.id === "number") { + tool.id = String(tool.id) + } + + if (!tool.function) continue + + // Fix arguments: object → JSON string + if (typeof tool.function.arguments === "object" && tool.function.arguments !== null) { + tool.function.arguments = JSON.stringify(tool.function.arguments) + } + + // Fix arguments: malformed JSON repair + if (typeof tool.function.arguments === "string") { + tool.function.arguments = repairMalformedJson(tool.function.arguments) + } + } + } + + return normalized +} + +// ─── Request Enrichment ───────────────────────────────────────────── + +// Known reasoning models and their required chat_template_kwargs. +// NOTE: This is a static map. New NIM reasoning models may require additions. +// The fallback heuristic below catches unknown reasoning model variants. +const REASONING_MODEL_KWARGS: Record> = { + "deepseek-ai/deepseek-v4": { enable_thinking: true, thinking: true }, + "moonshotai/kimi-k2": { thinking: true }, + "z-ai/glm-5": { enable_thinking: true, clear_thinking: false }, +} + +// Keywords that suggest a model uses reasoning/capabilities that need +// chat_template_kwargs. Used as fallback for unknown reasoning model variants. +const REASONING_KEYWORDS = [ + "deepseek", + "kimi", + "k2", + "k2p", + "glm", + "qwen", + "reasoning", + "think", + "qwq", +] + +/** + * Inject chat_template_kwargs into the request body for reasoning models. + * Falls back to conservative keyword-based heuristic for unknown variants. + * Logs a warning if user kwargs conflict with known-required kwargs. + */ +export function enrichNimRequest( + body: Record, + modelId: string, + log?: (msg: string) => void, +): Record { + const result = { ...body } + const modelLower = modelId.toLowerCase() + const existing = (result.chat_template_kwargs as Record | undefined) ?? {} + + // Check static map first + let defaults: Record | undefined + for (const [prefix, kwargs] of Object.entries(REASONING_MODEL_KWARGS)) { + if (modelId.includes(prefix)) { + defaults = kwargs + break + } + } + + // Fallback heuristic: keyword match + if (!defaults) { + const hasReasoningKeyword = REASONING_KEYWORDS.some((kw) => modelLower.includes(kw)) + if (hasReasoningKeyword) { + defaults = { enable_thinking: true, thinking: true } + } + } + + if (defaults) { + // Warn if user kwargs contradict known-required kwargs + if (log) { + for (const [key, val] of Object.entries(defaults)) { + if (key in existing && existing[key] !== val) { + log( + `[NIM] chat_template_kwargs.${key} is ${JSON.stringify(existing[key])} but ${JSON.stringify(val)} is required for reasoning model ${modelId}. The model may hang or behave unexpectedly.`, + ) + } + } + } + + result.chat_template_kwargs = { ...defaults, ...existing } + } + + return result +} + +// ─── Retry & Resilience ───────────────────────────────────────────── + +export interface NimRetryOptions { + /** Max retries (default: 3) */ + maxRetries?: number + /** Base delay in ms (default: 1000) */ + baseDelay?: number + /** Sleep function for testability (default: setTimeout-based) */ + sleepFn?: (ms: number) => Promise + /** External abort signal for user cancellation */ + signal?: AbortSignal + /** Logger */ + log?: (msg: string) => void +} + +const DEFAULT_RETRY_OPTIONS = { + maxRetries: 3, + baseDelay: 1000, + sleepFn: (ms: number) => new Promise((r) => setTimeout(r, ms)), + signal: undefined as AbortSignal | undefined, + log: undefined as ((msg: string) => void) | undefined, +} + +/** + * Check if an error message indicates a retryable NIM failure. + * Used as fallback when error.name doesn't match known class names. + */ +function isMessageBasedRetryable(msg: string): boolean { + const lower = msg.toLowerCase() + return ( + lower.includes("invalidresponsedataerror") || + lower.includes("expected 'id' to be a string") || + lower.includes("nim http 200 error") || + lower.includes("nim error payload") || + lower.includes("nim unexpected text") || + lower.includes("nim http 429") || + lower.includes("nim http 5") || + lower.includes("nim http 50") + ) +} + +/** + * Detect HTTP 200 error payloads from NIM. + * Content-type aware: only keyword-match for text/plain responses. + * JSON responses are checked structurally for error field. + */ +async function detectHttp200Error(res: Response): Promise<{ isError: boolean; errorText?: string }> { + const contentType = (res.headers.get("content-type") || "").toLowerCase() + + // Text/plain or text/html: keyword match + if (contentType.includes("text/plain") || contentType.includes("text/html")) { + const text = await res.text() + // Narrow keyword matching: only check in text/plain responses + if ( + text.includes("unavailable") || + text.includes("rate limit") || + (text.includes("error") && !text.includes("no error")) + ) { + return { isError: true, errorText: text.slice(0, 500) } + } + // Unexpected text response — treat as error to avoid processing garbage + return { isError: text.length > 0 && text[0] !== "{", errorText: text.slice(0, 200) } + } + + // JSON: check structurally + if (contentType.includes("json")) { + try { + const body = await res.json() + if (body?.error) { + return { isError: true, errorText: JSON.stringify(body.error).slice(0, 500) } + } + return { isError: false } + } catch { + return { isError: true, errorText: "Failed to parse JSON response" } + } + } + + return { isError: false } +} + +/** + * Fetch wrapper with NIM defense layers. + * Wraps BOTH fetch and normalizeNimResponse as a single retry unit. + * + * Only retries: + * - AI_InvalidResponseDataError (by name AND message pattern) + * - HTTP 429 (rate limited) + * - HTTP 5xx (server errors) + * - HTTP 200 error payloads + * + * Does NOT retry: + * - HTTP 401/403/404 + * - Non-NVIDIA providers (caller should not call this function) + */ +export async function fetchWithNimDefense( + fetchFn: () => Promise, + modelId: string, + options: NimRetryOptions = {}, +): Promise { + const opts = { ...DEFAULT_RETRY_OPTIONS, ...options } + const sleepFn = opts.sleepFn! + const maxRetries = Math.max(1, opts.maxRetries!) + const baseDelay = opts.baseDelay! + const log = opts.log + const externalSignal = opts.signal + const errors: Error[] = [] + + for (let attempt = 0; attempt < maxRetries; attempt++) { + // Check for user cancellation before each attempt + if (externalSignal?.aborted) { + throw externalSignal.reason ?? new DOMException("Aborted", "AbortError") + } + + try { + // Per-attempt timeout: compose with any external signal + // Using AbortSignal.any — existing signals handle the composition + const response = await fetchFn() + + // Handle HTTP 200 error payloads (before JSON parse) + if (response.status === 200) { + const { isError, errorText } = await detectHttp200Error(response.clone()) + if (isError) { + throw new Error(`NIM HTTP 200 error: ${errorText}`) + } + } + + // Only retry 429 and 5xx + if (response.status === 429 || response.status >= 500) { + throw new Error(`NIM HTTP ${response.status} error`) + } + + // Non-OK but non-retryable statuses pass through as-is + return response + } catch (err: unknown) { + const error = err instanceof Error ? err : new Error(String(err)) + errors.push(error) + + const isRetryable = + // Match by error name + error.name === "AI_InvalidResponseDataError" || + // Match by message pattern (fallback) + isMessageBasedRetryable(error.message) + + if (isRetryable && attempt < maxRetries - 1) { + // Full jitter: random(0, min(base * 2^attempt, 8000)) + const cap = Math.min(baseDelay * 2 ** attempt, 8000) + const delay = Math.random() * cap + if (log) { + log(`[NIM] Retry ${attempt + 1}/${maxRetries} after ${Math.round(delay)}ms for ${modelId}: ${error.message}`) + } + await sleepFn(delay) + continue + } + + // Not retryable or exhausted — throw with summary + const summary = errors.map((e) => e.message).join("; ") + const ex = new Error(`NIM retry exhausted for ${modelId}: ${attempt + 1} attempt(s) - ${summary}`) + ex.name = error.name + throw ex + } + } + + throw new Error("NIM retry exhausted: unexpected exit from retry loop") +} diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 063e2800d167..b5bcb2cd1225 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -24,6 +24,13 @@ import { AppFileSystem } from "@opencode-ai/core/filesystem" import { isRecord } from "@/util/record" import { optionalOmitUndefined } from "@opencode-ai/core/schema" import * as ProviderTransform from "./transform" +import { + enrichNimRequest, + fetchWithNimDefense, + isNimProvider, + normalizeNimResponse, + normalizeNvidiaModelId, +} from "./nim-defense" import { ModelID, ProviderID } from "./schema" import { ModelStatus } from "./model-status" import { RuntimeFlags } from "@/effect/runtime-flags" @@ -1596,11 +1603,46 @@ export const layer = Layer.effect( } } - const res = await fetchFn(input, { - ...opts, - // @ts-ignore see here: https://github.com/oven-sh/bun/issues/16682 - timeout: false, - }) + // NVIDIA NIM: enrich request body with reasoning model kwargs + const isNim = isNimProvider(model.providerID, model.api.npm, options["baseURL"] as string | undefined) + if (isNim && opts.body && opts.method === "POST") { + const body = JSON.parse(opts.body as string) + const enriched = enrichNimRequest(body, model.api.id, (msg) => log.warn(msg)) + // Two important fixes for NIM: + // 1. Deduplicate nvidia/ prefix in model ID + if (enriched.model) enriched.model = normalizeNvidiaModelId(enriched.model as string) + // 2. Inject chat_template_kwargs at the root of the fetch body + opts.body = JSON.stringify(enriched) + } + + const innerFetch = () => + fetchFn(input, { + ...opts, + // @ts-ignore see here: https://github.com/oven-sh/bun/issues/16682 + timeout: false, + }) + + const res = isNim + ? await fetchWithNimDefense(innerFetch, model.api.id, { + signal: combined ?? undefined, + }) + : await innerFetch() + + // NVIDIA NIM: normalize non-streaming responses + if (isNim) { + // Detect streaming: prefer ReadableStream check over content-type + const isStreaming = res.body instanceof ReadableStream || + (res.headers.get("content-type") || "").includes("text/event-stream") + if (!isStreaming) { + const normalized = normalizeNimResponse(await res.clone().json()) + return new Response(JSON.stringify(normalized), { + status: res.status, + statusText: res.statusText, + headers: new Headers(res.headers), + }) + } + // Streaming responses pass through unchanged + } if (!chunkAbortCtl) return res return wrapSSE(res, chunkTimeout, chunkAbortCtl) diff --git a/packages/opencode/test/provider/nim-defense.test.ts b/packages/opencode/test/provider/nim-defense.test.ts new file mode 100644 index 000000000000..262ce8b077a7 --- /dev/null +++ b/packages/opencode/test/provider/nim-defense.test.ts @@ -0,0 +1,486 @@ +import { test, expect } from "bun:test" +import { + enrichNimRequest, + fetchWithNimDefense, + isNimProvider, + normalizeNimResponse, + normalizeNvidiaModelId, + repairMalformedJson, +} from "../../src/provider/nim-defense" + +// ─── isNimProvider ────────────────────────────────────────────────── + +test("detects NVIDIA by provider ID and npm", () => { + expect(isNimProvider("nvidia", "@ai-sdk/openai-compatible")).toBe(true) +}) + +test("detects NVIDIA by baseURL", () => { + expect(isNimProvider("custom", "@ai-sdk/openai-compatible", "https://integrate.api.nvidia.com/v1")).toBe(true) +}) + +test("detects NVIDIA by baseURL with nvidia.com", () => { + expect(isNimProvider("my-provider", "@ai-sdk/openai-compatible", "https://api.nvidia.com/v1")).toBe(true) +}) + +test("rejects non-NVIDIA providers", () => { + expect(isNimProvider("anthropic", "@ai-sdk/anthropic")).toBe(false) + expect(isNimProvider("openai", "@ai-sdk/openai")).toBe(false) + expect(isNimProvider("nvidia", "@ai-sdk/openai")).toBe(false) +}) + +test("rejects openai-compatible with non-NVIDIA URL", () => { + expect(isNimProvider("fireworks", "@ai-sdk/openai-compatible", "https://api.fireworks.ai/v1")).toBe(false) +}) + +// ─── normalizeNvidiaModelId ───────────────────────────────────────── + +test("fixes double nvidia/ prefix", () => { + expect(normalizeNvidiaModelId("nvidia/nvidia/meta/llama-3_1-70b")).toBe("nvidia/meta/llama-3_1-70b") +}) + +test("fixes triple nvidia/ prefix", () => { + expect(normalizeNvidiaModelId("nvidia/nvidia/nvidia/meta/llama-3_1-70b")).toBe("nvidia/meta/llama-3_1-70b") +}) + +test("passes normal model ID through unchanged", () => { + expect(normalizeNvidiaModelId("nvidia/meta/llama-3_1-70b")).toBe("nvidia/meta/llama-3_1-70b") +}) + +test("passes non-NVIDIA model ID through unchanged", () => { + expect(normalizeNvidiaModelId("anthropic/claude-sonnet-4")).toBe("anthropic/claude-sonnet-4") +}) + +test("handles single nvidia/ prefix correctly (no dedup needed)", () => { + expect(normalizeNvidiaModelId("nvidia/deepseek-ai/deepseek-v4")).toBe("nvidia/deepseek-ai/deepseek-v4") +}) + +test("is idempotent", () => { + const id = "nvidia/nvidia/meta/llama-3_1-70b" + expect(normalizeNvidiaModelId(normalizeNvidiaModelId(id))).toBe("nvidia/meta/llama-3_1-70b") +}) + +// ─── repairMalformedJson ──────────────────────────────────────────── + +test("passes valid JSON through unchanged", () => { + const valid = '{"key": "value", "num": 42}' + expect(repairMalformedJson(valid)).toBe(valid) +}) + +test("preserves apostrophes inside string values", () => { + const input = '{"msg": "it\'s fine"}' + const result = repairMalformedJson(input) + // After repair, the single quotes-as-apostrophes should be preserved + const parsed = JSON.parse(result) + expect(parsed.msg).toBe("it's fine") +}) + +test("removes trailing commas in objects", () => { + const result = repairMalformedJson('{"a": 1, "b": 2,}') + expect(JSON.parse(result)).toEqual({ a: 1, b: 2 }) +}) + +test("removes trailing commas in arrays", () => { + const result = repairMalformedJson('[1, 2, 3,]') + expect(JSON.parse(result)).toEqual([1, 2, 3]) +}) + +test("converts single-quote JSON delimiters to double quotes", () => { + const result = repairMalformedJson("{'key': 'value'}") + expect(JSON.parse(result)).toEqual({ key: "value" }) +}) + +test("converts Python True/False/None literals", () => { + const result = repairMalformedJson('{"a": True, "b": False, "c": None}') + const parsed = JSON.parse(result) + expect(parsed.a).toBe(true) + expect(parsed.b).toBe(false) + expect(parsed.c).toBe(null) +}) + +test("balances missing closing braces", () => { + const result = repairMalformedJson('{"a": {"b": 1}') + expect(JSON.parse(result)).toEqual({ a: { b: 1 } }) +}) + +test("handles braces inside JSON string values without corruption", () => { + // Valid JSON - passes through + const input = '{"code": "if (x) { return y; }"}' + expect(repairMalformedJson(input)).toBe(input) +}) + +test("handles mixed single-quote delimiters and apostrophes", () => { + const input = "{'msg': 'it\\'s broken', 'status': 'ok'}" + const result = repairMalformedJson(input) + const parsed = JSON.parse(result) + expect(parsed.msg).toBe("it's broken") + expect(parsed.status).toBe("ok") +}) + +test("handles empty input gracefully", () => { + const result = repairMalformedJson("") + expect(result).toBe("") +}) + +// ─── normalizeNimResponse ────────────────────────────────────────── + +test("handles null/undefined input", () => { + expect(normalizeNimResponse(null)).toBe(null) + expect(normalizeNimResponse(undefined)).toBe(undefined) +}) + +test("generates response.id when missing", () => { + const result = normalizeNimResponse({ choices: [] }) as any + expect(result.id).toMatch(/^nim_/) +}) + +test("generates response.id when null", () => { + const result = normalizeNimResponse({ id: null, choices: [] }) as any + expect(result.id).toMatch(/^nim_/) +}) + +test("preserves valid response.id", () => { + const result = normalizeNimResponse({ id: "valid-id", choices: [] }) as any + expect(result.id).toBe("valid-id") +}) + +test("fixes numeric tool_call.id to string", () => { + const input = { + id: "r1", + choices: [{ message: { tool_calls: [{ id: 123, function: { name: "f", arguments: "{}" } }] } }], + } + const result = normalizeNimResponse(input) as any + expect(typeof result.choices[0].message.tool_calls[0].id).toBe("string") + expect(result.choices[0].message.tool_calls[0].id).toBe("123") +}) + +test("generates tool_call.id when missing", () => { + const input = { + id: "r1", + choices: [{ message: { tool_calls: [{ function: { name: "f", arguments: "{}" } }] } }], + } + const result = normalizeNimResponse(input) as any + expect(result.choices[0].message.tool_calls[0].id).toMatch(/^call_/) +}) + +test("generates tool_call.id when null", () => { + const input = { + id: "r1", + choices: [{ message: { tool_calls: [{ id: null, function: { name: "f", arguments: "{}" } }] } }], + } + const result = normalizeNimResponse(input) as any + expect(result.choices[0].message.tool_calls[0].id).toMatch(/^call_/) +}) + +test("converts dict arguments to JSON string", () => { + const input = { + id: "r1", + choices: [{ + message: { + tool_calls: [{ + id: "call_1", + type: "function", + function: { name: "get_weather", arguments: { location: "Paris" } }, + }], + }, + }], + } + const result = normalizeNimResponse(input) as any + expect(typeof result.choices[0].message.tool_calls[0].function.arguments).toBe("string") + expect(JSON.parse(result.choices[0].message.tool_calls[0].function.arguments)).toEqual({ location: "Paris" }) +}) + +test("strips thinking blocks from content", () => { + const input = { + id: "r1", + choices: [{ message: { content: "Beforeinternal reasoningAfter" } }], + } + const result = normalizeNimResponse(input) as any + expect(result.choices[0].message.content).toBe("BeforeAfter") +}) + +test("strips think blocks (without ing) from content", () => { + const input = { + id: "r1", + choices: [{ message: { content: "Hello reasoning world" } }], + } + const result = normalizeNimResponse(input) as any + expect(result.choices[0].message.content).toBe("Hello world") +}) + +test("passes valid responses through unchanged (no mutation)", () => { + const input = { + id: "r1", + choices: [{ + message: { + content: "Hello", + tool_calls: [{ + id: "call_1", + type: "function", + function: { name: "f", arguments: '{"x": 1}' }, + }], + }, + }], + } + const cloned = JSON.parse(JSON.stringify(input)) + normalizeNimResponse(input) + // Input should not be mutated (function uses structuredClone) + expect(JSON.stringify(input)).toBe(JSON.stringify(cloned)) +}) + +test("skips null message gracefully", () => { + const input = { choices: [{ message: null }] } + expect(() => normalizeNimResponse(input)).not.toThrow() + // Should still generate an id + const result = normalizeNimResponse(input) as any + expect(result.id).toMatch(/^nim_/) +}) + +test("handles content as array (non-string content)", () => { + const input = { + id: "r1", + choices: [{ message: { content: [{ type: "text", text: "Hello" }] } }], + } + expect(() => normalizeNimResponse(input)).not.toThrow() +}) + +test("handles empty tool_calls array", () => { + const input = { + id: "r1", + choices: [{ message: { content: "Hello", tool_calls: [] } }], + } + expect(() => normalizeNimResponse(input)).not.toThrow() +}) + +test("handles undefined tool_calls", () => { + const input = { + id: "r1", + choices: [{ message: { content: "Hello" } }], + } + expect(() => normalizeNimResponse(input)).not.toThrow() +}) + +test("handles choices array with null entries", () => { + const input = { id: "r1", choices: [null] } + expect(() => normalizeNimResponse(input)).not.toThrow() +}) + +test("handles mixed tool_calls: some valid, some numeric, some missing", () => { + const input = { + id: "r1", + choices: [{ + message: { + tool_calls: [ + { id: "valid", function: { name: "f1", arguments: "{}" } }, + { id: 456, function: { name: "f2", arguments: "{}" } }, + { function: { name: "f3", arguments: "{}" } }, + ], + }, + }], + } + const result = normalizeNimResponse(input) as any + const calls = result.choices[0].message.tool_calls + expect(calls[0].id).toBe("valid") + expect(calls[1].id).toBe("456") + expect(calls[2].id).toMatch(/^call_/) +}) + +test("handles tool_call without function property", () => { + const input = { + id: "r1", + choices: [{ + message: { + tool_calls: [{ id: "c1" }], + }, + }], + } + expect(() => normalizeNimResponse(input)).not.toThrow() +}) + +test("preserves extra properties on response", () => { + const input = { + id: "r1", + choices: [], + usage: { prompt_tokens: 10, completion_tokens: 20 }, + model: "test-model", + } + const result = normalizeNimResponse(input) as any + expect(result.usage).toEqual({ prompt_tokens: 10, completion_tokens: 20 }) + expect(result.model).toBe("test-model") +}) + +// ─── enrichNimRequest ─────────────────────────────────────────────── + +test("injects chat_template_kwargs for DeepSeek v4 reasoning models", () => { + const body = {} + const result = enrichNimRequest(body, "nvidia/deepseek-ai/deepseek-v4-flash") + expect(result.chat_template_kwargs).toEqual({ enable_thinking: true, thinking: true }) +}) + +test("injects chat_template_kwargs for Kimi K2 models", () => { + const body = {} + const result = enrichNimRequest(body, "nvidia/moonshotai/kimi-k2.6") + expect(result.chat_template_kwargs).toEqual({ thinking: true }) +}) + +test("injects chat_template_kwargs for GLM-5 models", () => { + const body = {} + const result = enrichNimRequest(body, "nvidia/z-ai/glm-5.1") + expect(result.chat_template_kwargs).toEqual({ enable_thinking: true, clear_thinking: false }) +}) + +test("merges user kwargs with defaults (user wins)", () => { + const body = { chat_template_kwargs: { thinking: false } } + const result = enrichNimRequest(body, "nvidia/deepseek-ai/deepseek-v4-flash") + const kwargs = result.chat_template_kwargs as Record | undefined + // Defaults first, user overwrites + expect(kwargs?.thinking).toBe(false) + expect(kwargs?.enable_thinking).toBe(true) +}) + +test("does not inject kwargs for non-reasoning models", () => { + const body = {} + const result = enrichNimRequest(body, "nvidia/meta/llama-3_1-70b") + expect(result.chat_template_kwargs).toBeUndefined() +}) + +test("uses fallback heuristic for unknown reasoning model variants", () => { + const body = {} + const result = enrichNimRequest(body, "nvidia/deepseek-ai/deepseek-r1") + expect(result.chat_template_kwargs).toEqual({ enable_thinking: true, thinking: true }) +}) + +test("logs warning when user kwargs conflict with required kwargs", () => { + const body = { chat_template_kwargs: { enable_thinking: false } } + const warnings: string[] = [] + const log = (msg: string) => warnings.push(msg) + enrichNimRequest(body, "nvidia/deepseek-ai/deepseek-v4-flash", log) + expect(warnings.length).toBeGreaterThan(0) + expect(warnings[0]).toContain("enable_thinking") +}) + +test("does not mutate original body", () => { + const body = { existing: "value" } + const result = enrichNimRequest(body, "nvidia/deepseek-ai/deepseek-v4-flash") + expect(body).toEqual({ existing: "value" }) + expect(result).not.toBe(body) +}) + +// ─── fetchWithNimDefense (retry wrapper) ──────────────────────────── + +test("passes through successful response unchanged", async () => { + const mockResponse = new Response(JSON.stringify({ ok: true }), { + status: 200, + headers: { "content-type": "application/json" }, + }) + const result = await fetchWithNimDefense(async () => mockResponse, "test-model", { maxRetries: 0 }) + expect(result.status).toBe(200) + const body = await result.json() + expect(body.ok).toBe(true) +}) + +test("retries on HTTP 429 and eventually succeeds", async () => { + let attempt = 0 + const sleepFn = async (ms: number) => { /* no-op for test speed */ } + const result = await fetchWithNimDefense(async () => { + attempt++ + if (attempt <= 2) return new Response("rate limit", { status: 429, headers: { "content-type": "text/plain" } }) + return new Response(JSON.stringify({ ok: true }), { status: 200, headers: { "content-type": "application/json" } }) + }, "test-model", { maxRetries: 3, baseDelay: 1, sleepFn, log: undefined }) + expect(attempt).toBe(3) + const body = await result.json() + expect(body.ok).toBe(true) +}) + +test("passes through HTTP 401 without retrying", async () => { + let attempt = 0 + const result = await fetchWithNimDefense(async () => { + attempt++ + return new Response("Unauthorized", { status: 401 }) + }, "test-model", { maxRetries: 3, baseDelay: 1 }) + expect(result.status).toBe(401) + expect(attempt).toBe(1) +}) + +test("passes through HTTP 403 without retrying", async () => { + let attempt = 0 + const result = await fetchWithNimDefense(async () => { + attempt++ + return new Response("Forbidden", { status: 403 }) + }, "test-model", { maxRetries: 3, baseDelay: 1 }) + expect(result.status).toBe(403) + expect(attempt).toBe(1) +}) + +test("passes through HTTP 404 without retrying", async () => { + let attempt = 0 + const result = await fetchWithNimDefense(async () => { + attempt++ + return new Response("Not Found", { status: 404 }) + }, "test-model", { maxRetries: 3, baseDelay: 1 }) + expect(result.status).toBe(404) + expect(attempt).toBe(1) +}) + +test("exhausts retries and throws with summary", async () => { + let attempts = 0 + const sleepFn = async (ms: number) => { /* no-op */ } + await expect( + fetchWithNimDefense(async () => { + attempts++ + return new Response("unavailable", { + status: 200, + headers: { "content-type": "text/plain" }, + }) + }, "test-model", { maxRetries: 3, baseDelay: 1, sleepFn }), + ).rejects.toThrow(/NIM retry exhausted/) + expect(attempts).toBe(3) +}) + +test("respects external abort signal", async () => { + const ctl = new AbortController() + ctl.abort() + await expect( + fetchWithNimDefense(async () => new Response("ok"), "test-model", { signal: ctl.signal, maxRetries: 0 }), + ).rejects.toThrow() +}) + +test("uses exponential backoff with full jitter", async () => { + const delays: number[] = [] + const sleepFn = async (ms: number) => { delays.push(ms) } + let attempt = 0 + await expect( + fetchWithNimDefense(async () => { + attempt++ + return new Response("unavailable", { + status: 200, + headers: { "content-type": "text/plain" }, + }) + }, "test-model", { maxRetries: 3, baseDelay: 1000, sleepFn }), + ).rejects.toThrow() + // delays should increase (each is random(0, cap) so just check they're roughly in range) + expect(delays.length).toBe(2) // attempt 0 delays, attempt 1 delays + expect(delays[0]).toBeLessThanOrEqual(1000) + // Attempt 2 (last) doesn't delay since it exhausts +}) + +test("supports config overrides", async () => { + let attempts = 0 + const sleepFn = async (ms: number) => { /* no-op */ } + await expect( + fetchWithNimDefense(async () => { + attempts++ + return new Response("unavailable", { + status: 200, + headers: { "content-type": "text/plain" }, + }) + }, "test-model", { maxRetries: 2, baseDelay: 10, sleepFn }), + ).rejects.toThrow() + expect(attempts).toBe(2) +}) + +test("does not apply retry to non-NVIDIA providers (pass-through)", async () => { + // fetchWithNimDefense is provider-agnostic — the caller decides which provider gets retry + const mockResponse = new Response(JSON.stringify({ ok: true }), { status: 200 }) + const result = await fetchWithNimDefense(async () => mockResponse, "any-provider", { maxRetries: 0 }) + expect(result.status).toBe(200) +})