Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions proxy/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ func (t *anthropicStreamTranslator) translateEvent(eventData []byte) []anthropic
case "response.reasoning_summary_text.delta", "response.reasoning_text.delta":
return t.handleThinkingDelta(eventData)

case "response.function_call_arguments.delta":
case "response.function_call_arguments.delta", "response.custom_tool_call_input.delta":
return t.handleToolInputDelta(eventData)

case "response.output_text.done", "response.reasoning_summary_text.done",
Expand Down Expand Up @@ -645,9 +645,12 @@ func (t *anthropicStreamTranslator) handleOutputItemAdded(data []byte) []anthrop
},
})

case "function_call":
case "function_call", "custom_tool_call":
events = append(events, t.closeCurrentBlock()...)
callID := fromCodexCallID(gjson.GetBytes(data, "item.call_id").String())
if callID == "" {
callID = fromCodexCallID(gjson.GetBytes(data, "item.id").String())
}
name := gjson.GetBytes(data, "item.name").String()
idx := t.contentBlockIndex
t.contentBlockIndex++
Expand Down Expand Up @@ -1077,11 +1080,17 @@ func buildAnthropicResponseFromCompleted(completedData []byte, model string) *an
return true
})

case "function_call":
// function_call → tool_use block
case "function_call", "custom_tool_call":
// function_call/custom_tool_call → tool_use block
callID := fromCodexCallID(item.Get("call_id").String())
if callID == "" {
callID = fromCodexCallID(item.Get("id").String())
}
name := item.Get("name").String()
args := item.Get("arguments").String()
if itemType == "custom_tool_call" {
args = item.Get("input").String()
}
if cleaned := sanitizeToolInputJSON(name, args); cleaned != "" {
args = cleaned
} else {
Expand Down
36 changes: 36 additions & 0 deletions proxy/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,42 @@ func TestAnthropicStreamTranslator_ToolInputBufferedAndCleaned(t *testing.T) {
}
}

func TestAnthropicStreamTranslator_CustomToolCallInputDelta(t *testing.T) {
tr := newAnthropicStreamTranslator("claude-sonnet-4-5")
tr.translateEvent([]byte(`{"type":"response.created"}`))
tr.translateEvent([]byte(`{
"type":"response.output_item.added",
"item":{"type":"custom_tool_call","id":"call_custom","name":"CustomTool"}
}`))

streamed := tr.translateEvent([]byte(`{
"type":"response.custom_tool_call_input.delta",
"delta":"{\"query\":\"hello\"}"
}`))
for _, evt := range streamed {
if evt.Type == "content_block_delta" {
t.Fatalf("expected no content_block_delta during streaming, got %+v", evt)
}
}

closing := tr.translateEvent([]byte(`{"type":"response.output_item.done"}`))
var sawDelta bool
for _, evt := range closing {
if evt.Type == "content_block_delta" {
sawDelta = true
if evt.Delta == nil || evt.Delta.Type != "input_json_delta" {
t.Fatalf("expected input_json_delta, got %+v", evt.Delta)
}
if !jsonEqual(t, evt.Delta.PartialJSON, `{"query":"hello"}`) {
t.Fatalf("custom tool input = %q", evt.Delta.PartialJSON)
}
}
}
if !sawDelta {
t.Fatalf("expected custom tool input_json_delta on close")
}
}

func TestAnthropicResponseAccumulatorUsesStreamDeltasWhenCompletedOutputIsEmpty(t *testing.T) {
tr := newAnthropicStreamTranslator("claude-sonnet-4-5")
acc := newAnthropicResponseAccumulator("claude-sonnet-4-5")
Expand Down
4 changes: 2 additions & 2 deletions proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3144,7 +3144,7 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
ttftRecorded = true
}
// 累计 delta 字符数(文本 + function call 参数)
if eventType == "response.output_text.delta" || eventType == "response.function_call_arguments.delta" {
if eventType == "response.output_text.delta" || isCodexToolInputDeltaEvent(eventType) {
deltaCharCount += len(parsed.Get("delta").String())
}
if eventType == "response.completed" {
Expand Down Expand Up @@ -3226,7 +3226,7 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
fullContent.WriteString(delta)
case "response.reasoning_summary_text.delta", "response.reasoning_text.delta":
fullReasoning.WriteString(parsed.Get("delta").String())
case "response.function_call_arguments.delta":
case "response.function_call_arguments.delta", "response.custom_tool_call_input.delta":
deltaCharCount += len(parsed.Get("delta").String())
case "response.completed":
usage = extractUsageFromResult(parsed.Get("response.usage"))
Expand Down
4 changes: 2 additions & 2 deletions proxy/handler_anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ func (h *Handler) Messages(c *gin.Context) {
}

// 累计 delta 字符数
if eventType == "response.output_text.delta" || eventType == "response.function_call_arguments.delta" {
if eventType == "response.output_text.delta" || isCodexToolInputDeltaEvent(eventType) {
deltaCharCount += len(parsed.Get("delta").String())
}

Expand Down Expand Up @@ -462,7 +462,7 @@ func (h *Handler) Messages(c *gin.Context) {
firstTokenMs = int(time.Since(start).Milliseconds())
ttftRecorded = true
}
if eventType == "response.output_text.delta" || eventType == "response.function_call_arguments.delta" {
if eventType == "response.output_text.delta" || isCodexToolInputDeltaEvent(eventType) {
deltaCharCount += len(parsed.Get("delta").String())
}
if eventType == "response.completed" {
Expand Down
109 changes: 100 additions & 9 deletions proxy/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,47 @@ func PrepareResponsesWebSocketBody(rawBody []byte) ([]byte, string) {
})
}

const codexReasoningEncryptedContentInclude = "reasoning.encrypted_content"

func ensureDefaultCodexInclude(body map[string]any) {
if body == nil {
return
}
if _, ok := body["include"]; !ok {
body["include"] = []string{codexReasoningEncryptedContentInclude}
}
}

func ensureCodexReasoningInclude(body map[string]any) {
if body == nil {
return
}
if _, ok := body["reasoning"]; !ok {
return
}
if _, ok := body["include"]; !ok {
body["include"] = []string{codexReasoningEncryptedContentInclude}
return
}

switch includes := body["include"].(type) {
case []any:
for _, item := range includes {
if s, ok := item.(string); ok && s == codexReasoningEncryptedContentInclude {
return
}
}
body["include"] = append(includes, codexReasoningEncryptedContentInclude)
case []string:
for _, item := range includes {
if item == codexReasoningEncryptedContentInclude {
return
}
}
body["include"] = append(includes, codexReasoningEncryptedContentInclude)
}
}

func prepareResponsesBodyWithOptions(rawBody []byte, opts responsesBodyPrepareOptions) ([]byte, string) {
var body map[string]any
if err := json.Unmarshal(rawBody, &body); err != nil {
Expand All @@ -1500,9 +1541,7 @@ func prepareResponsesBodyWithOptions(rawBody []byte, opts responsesBodyPrepareOp
if opts.forceStoreFalse {
body["store"] = false
}
if _, ok := body["include"]; !ok {
body["include"] = []string{"reasoning.encrypted_content"}
}
ensureDefaultCodexInclude(body)

normalizeResponsesImageOnlyModel(body)
normalizeResponsesPromptCompat(body)
Expand Down Expand Up @@ -1537,6 +1576,7 @@ func prepareResponsesBodyWithOptions(rawBody []byte, opts responsesBodyPrepareOp
}
}
}
ensureCodexReasoningInclude(body)

// 4. service tier 清理(兼容客户端字段;只有 fast/priority 会显式传给 Codex 上游)
delete(body, "serviceTier")
Expand Down Expand Up @@ -2544,6 +2584,33 @@ func newReasoningChunk(id, model string, created int64, reasoning string) []byte
return b
}

func isCodexToolCallItemType(itemType string) bool {
switch itemType {
case "function_call", "custom_tool_call":
return true
default:
return false
}
}

func isCodexToolInputDeltaEvent(eventType string) bool {
switch eventType {
case "response.function_call_arguments.delta", "response.custom_tool_call_input.delta":
return true
default:
return false
}
}

func isCodexToolInputDoneEvent(eventType string) bool {
switch eventType {
case "response.function_call_arguments.done", "response.custom_tool_call_input.done":
return true
default:
return false
}
}

// newToolCallAnnouncementChunk 构建 tool call 首块(含 id、type、function.name)
func newToolCallAnnouncementChunk(id, model string, created int64, tcIndex int, callID, funcName string) []byte {
chunk := openAIStreamChunk{
Expand Down Expand Up @@ -2622,6 +2689,9 @@ func TranslateStreamChunk(eventData []byte, model string, chunkID string, create
delta := gjson.GetBytes(eventData, "delta").String()
return newReasoningChunk(chunkID, model, created, delta), false

case "response.custom_tool_call_input.delta", "response.custom_tool_call_input.done":
return nil, false

case "response.completed":
usage := extractUsage(eventData)
return newFinalChunk(chunkID, model, created, "stop", usage), true
Expand Down Expand Up @@ -2698,30 +2768,42 @@ func (st *StreamTranslator) TranslateParsed(parsed gjson.Result) ([]byte, bool)

case "response.output_item.added":
itemType := parsed.Get("item.type").String()
if itemType != "function_call" {
if !isCodexToolCallItemType(itemType) {
return nil, false
}
callID := parsed.Get("item.call_id").String()
if callID == "" {
callID = parsed.Get("item.id").String()
}
name := parsed.Get("item.name").String()
itemID := parsed.Get("item.id").String()
if itemID == "" {
itemID = callID
}

tcIdx := st.nextIdx
st.toolCallMap[itemID] = tcIdx
if callID != "" && callID != itemID {
st.toolCallMap[callID] = tcIdx
}
st.nextIdx++
st.HasToolCalls = true

return newToolCallAnnouncementChunk(st.ChunkID, st.Model, st.Created, tcIdx, callID, name), false

case "response.function_call_arguments.delta":
case "response.function_call_arguments.delta", "response.custom_tool_call_input.delta":
itemID := parsed.Get("item_id").String()
if itemID == "" {
itemID = parsed.Get("call_id").String()
}
tcIdx, ok := st.toolCallMap[itemID]
if !ok {
return nil, false
}
delta := parsed.Get("delta").String()
return newToolCallDeltaChunk(st.ChunkID, st.Model, st.Created, tcIdx, delta), false

case "response.function_call_arguments.done":
case "response.function_call_arguments.done", "response.custom_tool_call_input.done":
return nil, false

case "response.completed":
Expand Down Expand Up @@ -2897,11 +2979,20 @@ func ExtractToolCallsFromOutput(eventData []byte) []ToolCallResult {
return nil
}
output.ForEach(func(_, item gjson.Result) bool {
if item.Get("type").String() == "function_call" {
itemType := item.Get("type").String()
if isCodexToolCallItemType(itemType) {
callID := item.Get("call_id").String()
if callID == "" {
callID = item.Get("id").String()
}
arguments := item.Get("arguments").String()
if itemType == "custom_tool_call" {
arguments = item.Get("input").String()
}
toolCalls = append(toolCalls, ToolCallResult{
ID: item.Get("call_id").String(),
ID: callID,
Name: item.Get("name").String(),
Arguments: item.Get("arguments").String(),
Arguments: arguments,
})
}
return true
Expand Down
Loading
Loading