Skip to content

Commit e8df478

Browse files
committed
Merge branch 'main' into batuhan/prob
2 parents 5eb3499 + 06d1901 commit e8df478

17 files changed

Lines changed: 632 additions & 68 deletions

bridges/ai/abort_helpers.go

Lines changed: 173 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,187 @@ package ai
33
import (
44
"context"
55
"fmt"
6+
"strings"
7+
"unicode"
8+
"unicode/utf8"
69

710
"maunium.net/go/mautrix/bridgev2"
11+
"maunium.net/go/mautrix/id"
812
)
913

10-
func formatAbortNotice(stopped int) string {
11-
if stopped <= 0 {
12-
return "Agent was aborted."
14+
type stopPlanKind string
15+
16+
const (
17+
stopPlanKindNoMatch stopPlanKind = "no-match"
18+
stopPlanKindRoomWide stopPlanKind = "room-wide"
19+
stopPlanKindActive stopPlanKind = "active-turn"
20+
stopPlanKindQueued stopPlanKind = "queued-turn"
21+
)
22+
23+
type userStopRequest struct {
24+
Portal *bridgev2.Portal
25+
Meta *PortalMetadata
26+
ReplyTo id.EventID
27+
RequestedByEventID id.EventID
28+
RequestedVia string
29+
}
30+
31+
type userStopPlan struct {
32+
Kind stopPlanKind
33+
Scope string
34+
TargetKind string
35+
TargetEventID id.EventID
36+
}
37+
38+
type userStopResult struct {
39+
Plan userStopPlan
40+
ActiveStopped bool
41+
QueuedStopped int
42+
SubagentsStopped int
43+
}
44+
45+
func stopLabel(count int, singular string) string {
46+
if count == 1 {
47+
return singular
48+
}
49+
return singular + "s"
50+
}
51+
52+
func formatAbortNotice(result userStopResult) string {
53+
switch result.Plan.Kind {
54+
case stopPlanKindNoMatch:
55+
return "No matching active or queued turn found for that reply."
56+
case stopPlanKindActive:
57+
if result.SubagentsStopped > 0 {
58+
return fmt.Sprintf("Stopped that turn. Stopped %d %s.", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent"))
59+
}
60+
return "Stopped that turn."
61+
case stopPlanKindQueued:
62+
if result.QueuedStopped <= 1 {
63+
return "Stopped that queued turn."
64+
}
65+
return fmt.Sprintf("Stopped %d queued %s.", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn"))
66+
case stopPlanKindRoomWide:
67+
parts := make([]string, 0, 3)
68+
if result.ActiveStopped {
69+
parts = append(parts, "stopped the active turn")
70+
}
71+
if result.QueuedStopped > 0 {
72+
parts = append(parts, fmt.Sprintf("removed %d queued %s", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn")))
73+
}
74+
if result.SubagentsStopped > 0 {
75+
parts = append(parts, fmt.Sprintf("stopped %d %s", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent")))
76+
}
77+
if len(parts) == 0 {
78+
return "No active or queued turns to stop."
79+
}
80+
for i := range parts {
81+
r, size := utf8.DecodeRuneInString(parts[i])
82+
parts[i] = string(unicode.ToUpper(r)) + parts[i][size:]
83+
}
84+
return strings.Join(parts, ". ") + "."
85+
default:
86+
return "No active or queued turns to stop."
87+
}
88+
}
89+
90+
func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata {
91+
return &assistantStopMetadata{
92+
Reason: "user_stop",
93+
Scope: plan.Scope,
94+
TargetKind: plan.TargetKind,
95+
TargetEventID: plan.TargetEventID.String(),
96+
RequestedByEventID: req.RequestedByEventID.String(),
97+
RequestedVia: strings.TrimSpace(req.RequestedVia),
1398
}
14-
label := "sub-agents"
15-
if stopped == 1 {
16-
label = "sub-agent"
99+
}
100+
101+
func (oc *AIClient) resolveUserStopPlan(req userStopRequest) userStopPlan {
102+
if req.Portal == nil || req.Portal.MXID == "" {
103+
return userStopPlan{Kind: stopPlanKindNoMatch}
104+
}
105+
if req.ReplyTo == "" {
106+
return userStopPlan{
107+
Kind: stopPlanKindRoomWide,
108+
Scope: "room",
109+
TargetKind: "all",
110+
}
111+
}
112+
113+
_, sourceEventID, initialEventID, _ := oc.roomRunTarget(req.Portal.MXID)
114+
if initialEventID != "" && req.ReplyTo == initialEventID {
115+
return userStopPlan{
116+
Kind: stopPlanKindActive,
117+
Scope: "turn",
118+
TargetKind: "placeholder_event",
119+
TargetEventID: req.ReplyTo,
120+
}
121+
}
122+
if sourceEventID != "" && req.ReplyTo == sourceEventID {
123+
return userStopPlan{
124+
Kind: stopPlanKindActive,
125+
Scope: "turn",
126+
TargetKind: "source_event",
127+
TargetEventID: req.ReplyTo,
128+
}
129+
}
130+
return userStopPlan{
131+
Kind: stopPlanKindQueued,
132+
Scope: "turn",
133+
TargetKind: "source_event",
134+
TargetEventID: req.ReplyTo,
17135
}
18-
return fmt.Sprintf("Agent was aborted. Stopped %d %s.", stopped, label)
19136
}
20137

21-
func (oc *AIClient) abortRoom(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) int {
22-
if portal == nil {
23-
return 0
138+
func (oc *AIClient) finalizeStoppedQueueItems(ctx context.Context, items []pendingQueueItem) int {
139+
for _, item := range items {
140+
oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending)
141+
oc.sendQueueRejectedStatus(ctx, item.pending.Portal, item.pending.Event, item.pending.StatusEvents, "Stopped.")
24142
}
25-
oc.cancelRoomRun(portal.MXID)
26-
oc.clearPendingQueue(portal.MXID)
27-
stopped := oc.stopSubagentRuns(portal.MXID)
28-
if meta != nil {
29-
meta.AbortedLastRun = true
30-
oc.savePortalQuiet(ctx, portal, "abort")
143+
return len(items)
144+
}
145+
146+
func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest, plan userStopPlan) userStopResult {
147+
result := userStopResult{Plan: plan}
148+
if req.Portal == nil || req.Portal.MXID == "" {
149+
return result
150+
}
151+
roomID := req.Portal.MXID
152+
switch plan.Kind {
153+
case stopPlanKindRoomWide:
154+
if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) {
155+
result.ActiveStopped = oc.cancelRoomRun(roomID)
156+
}
157+
result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID))
158+
result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID)
159+
case stopPlanKindActive:
160+
markedStopped := oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req))
161+
if markedStopped {
162+
result.ActiveStopped = oc.cancelRoomRun(roomID)
163+
}
164+
if result.ActiveStopped {
165+
result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID)
166+
} else {
167+
result.Plan.Kind = stopPlanKindNoMatch
168+
}
169+
case stopPlanKindQueued:
170+
result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID))
171+
if result.QueuedStopped == 0 {
172+
result.Plan.Kind = stopPlanKindNoMatch
173+
}
174+
}
175+
176+
if req.Meta != nil && (result.ActiveStopped || result.QueuedStopped > 0 || result.SubagentsStopped > 0) {
177+
req.Meta.AbortedLastRun = true
178+
oc.savePortalQuiet(ctx, req.Portal, "stop")
31179
}
32-
return stopped
180+
if req.Meta != nil && result.QueuedStopped > 0 {
181+
oc.notifySessionMutation(ctx, req.Portal, req.Meta, false)
182+
}
183+
return result
184+
}
185+
186+
func (oc *AIClient) handleUserStop(ctx context.Context, req userStopRequest) userStopResult {
187+
plan := oc.resolveUserStopPlan(req)
188+
return oc.executeUserStopPlan(ctx, req, plan)
33189
}

bridges/ai/abort_helpers_test.go

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package ai
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"maunium.net/go/mautrix/bridgev2"
8+
"maunium.net/go/mautrix/bridgev2/database"
9+
"maunium.net/go/mautrix/id"
10+
11+
bridgesdk "github.com/beeper/agentremote/sdk"
12+
)
13+
14+
func TestResolveUserStopPlanRoomWideWithoutReply(t *testing.T) {
15+
oc := &AIClient{}
16+
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}}
17+
req := userStopRequest{Portal: portal, RequestedVia: "command"}
18+
19+
plan := oc.resolveUserStopPlan(req)
20+
if plan.Kind != stopPlanKindRoomWide {
21+
t.Fatalf("expected room-wide stop, got %#v", plan)
22+
}
23+
if plan.TargetKind != "all" || plan.Scope != "room" {
24+
t.Fatalf("unexpected room-wide stop plan: %#v", plan)
25+
}
26+
}
27+
28+
func TestResolveUserStopPlanMatchesActiveReplyTargets(t *testing.T) {
29+
roomID := id.RoomID("!room:test")
30+
oc := &AIClient{
31+
activeRoomRuns: map[id.RoomID]*roomRunState{
32+
roomID: {
33+
sourceEvent: id.EventID("$user"),
34+
initialEvent: id.EventID("$assistant"),
35+
},
36+
},
37+
}
38+
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}}
39+
40+
placeholderPlan := oc.resolveUserStopPlan(userStopRequest{
41+
Portal: portal,
42+
ReplyTo: id.EventID("$assistant"),
43+
})
44+
if placeholderPlan.Kind != stopPlanKindActive || placeholderPlan.TargetKind != "placeholder_event" {
45+
t.Fatalf("expected placeholder-targeted active stop, got %#v", placeholderPlan)
46+
}
47+
48+
sourcePlan := oc.resolveUserStopPlan(userStopRequest{
49+
Portal: portal,
50+
ReplyTo: id.EventID("$user"),
51+
})
52+
if sourcePlan.Kind != stopPlanKindActive || sourcePlan.TargetKind != "source_event" {
53+
t.Fatalf("expected source-targeted active stop, got %#v", sourcePlan)
54+
}
55+
}
56+
57+
func TestResolveUserStopPlanSpeculativelyReturnsQueued(t *testing.T) {
58+
oc := &AIClient{}
59+
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}}
60+
61+
plan := oc.resolveUserStopPlan(userStopRequest{
62+
Portal: portal,
63+
ReplyTo: id.EventID("$unknown"),
64+
})
65+
if plan.Kind != stopPlanKindQueued || plan.TargetKind != "source_event" {
66+
t.Fatalf("expected speculative queued stop plan, got %#v", plan)
67+
}
68+
}
69+
70+
func TestExecuteUserStopPlanFallsBackToNoMatch(t *testing.T) {
71+
oc := &AIClient{}
72+
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}}
73+
74+
result := oc.executeUserStopPlan(context.Background(), userStopRequest{
75+
Portal: portal,
76+
}, userStopPlan{
77+
Kind: stopPlanKindQueued,
78+
Scope: "turn",
79+
TargetKind: "source_event",
80+
TargetEventID: id.EventID("$nonexistent"),
81+
})
82+
if result.Plan.Kind != stopPlanKindNoMatch {
83+
t.Fatalf("expected no-match fallback, got %#v", result.Plan)
84+
}
85+
if result.QueuedStopped != 0 {
86+
t.Fatalf("expected zero queued stopped, got %d", result.QueuedStopped)
87+
}
88+
}
89+
90+
func TestExecuteUserStopPlanRemovesOnlyTargetedQueuedTurn(t *testing.T) {
91+
roomID := id.RoomID("!room:test")
92+
oc := &AIClient{
93+
pendingQueues: map[id.RoomID]*pendingQueue{
94+
roomID: {
95+
items: []pendingQueueItem{
96+
{pending: pendingMessage{SourceEventID: id.EventID("$one")}},
97+
{pending: pendingMessage{SourceEventID: id.EventID("$two")}},
98+
},
99+
},
100+
},
101+
}
102+
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}}
103+
104+
result := oc.executeUserStopPlan(context.Background(), userStopRequest{
105+
Portal: portal,
106+
}, userStopPlan{
107+
Kind: stopPlanKindQueued,
108+
Scope: "turn",
109+
TargetKind: "source_event",
110+
TargetEventID: id.EventID("$one"),
111+
})
112+
if result.QueuedStopped != 1 {
113+
t.Fatalf("expected one queued turn to stop, got %#v", result)
114+
}
115+
snapshot := oc.getQueueSnapshot(roomID)
116+
if snapshot == nil || len(snapshot.items) != 1 {
117+
t.Fatalf("expected one queued item to remain, got %#v", snapshot)
118+
}
119+
if got := snapshot.items[0].pending.sourceEventID(); got != id.EventID("$two") {
120+
t.Fatalf("expected remaining queued event $two, got %q", got)
121+
}
122+
}
123+
124+
func TestExecuteUserStopPlanActiveNoOpFallsBackToNoMatch(t *testing.T) {
125+
roomID := id.RoomID("!room:test")
126+
oc := &AIClient{
127+
activeRoomRuns: map[id.RoomID]*roomRunState{
128+
roomID: {
129+
sourceEvent: id.EventID("$user"),
130+
},
131+
},
132+
}
133+
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}}
134+
135+
result := oc.executeUserStopPlan(context.Background(), userStopRequest{
136+
Portal: portal,
137+
ReplyTo: id.EventID("$user"),
138+
}, userStopPlan{
139+
Kind: stopPlanKindActive,
140+
Scope: "turn",
141+
TargetKind: "source_event",
142+
TargetEventID: id.EventID("$user"),
143+
})
144+
if result.Plan.Kind != stopPlanKindNoMatch {
145+
t.Fatalf("expected no-match fallback for no-op active stop, got %#v", result.Plan)
146+
}
147+
if result.ActiveStopped {
148+
t.Fatalf("expected active stop to report false, got %#v", result)
149+
}
150+
}
151+
152+
func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) {
153+
oc := &AIClient{}
154+
conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil)
155+
turn := conv.StartTurn(context.Background(), nil, &bridgesdk.SourceRef{EventID: "$user", SenderID: "@user:test"})
156+
turn.SetID("turn-stop")
157+
state := &streamingState{
158+
turn: turn,
159+
finishReason: "stop",
160+
responseID: "resp_123",
161+
completedAtMs: 1,
162+
}
163+
state.stop.Store(&assistantStopMetadata{
164+
Reason: "user_stop",
165+
Scope: "turn",
166+
TargetKind: "source_event",
167+
TargetEventID: "$user",
168+
RequestedByEventID: "$stop",
169+
RequestedVia: "command",
170+
})
171+
172+
ui := oc.buildStreamUIMessage(state, nil, nil)
173+
metadata, ok := ui["metadata"].(map[string]any)
174+
if !ok {
175+
t.Fatalf("expected metadata map, got %T", ui["metadata"])
176+
}
177+
stop, ok := metadata["stop"].(map[string]any)
178+
if !ok {
179+
t.Fatalf("expected nested stop metadata, got %#v", metadata["stop"])
180+
}
181+
if stop["reason"] != "user_stop" || stop["requested_via"] != "command" {
182+
t.Fatalf("unexpected stop metadata: %#v", stop)
183+
}
184+
if metadata["response_status"] != "cancelled" {
185+
t.Fatalf("expected cancelled response status for stopped turn, got %#v", metadata["response_status"])
186+
}
187+
}

0 commit comments

Comments
 (0)