From 5daf79638e07d633315a1555418aa3e1244476fa Mon Sep 17 00:00:00 2001 From: Djordje Lukic Date: Tue, 5 May 2026 22:16:05 +0200 Subject: [PATCH] Fix finish_reason stop when tracking usage The server now correctly sends usage at the end of the stream the same way openai's API does. Signed-off-by: Djordje Lukic --- pkg/chatserver/agent.go | 7 +----- pkg/chatserver/openapi.json | 9 ++++++++ pkg/chatserver/server.go | 32 ++++++++++++++++++++++++++-- pkg/chatserver/server_test.go | 40 +++++++++++++++++++++++------------ pkg/chatserver/types.go | 14 +++++++++--- 5 files changed, 77 insertions(+), 25 deletions(-) diff --git a/pkg/chatserver/agent.go b/pkg/chatserver/agent.go index 752286cdf..8661c5a43 100644 --- a/pkg/chatserver/agent.go +++ b/pkg/chatserver/agent.go @@ -247,13 +247,8 @@ func runAgentLoop(ctx context.Context, rt runtime.Runtime, sess *session.Session return errors.Join(runErrs...) } -// sessionUsage extracts approximate token usage from a completed session, -// returning nil when nothing is known so we can omit the field entirely -// rather than reporting zeroes. +// sessionUsage extracts approximate token usage from a completed session func sessionUsage(sess *session.Session) *ChatCompletionUsage { - if sess.InputTokens == 0 && sess.OutputTokens == 0 { - return nil - } return &ChatCompletionUsage{ PromptTokens: sess.InputTokens, CompletionTokens: sess.OutputTokens, diff --git a/pkg/chatserver/openapi.json b/pkg/chatserver/openapi.json index e34ded003..2856d5dc2 100644 --- a/pkg/chatserver/openapi.json +++ b/pkg/chatserver/openapi.json @@ -118,6 +118,15 @@ "items": { "$ref": "#/components/schemas/ChatCompletionMessage" } }, "stream": { "type": "boolean", "default": false }, + "stream_options": { + "type": "object", + "properties": { + "include_usage": { + "type": "boolean", + "description": "When true and stream=true, emit an extra final chunk with usage and empty choices before [DONE]." + } + } + }, "temperature": { "type": "number", "minimum": 0, diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index 928e56636..60ab69d51 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -395,7 +395,7 @@ func (s *server) handleChatCompletions(c echo.Context) error { } if req.Stream { - err := s.streamChatCompletion(c, rt, sess, model) + err := s.streamChatCompletion(c, rt, sess, model, req.StreamOptions.IncludeUsage) s.maybeStoreConversation(conversationID, sess) return err } @@ -469,7 +469,7 @@ func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *sessio // The error return is reserved for future use (e.g. surfacing a write // failure to the request logger). Today every error is converted into an // in-band SSE error event, so the function always returns nil. -func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { //nolint:unparam // see comment +func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string, includeUsage bool) error { //nolint:unparam // see comment stream := newSSEStream(c.Response(), newChatID(), model) // Initial "role: assistant" delta so clients can start rendering. @@ -499,8 +499,14 @@ func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess * // from a normal completion. stream.sendError(runErr) stream.send(ChatCompletionStreamDelta{}, "error") + if includeUsage { + stream.sendUsage(sessionUsage(sess)) + } } else { stream.send(ChatCompletionStreamDelta{}, "stop") + if includeUsage { + stream.sendUsage(sessionUsage(sess)) + } } stream.done() return nil @@ -524,6 +530,28 @@ func newSSEStream(w http.ResponseWriter, id, model string) *sseStream { return &sseStream{w: w, id: id, model: model, created: time.Now().Unix()} } +func (s *sseStream) sendUsage(usage *ChatCompletionUsage) { + if usage == nil { + return + } + chunk := ChatCompletionStreamResponse{ + ID: s.id, + Object: "chat.completion.chunk", + Created: s.created, + Model: s.model, + Choices: []ChatCompletionStreamChoice{}, + Usage: usage, + } + data, err := json.Marshal(chunk) + if err != nil { + return + } + _, _ = fmt.Fprintf(s.w, "data: %s\n\n", data) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } +} + func (s *sseStream) send(delta ChatCompletionStreamDelta, finishReason string) { chunk := ChatCompletionStreamResponse{ ID: s.id, diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index e4df5074b..b3951d81d 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/chat" - "github.com/docker/docker-agent/pkg/session" ) func TestBuildSession_RequiresUserMessage(t *testing.T) { @@ -120,19 +119,6 @@ func TestBuildSession_UnknownRoleTreatedAsUser(t *testing.T) { assert.Equal(t, "do this", all[0].Message.Content) } -func TestSessionUsage_OmitsZero(t *testing.T) { - sess := session.New() - assert.Nil(t, sessionUsage(sess)) - - sess.InputTokens = 5 - sess.OutputTokens = 7 - usage := sessionUsage(sess) - require.NotNil(t, usage) - assert.Equal(t, int64(5), usage.PromptTokens) - assert.Equal(t, int64(7), usage.CompletionTokens) - assert.Equal(t, int64(12), usage.TotalTokens) -} - func TestAgentPolicy_Pick(t *testing.T) { p := agentPolicy{exposed: []string{"root", "reviewer"}, fallback: "root"} @@ -457,6 +443,18 @@ func TestStopSequences_UnmarshalJSON(t *testing.T) { } } +func TestChatCompletionRequest_UnmarshalStreamOptions(t *testing.T) { + var req ChatCompletionRequest + require.NoError(t, json.Unmarshal([]byte(`{ + "messages": [{"role":"user","content":"hi"}], + "stream": true, + "stream_options": {"include_usage": true} + }`), &req)) + require.NotNil(t, req.StreamOptions) + assert.True(t, req.Stream) + assert.True(t, req.StreamOptions.IncludeUsage) +} + func TestSSEStream_ToolCallDelta(t *testing.T) { rec := httptest.NewRecorder() s := newSSEStream(rec, "chatcmpl-x", "root") @@ -477,6 +475,20 @@ func TestSSEStream_ToolCallDelta(t *testing.T) { assert.Contains(t, body, `"arguments":"{\"q\":\"docker\"}"`) } +func TestSSEStream_SendUsage(t *testing.T) { + rec := httptest.NewRecorder() + s := newSSEStream(rec, "chatcmpl-x", "root") + s.send(ChatCompletionStreamDelta{}, "stop") + s.sendUsage(&ChatCompletionUsage{PromptTokens: 5, CompletionTokens: 7, TotalTokens: 12}) + s.done() + + body := rec.Body.String() + assert.Contains(t, body, `"finish_reason":"stop"`) + assert.Contains(t, body, `"choices":[]`) + assert.Contains(t, body, `"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}`) + assert.Contains(t, body, "data: [DONE]") +} + func TestSSEStream_SendError(t *testing.T) { rec := httptest.NewRecorder() s := newSSEStream(rec, "chatcmpl-x", "root") diff --git a/pkg/chatserver/types.go b/pkg/chatserver/types.go index bbf3a4e06..8eafdb102 100644 --- a/pkg/chatserver/types.go +++ b/pkg/chatserver/types.go @@ -23,9 +23,10 @@ import ( // declare every field commonly sent by OpenAI clients so they are accepted // without surprise. Whether each field is *acted on* is documented inline. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - Stream bool `json:"stream,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + StreamOptions ChatCompletionStreamOptions `json:"stream_options"` // Temperature is parsed and range-checked but not yet plumbed through // to the runtime/model layer (no per-request override exists today). @@ -252,6 +253,7 @@ type ChatCompletionStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionStreamChoice `json:"choices"` + Usage *ChatCompletionUsage `json:"usage,omitempty"` } type ChatCompletionStreamChoice struct { @@ -260,6 +262,12 @@ type ChatCompletionStreamChoice struct { FinishReason string `json:"finish_reason,omitempty"` } +// ChatCompletionStreamOptions mirrors the subset of OpenAI's +// stream_options object we currently support. +type ChatCompletionStreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + type ChatCompletionStreamDelta struct { Role string `json:"role,omitempty"` Content string `json:"content,omitempty"`