Skip to content

Commit f490865

Browse files
committed
Finalize streaming state and handle terminal events
Add a finalized flag to streamingState (markFinalized/isFinalized) and use it to guard finalization paths across the streaming lifecycle. Ensure streams are closed in runAgentLoopStreamStep and short-circuit when state is nil in finishStreamingWithFailure. Update responsesTurnAdapter/FinaleAgentLoop and chatCompletionsTurnAdapter to use isFinalized checks, make processResponseStreamEvent explicitly handle response.failed and response.incomplete (finalize metadata, log, and return loop-stop), and return on response.completed. Add early-finalize guard in completeStreamingSuccess. Include new tests to cover completed, failed, and finalize behavior. These changes prevent duplicate finalization/races and ensure terminal events stop the loop and close resources correctly.
1 parent b775ea0 commit f490865

7 files changed

Lines changed: 153 additions & 4 deletions

bridges/ai/agent_loop_runtime.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ func runAgentLoopStreamStep[T any](
9090
handleEvent func(T) (done bool, cle *ContextLengthError, err error),
9191
handleErr func(error) (cle *ContextLengthError, err error),
9292
) (bool, *ContextLengthError, error) {
93+
if stream != nil {
94+
defer stream.Close()
95+
}
9396
writer := state.writer()
9497
writer.StepStart(ctx)
9598
defer writer.StepFinish(ctx)

bridges/ai/streaming_chat_completions.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func (a *chatCompletionsTurnAdapter) FinalizeAgentLoop(ctx context.Context) {
203203
state := a.state
204204
portal := a.portal
205205
meta := a.meta
206-
if state == nil || state.completedAtMs != 0 {
206+
if state == nil || state.isFinalized() {
207207
return
208208
}
209209

bridges/ai/streaming_error_handling.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ func (oc *AIClient) finishStreamingWithFailure(
4040
reason string,
4141
err error,
4242
) error {
43+
if state == nil {
44+
return err
45+
}
46+
if !state.markFinalized() {
47+
return streamFailureError(state, err)
48+
}
4349
if state != nil && state.stop.Load() != nil && reason == "cancelled" {
4450
reason = "stop"
4551
}

bridges/ai/streaming_lifecycle_cluster_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,104 @@ func TestProcessResponseStreamEventUpdatesCompletedResponseStatus(t *testing.T)
162162
t.Fatalf("expected writer metadata to be completed, got %#v", metadata["response_status"])
163163
}
164164
}
165+
166+
func TestProcessResponseStreamEventCompletedSignalsLoopStop(t *testing.T) {
167+
state := newTestStreamingStateWithTurn()
168+
oc := &AIClient{}
169+
170+
rsc := &responseStreamContext{
171+
base: &agentLoopProviderBase{
172+
oc: oc,
173+
log: zerolog.Nop(),
174+
state: state,
175+
},
176+
}
177+
178+
done, cle, err := oc.processResponseStreamEvent(context.Background(), rsc, responses.ResponseStreamEventUnion{
179+
Type: "response.completed",
180+
Response: responses.Response{
181+
ID: "resp_done",
182+
Status: "completed",
183+
},
184+
}, false)
185+
if !done {
186+
t.Fatal("expected completed response event to stop the stream loop")
187+
}
188+
if cle != nil {
189+
t.Fatalf("did not expect context-length error, got %#v", cle)
190+
}
191+
if err != nil {
192+
t.Fatalf("did not expect error, got %v", err)
193+
}
194+
}
195+
196+
func TestResponsesTurnAdapterFinalizeAgentLoopDoesNotSkipTerminalLifecycle(t *testing.T) {
197+
state := newTestStreamingStateWithTurn()
198+
state.turn.SetSuppressSend(true)
199+
state.writer().TextDelta(context.Background(), "done")
200+
state.completedAtMs = 123
201+
state.finishReason = "stop"
202+
203+
adapter := &responsesTurnAdapter{
204+
agentLoopProviderBase: agentLoopProviderBase{
205+
oc: &AIClient{},
206+
log: zerolog.Nop(),
207+
state: state,
208+
},
209+
}
210+
211+
adapter.FinalizeAgentLoop(context.Background())
212+
213+
if !state.isFinalized() {
214+
t.Fatal("expected finalize agent loop to finalize terminal response state")
215+
}
216+
217+
message := streamui.SnapshotUIMessage(state.turn.UIState())
218+
metadata, _ := message["metadata"].(map[string]any)
219+
if metadata["finish_reason"] != "stop" {
220+
t.Fatalf("expected finalized UI message finish_reason stop, got %#v", metadata["finish_reason"])
221+
}
222+
}
223+
224+
func TestProcessResponseStreamEventFailedFinalizesAsError(t *testing.T) {
225+
state := newTestStreamingStateWithTurn()
226+
state.turn.SetSuppressSend(true)
227+
state.writer().TextDelta(context.Background(), "hello")
228+
oc := &AIClient{}
229+
230+
rsc := &responseStreamContext{
231+
base: &agentLoopProviderBase{
232+
oc: oc,
233+
log: zerolog.Nop(),
234+
state: state,
235+
},
236+
}
237+
238+
done, cle, err := oc.processResponseStreamEvent(context.Background(), rsc, responses.ResponseStreamEventUnion{
239+
Type: "response.failed",
240+
Response: responses.Response{
241+
ID: "resp_failed",
242+
Status: "failed",
243+
Error: responses.ResponseError{
244+
Message: "boom",
245+
},
246+
},
247+
}, false)
248+
if !done {
249+
t.Fatal("expected failed response event to stop the stream loop")
250+
}
251+
if cle != nil {
252+
t.Fatalf("did not expect context-length error, got %#v", cle)
253+
}
254+
if err == nil {
255+
t.Fatal("expected failed response event to return an error")
256+
}
257+
if !state.isFinalized() {
258+
t.Fatal("expected failed response event to finalize the turn")
259+
}
260+
message := streamui.SnapshotUIMessage(state.turn.UIState())
261+
metadata, _ := message["metadata"].(map[string]any)
262+
if metadata["finish_reason"] != "error" {
263+
t.Fatalf("expected error finish_reason, got %#v", metadata["finish_reason"])
264+
}
265+
}

bridges/ai/streaming_responses_api.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ func (a *responsesTurnAdapter) RunAgentTurn(
176176
}
177177

178178
func (a *responsesTurnAdapter) FinalizeAgentLoop(ctx context.Context) {
179-
if a.state == nil || a.state.completedAtMs != 0 {
179+
if a.state == nil || a.state.isFinalized() {
180180
return
181181
}
182182
a.oc.finalizeResponsesStream(ctx, a.log, a.portal, a.state, a.meta)
@@ -217,9 +217,29 @@ func (oc *AIClient) processResponseStreamEvent(
217217
)
218218

219219
switch streamEvent.Type {
220-
case "response.created", "response.queued", "response.in_progress", "response.failed", "response.incomplete":
220+
case "response.created", "response.queued", "response.in_progress":
221221
oc.handleResponseLifecycleEvent(ctx, portal, state, meta, streamEvent.Type, streamEvent.Response)
222222

223+
case "response.failed":
224+
oc.handleResponseLifecycleEvent(ctx, portal, state, meta, streamEvent.Type, streamEvent.Response)
225+
state.completedAtMs = time.Now().UnixMilli()
226+
errText := strings.TrimSpace(streamEvent.Response.Error.Message)
227+
if errText == "" {
228+
errText = "response failed"
229+
}
230+
return true, nil, oc.finishStreamingWithFailure(ctx, log, portal, state, meta, "error", errors.New(errText))
231+
232+
case "response.incomplete":
233+
oc.handleResponseLifecycleEvent(ctx, portal, state, meta, streamEvent.Type, streamEvent.Response)
234+
state.completedAtMs = time.Now().UnixMilli()
235+
actions.finalizeMetadata()
236+
log.Debug().
237+
Str("reason", state.finishReason).
238+
Str("response_id", state.responseID).
239+
Str("response_status", state.responseStatus).
240+
Msg("Response stream ended incomplete" + contSuffix)
241+
return true, nil, nil
242+
223243
case "response.output_item.added":
224244
actions.outputItemAdded(streamEvent.Item)
225245

@@ -377,6 +397,7 @@ func (oc *AIClient) processResponseStreamEvent(
377397
}
378398
log.Debug().Str("reason", state.finishReason).Str("response_id", state.responseID).Int("images", len(state.pendingImages)).
379399
Msg("Response stream completed" + contSuffix)
400+
return true, nil, nil
380401

381402
case "error":
382403
apiErr := fmt.Errorf("API error: %s", streamEvent.Message)

bridges/ai/streaming_state.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ type streamingState struct {
7171
pendingMcpApprovals []mcpApprovalRequest
7272
pendingMcpApprovalsSeen map[string]bool
7373

74-
stop atomic.Pointer[assistantStopMetadata]
74+
finalized atomic.Bool
75+
stop atomic.Pointer[assistantStopMetadata]
7576
}
7677

7778
// sourceEventID returns the triggering user message event ID from the turn's source ref.
@@ -109,6 +110,20 @@ func (s *streamingState) writer() *sdk.Writer {
109110
return s.turn.Writer()
110111
}
111112

113+
func (s *streamingState) markFinalized() bool {
114+
if s == nil {
115+
return false
116+
}
117+
return s.finalized.CompareAndSwap(false, true)
118+
}
119+
120+
func (s *streamingState) isFinalized() bool {
121+
if s == nil {
122+
return false
123+
}
124+
return s.finalized.Load()
125+
}
126+
112127
func (s *streamingState) nextMessageTiming() agentremote.EventTiming {
113128
if s == nil {
114129
return agentremote.ResolveEventTiming(time.Time{}, 0)

bridges/ai/streaming_success.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ func (oc *AIClient) completeStreamingSuccess(
1717
state *streamingState,
1818
meta *PortalMetadata,
1919
) {
20+
if state == nil || !state.markFinalized() {
21+
return
22+
}
2023
state.completedAtMs = time.Now().UnixMilli()
2124
if state.finishReason == "" {
2225
state.finishReason = "stop"

0 commit comments

Comments
 (0)