Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ const installRepo = (): InMemoryRepo => {
test('generate native chat-completions target calls provider.callChatCompletions', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
const result = await chatCompletionsAttempt.generate({
payload: makePayload(),
Expand All @@ -132,7 +132,7 @@ test('generate native chat-completions target calls provider.callChatCompletions
test('generate translate-to-messages branch routes through messagesAttempt', async () => {
installRepo();
const callMessages = vi.fn(async (): Promise<ProviderStreamResult<MessagesStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeMessagesEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeMessagesEvents()), modelKey: 'k', headers: new Headers(),
}));
const result = await chatCompletionsAttempt.generate({
payload: makePayload(),
Expand Down Expand Up @@ -161,6 +161,7 @@ test('generate translate-to-responses branch routes through responsesAttempt', a
ok: true,
events: makeProtocolFrames([{ type: 'response.completed', sequence_number: 0, response: respResp }]),
modelKey: 'k',
headers: new Headers(),
}));
const result = await chatCompletionsAttempt.generate({
payload: makePayload(),
Expand All @@ -180,7 +181,7 @@ test('generate inherits invocation headers across translation to Messages', asyn
let observedHeaders: Record<string, string> | undefined;
const callMessages = vi.fn(async (_model: unknown, _body: unknown, _signal?: AbortSignal, headers?: Record<string, string>): Promise<ProviderStreamResult<MessagesStreamEvent>> => {
observedHeaders = headers;
return { ok: true, events: makeProtocolFrames(makeMessagesEvents()), modelKey: 'k' };
return { ok: true, events: makeProtocolFrames(makeMessagesEvents()), modelKey: 'k', headers: new Headers() };
});
const result = await chatCompletionsAttempt.generate({
payload: makePayload(),
Expand Down Expand Up @@ -212,6 +213,7 @@ test('generate inherits invocation headers across translation to Responses', asy
ok: true,
events: makeProtocolFrames([{ type: 'response.completed', sequence_number: 0, response: respResp }]),
modelKey: 'k',
headers: new Headers(),
};
});
const result = await chatCompletionsAttempt.generate({
Expand All @@ -226,3 +228,25 @@ test('generate inherits invocation headers across translation to Responses', asy
await collectEvents(result.events);
assertEquals(observedHeaders?.['x-test'], 'abc');
});

test('generate propagates upstream response headers onto the EventResult so respond can forward them', async () => {
installRepo();
const upstreamHeaders = new Headers({
'anthropic-ratelimit-unified-status': 'allowed',
'cf-ray': 'cf_ray_cc',
});
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: upstreamHeaders,
}));
const result = await chatCompletionsAttempt.generate({
payload: makePayload(),
ctx: makeGatewayCtx(),
store: createNonResponsesSourceStore(API_KEY_ID),
candidate: makeCandidate({ callChatCompletions }),
});
assertEquals(result.type, 'events');
if (result.type !== 'events') throw new Error('unreachable');
assertEquals(result.headers?.get('anthropic-ratelimit-unified-status'), 'allowed');
assertEquals(result.headers?.get('cf-ray'), 'cf_ray_cc');
await collectEvents(result.events);
});
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ const makeCandidate = (overrides: {
test('POST /v1/chat/completions streams a successful SSE body', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
queueCandidates([makeCandidate({ callChatCompletions })]);

Expand All @@ -127,7 +127,7 @@ test('POST /v1/chat/completions streams a successful SSE body', async () => {
test('POST /v1/chat/completions returns a single JSON body when stream is omitted', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
queueCandidates([makeCandidate({ callChatCompletions })]);

Expand All @@ -147,7 +147,7 @@ test('POST /v1/chat/completions returns a single JSON body when stream is omitte
test('POST /v1/chat/completions omits the usage-only chunk unless stream_options.include_usage is set', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
queueCandidates([makeCandidate({ callChatCompletions })]);

Expand All @@ -166,7 +166,7 @@ test('POST /v1/chat/completions omits the usage-only chunk unless stream_options
test('POST /v1/chat/completions emits the usage-only chunk when stream_options.include_usage is true', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
queueCandidates([makeCandidate({ callChatCompletions })]);

Expand Down Expand Up @@ -194,7 +194,7 @@ test('POST /v1/chat/completions emits the usage-only chunk when stream_options.i
test('POST /v1/chat/completions does not write any non-auth Hono context slot', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
queueCandidates([makeCandidate({ callChatCompletions })]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { CHAT_COMPLETIONS_MISSING_TERMINAL_MESSAGE, collectChatCompletionsProtoc
import { chatCompletionsProtocolFrameToSSEFrame } from './events/to-sse.ts';
import { tokenUsage } from '../../shared/telemetry/usage.ts';
import type { GatewayCtx } from '../shared/gateway-ctx.ts';
import { SourceStreamState, eventResultMetadata, plainResultToResponse, recordPerformance, recordUsage } from '../shared/respond.ts';
import { SourceStreamState, eventResultMetadata, forwardUpstreamHeaders, mergeForwardedUpstreamHeaders, plainResultToResponse, recordPerformance, recordUsage } from '../shared/respond.ts';
import { type StreamCompletion, writeSSEFrames } from '../shared/stream/sse.ts';
import type { ChatCompletionsStreamEvent, ChatCompletionsResult } from '@floway-dev/protocols/chat-completions';
import { chatCompletionsErrorPayloadMessage } from '@floway-dev/protocols/chat-completions';
Expand Down Expand Up @@ -47,13 +47,14 @@ export const respondChatCompletions = async (
const usage = response.usage ? tokenUsageFromChatCompletionsUsage(response.usage) : null;
await recordUsage(ctx, metadata.modelIdentity, usage);
recordPerformance(ctx, metadata.performance, state.failed);
return { success: true, response: Response.json(response) };
return { success: true, response: Response.json(response, { headers: mergeForwardedUpstreamHeaders(undefined, result.headers) }) };
} catch (error) {
recordPerformance(ctx, result.performance, true);
return { success: false, response: internalChatCompletionsErrorResponse(502, toInternalDebugError(error, 'chat-completions')) };
}
}

forwardUpstreamHeaders(c, result.headers);
const response = streamSSE(c, async stream => {
let completion: StreamCompletion = 'error';
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ const collectEvents = async <TEvent>(events: AsyncIterable<ProtocolFrame<TEvent>
test('generate routes a native Chat Completions candidate end to end', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'test-model-key',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'test-model-key', headers: new Headers(),
}));
queueCandidates([makeCandidate({ upstream: 'up_a', callChatCompletions })]);

Expand All @@ -166,7 +166,7 @@ test('generate routes a native Chat Completions candidate end to end', async ()
test('generate translates through the Messages target when only that endpoint is exposed', async () => {
installRepo();
const callMessages = vi.fn(async (): Promise<ProviderStreamResult<MessagesStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeMessagesResultEvents()), modelKey: 'messages-model-key',
ok: true, events: makeProtocolFrames(makeMessagesResultEvents()), modelKey: 'messages-model-key', headers: new Headers(),
}));
queueCandidates([makeCandidate({ upstream: 'up_m', targetApi: 'messages', callMessages })]);

Expand All @@ -185,7 +185,7 @@ test('generate translates through the Messages target when only that endpoint is
test('generate translates through the Responses target when only that endpoint is exposed', async () => {
installRepo();
const callResponses = vi.fn(async (): Promise<ProviderStreamResult<ResponsesStreamEvent>> => ({
ok: true, events: makeProtocolFrames([makeResponsesResultEvent()]), modelKey: 'responses-model-key',
ok: true, events: makeProtocolFrames([makeResponsesResultEvent()]), modelKey: 'responses-model-key', headers: new Headers(),
}));
queueCandidates([makeCandidate({ upstream: 'up_r', targetApi: 'responses', callResponses })]);

Expand All @@ -210,7 +210,7 @@ test('generate stops at the first candidate even when it yields an upstream erro
ok: false, response: firstError, modelKey: 'first-key',
}));
const secondCall = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'second-key',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'second-key', headers: new Headers(),
}));
queueCandidates([
makeCandidate({ upstream: 'up_a', callChatCompletions: firstCall }),
Expand All @@ -234,7 +234,7 @@ test('generate stops at the first candidate even when it yields an upstream erro
test('generate is a routing no-op when the payload carries no reasoning carriers (degenerate path)', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'test-model-key',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'test-model-key', headers: new Headers(),
}));
queueCandidates([
makeCandidate({ upstream: 'up_a', callChatCompletions }),
Expand Down
28 changes: 25 additions & 3 deletions packages/gateway/src/data-plane/llm/gemini/attempt_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ const collectEvents = async <TEvent>(events: AsyncIterable<ProtocolFrame<TEvent>
test('generate translates through Chat Completions when targetApi is chat-completions', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
const result = await geminiAttempt.generate({
payload: makePayload(),
Expand All @@ -141,7 +141,7 @@ test('generate translates through Chat Completions when targetApi is chat-comple
test('generate translates through Messages when targetApi is messages', async () => {
installRepo();
const callMessages = vi.fn(async (): Promise<ProviderStreamResult<MessagesStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeMessagesEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeMessagesEvents()), modelKey: 'k', headers: new Headers(),
}));
const result = await geminiAttempt.generate({
payload: makePayload(),
Expand All @@ -159,7 +159,7 @@ test('generate translates through Messages when targetApi is messages', async ()
test('generate translates through Responses when targetApi is responses', async () => {
installRepo();
const callResponses = vi.fn(async (): Promise<ProviderStreamResult<ResponsesStreamEvent>> => ({
ok: true, events: makeProtocolFrames([makeResponsesResultEvent()]), modelKey: 'k',
ok: true, events: makeProtocolFrames([makeResponsesResultEvent()]), modelKey: 'k', headers: new Headers(),
}));
const result = await geminiAttempt.generate({
payload: makePayload(),
Expand Down Expand Up @@ -260,3 +260,25 @@ test('countTokens refuses a non-messages candidate', async () => {
if (!(thrown instanceof Error)) throw new Error('expected an Error to be thrown');
assertEquals(thrown.message.includes("targetApi='messages'"), true);
});

test('generate propagates upstream response headers through the chat-completions translation', async () => {
installRepo();
const upstreamHeaders = new Headers({
'anthropic-ratelimit-unified-status': 'allowed',
'x-request-id': 'req_gemini_xyz',
});
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: upstreamHeaders,
}));
const result = await geminiAttempt.generate({
payload: makePayload(),
ctx: makeGatewayCtx(),
store: createNonResponsesSourceStore(API_KEY_ID),
candidate: makeCandidate({ targetApi: 'chat-completions', callChatCompletions }),
});
assertEquals(result.type, 'events');
if (result.type !== 'events') throw new Error('unreachable');
assertEquals(result.headers?.get('anthropic-ratelimit-unified-status'), 'allowed');
assertEquals(result.headers?.get('x-request-id'), 'req_gemini_xyz');
await collectEvents(result.events);
});
8 changes: 4 additions & 4 deletions packages/gateway/src/data-plane/llm/gemini/http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ const makeMessagesEvents = (): readonly MessagesStreamEvent[] => [
test('POST /v1beta/models/:model:generateContent returns a single JSON body for non-stream generate', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
queueCandidates([makeCandidate({ callChatCompletions })]);

Expand All @@ -133,7 +133,7 @@ test('POST /v1beta/models/:model:generateContent returns a single JSON body for
test('POST /v1beta/models/:model:streamGenerateContent streams a Gemini SSE body', async () => {
installRepo();
const callChatCompletions = vi.fn(async (): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers(),
}));
queueCandidates([makeCandidate({ callChatCompletions })]);

Expand Down Expand Up @@ -196,7 +196,7 @@ test('POST /v1beta/models/:model:countTokens accepts the generateContentRequest
test('POST /v1beta/models/:model:generateContent translates through Messages target end to end', async () => {
installRepo();
const callMessages = vi.fn(async (): Promise<ProviderStreamResult<MessagesStreamEvent>> => ({
ok: true, events: makeProtocolFrames(makeMessagesEvents()), modelKey: 'k',
ok: true, events: makeProtocolFrames(makeMessagesEvents()), modelKey: 'k', headers: new Headers(),
}));
queueCandidates([makeCandidate({ targetApi: 'messages', callMessages })]);

Expand Down Expand Up @@ -234,7 +234,7 @@ test('POST /v1beta/models/models/:model:generateContent accepts the models/ pref
let resolvedModel: string | undefined;
const callChatCompletions = vi.fn(async (model): Promise<ProviderStreamResult<ChatCompletionsStreamEvent>> => {
resolvedModel = (model as { id: string }).id;
return { ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k' };
return { ok: true, events: makeProtocolFrames(makeChatCompletionsEvents()), modelKey: 'k', headers: new Headers() };
});
queueCandidates([makeCandidate({ callChatCompletions })]);

Expand Down
5 changes: 3 additions & 2 deletions packages/gateway/src/data-plane/llm/gemini/respond.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { GEMINI_MISSING_TERMINAL_MESSAGE, isGeminiErrorEvent, isGeminiTerminalEv
import { geminiProtocolFrameToSSEFrame } from './events/to-sse.ts';
import { tokenUsage } from '../../shared/telemetry/usage.ts';
import type { GatewayCtx } from '../shared/gateway-ctx.ts';
import { SourceStreamState, eventResultMetadata, plainResultToResponse, recordPerformance, recordUsage } from '../shared/respond.ts';
import { SourceStreamState, eventResultMetadata, forwardUpstreamHeaders, mergeForwardedUpstreamHeaders, plainResultToResponse, recordPerformance, recordUsage } from '../shared/respond.ts';
import { type StreamCompletion, writeSSEFrames } from '../shared/stream/sse.ts';
import { type ProtocolFrame, sseCommentFrame, sseFrame } from '@floway-dev/protocols/common';
import type { GeminiErrorResponse, GeminiResult, GeminiStreamEvent, GeminiUsageMetadata } from '@floway-dev/protocols/gemini';
Expand Down Expand Up @@ -45,13 +45,14 @@ export const respondGemini = async (
const metadata = await eventResultMetadata(result);
await recordUsage(ctx, metadata.modelIdentity, tokenUsageFromGeminiResponse(response));
recordPerformance(ctx, metadata.performance, state.failed);
return { success: true, response: Response.json(response) };
return { success: true, response: Response.json(response, { headers: mergeForwardedUpstreamHeaders(undefined, result.headers) }) };
} catch (error) {
recordPerformance(ctx, result.performance, true);
return { success: false, response: geminiCollectErrorResponse(error) };
}
}

forwardUpstreamHeaders(c, result.headers);
const response = streamSSE(c, async stream => {
let completion: StreamCompletion = 'error';
try {
Expand Down
Loading