Skip to content

Commit 06ef911

Browse files
committed
Only send the delta on the partial tool call
Signed-off-by: Djordje Lukic <djordje.lukic@docker.com>
1 parent 99f833d commit 06ef911

8 files changed

Lines changed: 165 additions & 14 deletions

File tree

pkg/app/app.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package app
22

33
import (
4+
"cmp"
45
"context"
56
"encoding/base64"
67
"errors"
@@ -997,12 +998,24 @@ func (a *App) mergeEvents(events []tea.Msg) []tea.Msg {
997998
result = append(result, merged)
998999

9991000
case *runtime.PartialToolCallEvent:
1000-
// For PartialToolCallEvent, keep only the latest one per tool call ID
1001-
// Only merge consecutive events with the same ID
1001+
// For PartialToolCallEvent, merge consecutive events with the same tool call ID
1002+
// by concatenating argument deltas
10021003
latest := ev
10031004
for i+1 < len(events) {
10041005
if next, ok := events[i+1].(*runtime.PartialToolCallEvent); ok && next.ToolCall.ID == ev.ToolCall.ID {
1005-
latest = next
1006+
latest = &runtime.PartialToolCallEvent{
1007+
Type: ev.Type,
1008+
ToolCall: tools.ToolCall{
1009+
ID: ev.ToolCall.ID,
1010+
Type: ev.ToolCall.Type,
1011+
Function: tools.FunctionCall{
1012+
Name: cmp.Or(next.ToolCall.Function.Name, latest.ToolCall.Function.Name),
1013+
Arguments: latest.ToolCall.Function.Arguments + next.ToolCall.Function.Arguments,
1014+
},
1015+
},
1016+
ToolDefinition: cmp.Or(latest.ToolDefinition, next.ToolDefinition),
1017+
AgentContext: ev.AgentContext,
1018+
}
10061019
i++
10071020
} else {
10081021
break

pkg/cli/runner_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,4 @@ func TestMaxIterationsSafetyCapJSONMode(t *testing.T) {
237237
}
238238
assert.Equal(t, resumes[maxAutoExtensions].Type, runtime.ResumeTypeReject)
239239
}
240+

pkg/runtime/event.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,20 @@ func UserMessage(message, sessionID string, multiContent []chat.MessagePart, ses
6464
type PartialToolCallEvent struct {
6565
Type string `json:"type"`
6666
ToolCall tools.ToolCall `json:"tool_call"`
67-
ToolDefinition tools.Tool `json:"tool_definition"`
67+
ToolDefinition *tools.Tool `json:"tool_definition,omitempty"`
6868
AgentContext
6969
}
7070

7171
func PartialToolCall(toolCall tools.ToolCall, toolDefinition tools.Tool, agentName string) Event {
72+
var toolDef *tools.Tool
73+
if toolDefinition.Name != "" {
74+
def := toolDefinition
75+
toolDef = &def
76+
}
7277
return &PartialToolCallEvent{
7378
Type: "partial_tool_call",
7479
ToolCall: toolCall,
75-
ToolDefinition: toolDefinition,
80+
ToolDefinition: toolDef,
7681
AgentContext: newAgentContext(agentName),
7782
}
7883
}

pkg/runtime/runtime_response_api_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package runtime
22

33
import (
4+
"encoding/json"
45
"testing"
56

67
"github.com/stretchr/testify/require"
78

89
"github.com/docker/docker-agent/pkg/session"
10+
"github.com/docker/docker-agent/pkg/tools"
911
)
1012

1113
// TestResponseAPIToolCallHandling verifies that tool calls from the Response API
@@ -74,3 +76,95 @@ func TestResponseAPIMultipleToolCalls(t *testing.T) {
7476
}
7577
require.ElementsMatch(t, []string{"search", "calculate"}, toolCalls, "Expected both tool calls")
7678
}
79+
80+
func TestPartialToolCallEventsContainOnlyNewArgumentBytes(t *testing.T) {
81+
stream := newStreamBuilder().
82+
AddToolCallName("call_abc", "write_file").
83+
AddToolCallArguments("call_abc", `{"path":"story.md"`).
84+
AddToolCallArguments("call_abc", `,"content":"Once upon a time"}`).
85+
AddStopWithUsage(10, 15).
86+
Build()
87+
88+
sess := session.New(session.WithUserMessage("Write a story"))
89+
events := runSession(t, sess, stream)
90+
91+
var partials []*PartialToolCallEvent
92+
for _, event := range events {
93+
if ev, ok := event.(*PartialToolCallEvent); ok {
94+
partials = append(partials, ev)
95+
}
96+
}
97+
98+
require.Len(t, partials, 3)
99+
require.Equal(t, "write_file", partials[0].ToolCall.Function.Name)
100+
require.Empty(t, partials[0].ToolCall.Function.Arguments)
101+
require.Equal(t, `{"path":"story.md"`, partials[1].ToolCall.Function.Arguments)
102+
require.Nil(t, partials[1].ToolDefinition)
103+
require.Equal(t, `,"content":"Once upon a time"}`, partials[2].ToolCall.Function.Arguments)
104+
require.Nil(t, partials[2].ToolDefinition)
105+
106+
secondJSON, err := json.Marshal(partials[1])
107+
require.NoError(t, err)
108+
require.NotContains(t, string(secondJSON), `"tool_definition"`)
109+
}
110+
111+
func TestPartialToolCallEventJSONIncludesToolDefinitionOnlyWhenPresent(t *testing.T) {
112+
toolDef := &tools.Tool{Name: "write_file", Description: "Create file"}
113+
withDef := &PartialToolCallEvent{
114+
Type: "partial_tool_call",
115+
ToolCall: tools.ToolCall{ID: "call_1", Type: "function", Function: tools.FunctionCall{Name: "write_file"}},
116+
ToolDefinition: toolDef,
117+
AgentContext: newAgentContext("root"),
118+
}
119+
withoutDef := &PartialToolCallEvent{
120+
Type: "partial_tool_call",
121+
ToolCall: tools.ToolCall{ID: "call_1", Type: "function", Function: tools.FunctionCall{Name: "write_file", Arguments: `{"path":"story.md"}`}},
122+
AgentContext: newAgentContext("root"),
123+
}
124+
125+
withDefJSON, err := json.Marshal(withDef)
126+
require.NoError(t, err)
127+
require.Contains(t, string(withDefJSON), `"tool_definition"`)
128+
129+
withoutDefJSON, err := json.Marshal(withoutDef)
130+
require.NoError(t, err)
131+
require.NotContains(t, string(withoutDefJSON), `"tool_definition"`)
132+
}
133+
134+
func TestPartialToolCallEventsNormalizeCumulativeArgumentSnapshots(t *testing.T) {
135+
// Some providers resend the full accumulated argument buffer on each tool-call
136+
// update. PartialToolCallEvent should still emit only the new suffix bytes.
137+
stream := newStreamBuilder().
138+
AddToolCallName("call_abc", "write_file").
139+
AddToolCallArguments("call_abc", `{"path":"story.md"`).
140+
AddToolCallArguments("call_abc", `{"path":"story.md","content":"Once upon a time"}`).
141+
AddStopWithUsage(10, 15).
142+
Build()
143+
144+
sess := session.New(session.WithUserMessage("Write a story"))
145+
events := runSession(t, sess, stream)
146+
147+
var partials []*PartialToolCallEvent
148+
for _, event := range events {
149+
if ev, ok := event.(*PartialToolCallEvent); ok {
150+
partials = append(partials, ev)
151+
}
152+
}
153+
154+
require.Len(t, partials, 3)
155+
require.Equal(t, "write_file", partials[0].ToolCall.Function.Name)
156+
require.Empty(t, partials[0].ToolCall.Function.Arguments)
157+
require.Equal(t, `{"path":"story.md"`, partials[1].ToolCall.Function.Arguments)
158+
require.Equal(t, `,"content":"Once upon a time"}`, partials[2].ToolCall.Function.Arguments)
159+
160+
messages := sess.GetAllMessages()
161+
var foundToolCall bool
162+
for _, msg := range messages {
163+
if msg.Message.Role == "assistant" && len(msg.Message.ToolCalls) > 0 {
164+
foundToolCall = true
165+
require.Equal(t, `{"path":"story.md","content":"Once upon a time"}`,
166+
msg.Message.ToolCalls[0].Function.Arguments)
167+
}
168+
}
169+
require.True(t, foundToolCall, "Expected to find complete tool call in session messages")
170+
}

pkg/runtime/streaming.go

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,39 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
140140
if delta.Function.Name != "" {
141141
tc.Function.Name = delta.Function.Name
142142
}
143-
if delta.Function.Arguments != "" {
144-
tc.Function.Arguments += delta.Function.Arguments
143+
argsDelta := delta.Function.Arguments
144+
if argsDelta != "" {
145+
// Most providers stream argument deltas, but some may resend the
146+
// full accumulated argument buffer so far. Normalize both shapes to
147+
// a true suffix delta so PartialToolCallEvent always carries only
148+
// newly received bytes and the accumulated tool call remains valid.
149+
if strings.HasPrefix(argsDelta, tc.Function.Arguments) {
150+
argsDelta = argsDelta[len(tc.Function.Arguments):]
151+
tc.Function.Arguments = delta.Function.Arguments
152+
} else {
153+
tc.Function.Arguments += argsDelta
154+
}
145155
}
146156

147-
// Emit PartialToolCall once we have a name, and on subsequent argument deltas
148-
if tc.Function.Name != "" && (learningName || delta.Function.Arguments != "") {
149-
if !emittedPartial[delta.ID] || delta.Function.Arguments != "" {
150-
events <- PartialToolCall(*tc, toolDefMap[tc.Function.Name], a.Name())
157+
// Emit PartialToolCall once we have a name, and on subsequent argument deltas.
158+
// Only the newly received argument bytes are sent, not the full
159+
// accumulated arguments, to avoid re-transmitting the entire payload
160+
// on every token.
161+
if tc.Function.Name != "" && (learningName || argsDelta != "") {
162+
if !emittedPartial[delta.ID] || argsDelta != "" {
163+
partial := tools.ToolCall{
164+
ID: tc.ID,
165+
Type: tc.Type,
166+
Function: tools.FunctionCall{
167+
Name: tc.Function.Name,
168+
Arguments: argsDelta,
169+
},
170+
}
171+
toolDef := tools.Tool{}
172+
if !emittedPartial[delta.ID] {
173+
toolDef = toolDefMap[tc.Function.Name]
174+
}
175+
events <- PartialToolCall(partial, toolDef, a.Name())
151176
emittedPartial[delta.ID] = true
152177
}
153178
}

pkg/tui/components/messages/messages.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,11 @@ func (m *model) AddOrUpdateToolCall(agentName string, toolCall tools.ToolCall, t
13011301
if msg.Type == types.MessageTypeToolCall && msg.ToolCall.ID == toolCall.ID {
13021302
msg.ToolStatus = status
13031303
if toolCall.Function.Arguments != "" {
1304-
msg.ToolCall.Function.Arguments = toolCall.Function.Arguments
1304+
if status == types.ToolStatusPending {
1305+
msg.ToolCall.Function.Arguments += toolCall.Function.Arguments
1306+
} else {
1307+
msg.ToolCall.Function.Arguments = toolCall.Function.Arguments
1308+
}
13051309
}
13061310
m.invalidateItem(i)
13071311
return nil

pkg/tui/components/reasoningblock/reasoningblock.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,11 @@ func (m *Model) UpdateToolCall(toolCallID string, status types.ToolStatus, args
206206
}
207207
entry.msg.ToolStatus = status
208208
if args != "" {
209-
entry.msg.ToolCall.Function.Arguments = args
209+
if status == types.ToolStatusPending {
210+
entry.msg.ToolCall.Function.Arguments += args
211+
} else {
212+
entry.msg.ToolCall.Function.Arguments = args
213+
}
210214
}
211215
m.toolEntries[i] = entry
212216
return

pkg/tui/page/chat/runtime_events.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/docker/docker-agent/pkg/runtime"
1111
"github.com/docker/docker-agent/pkg/sound"
12+
"github.com/docker/docker-agent/pkg/tools"
1213
"github.com/docker/docker-agent/pkg/tui/components/notification"
1314
"github.com/docker/docker-agent/pkg/tui/components/sidebar"
1415
"github.com/docker/docker-agent/pkg/tui/core"
@@ -273,7 +274,11 @@ func (p *chatPage) handleStreamStopped(msg *runtime.StreamStoppedEvent) tea.Cmd
273274
// "pending" indicator (not animated) to show it's receiving data.
274275
func (p *chatPage) handlePartialToolCall(msg *runtime.PartialToolCallEvent) tea.Cmd {
275276
p.setPendingResponse(false)
276-
toolCmd := p.messages.AddOrUpdateToolCall(msg.AgentName, msg.ToolCall, msg.ToolDefinition, types.ToolStatusPending)
277+
var toolDef tools.Tool
278+
if msg.ToolDefinition != nil {
279+
toolDef = *msg.ToolDefinition
280+
}
281+
toolCmd := p.messages.AddOrUpdateToolCall(msg.AgentName, msg.ToolCall, toolDef, types.ToolStatusPending)
277282
return tea.Batch(toolCmd, p.messages.ScrollToBottom())
278283
}
279284

0 commit comments

Comments
 (0)