Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ If you're building a [Client](https://agentclientprotocol.com/protocol/overview#
`acp.NewClientSideConnection(client, stdin, stdout)`.
- Call `Initialize`, `NewSession`, and `Prompt` to run a turn and stream updates.

Connections accept options for transport-level behavior. For example,
`acp.WithMaxQueuedNotifications(n)` sets the per-connection capacity of the
inbound notification queue, while leaving outbound notifications and requests
unchanged. Values less than or equal to zero use the default capacity.

Helper constructors are provided to reduce boilerplate when working with union types:

- Content blocks: `acp.TextBlock`, `acp.ImageBlock`, `acp.AudioBlock`,
Expand Down
2 changes: 1 addition & 1 deletion acp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ func TestConnectionFailsFastOnNotificationQueueOverflow(t *testing.T) {
}

cause := context.Cause(c.ctx)
if !errors.Is(cause, errNotificationQueueOverflow) {
if !errors.Is(cause, ErrNotificationQueueOverflow) {
t.Fatalf("expected overflow cancellation cause, got %v", cause)
}

Expand Down
4 changes: 2 additions & 2 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ type AgentSideConnection struct {

// NewAgentSideConnection creates a new agent-side connection bound to the
// provided Agent implementation.
func NewAgentSideConnection(agent Agent, peerInput io.Writer, peerOutput io.Reader) *AgentSideConnection {
func NewAgentSideConnection(agent Agent, peerInput io.Writer, peerOutput io.Reader, opts ...ConnectionOption) *AgentSideConnection {
asc := &AgentSideConnection{}
asc.agent = agent
asc.sessionCancels = make(map[string]context.CancelFunc)
asc.conn = NewConnection(asc.handleWithExtensions, peerInput, peerOutput)
asc.conn = NewConnection(asc.handleWithExtensions, peerInput, peerOutput, opts...)
return asc
}

Expand Down
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ type ClientSideConnection struct {

// NewClientSideConnection creates a new client-side connection bound to the
// provided Client implementation.
func NewClientSideConnection(client Client, peerInput io.Writer, peerOutput io.Reader) *ClientSideConnection {
func NewClientSideConnection(client Client, peerInput io.Writer, peerOutput io.Reader, opts ...ConnectionOption) *ClientSideConnection {
csc := &ClientSideConnection{}
csc.client = client
csc.conn = NewConnection(csc.handleWithExtensions, peerInput, peerOutput)
csc.conn = NewConnection(csc.handleWithExtensions, peerInput, peerOutput, opts...)
return csc
}

Expand Down
66 changes: 50 additions & 16 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,29 @@ const (
defaultMaxQueuedNotifications = 1024
)

var errNotificationQueueOverflow = errors.New("notification queue overflow")
// ErrNotificationQueueOverflow is the cancellation cause used when the inbound
// notification queue reaches its configured per-connection capacity.
var ErrNotificationQueueOverflow = errors.New("notification queue overflow")

var errNotificationQueueOverflow = ErrNotificationQueueOverflow

type connectionOptions struct {
maxQueuedNotifications int
}

// ConnectionOption configures a Connection.
type ConnectionOption func(*connectionOptions)

// WithMaxQueuedNotifications sets the per-connection capacity of the inbound
// notification queue. Values less than or equal to zero use the default.
//
// This bounds inbound notification buffering only; outbound notifications and
// requests are unaffected.
func WithMaxQueuedNotifications(n int) ConnectionOption {
return func(o *connectionOptions) {
o.maxQueuedNotifications = n
}
}

type anyMessage struct {
JSONRPC string `json:"jsonrpc"`
Expand Down Expand Up @@ -88,24 +110,36 @@ type Connection struct {

// notificationQueue serializes notification processing to maintain order.
// It is bounded to keep memory usage predictable.
notificationQueue chan queuedNotification
maxQueuedNotifications int
notificationQueue chan queuedNotification
}

func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader, opts ...ConnectionOption) *Connection {
options := connectionOptions{maxQueuedNotifications: defaultMaxQueuedNotifications}
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
if options.maxQueuedNotifications <= 0 {
options.maxQueuedNotifications = defaultMaxQueuedNotifications
}

ctx, cancel := context.WithCancelCause(context.Background())
inboundCtx, inboundCancel := context.WithCancelCause(context.Background())
c := &Connection{
w: peerInput,
r: peerOutput,
handler: handler,
pending: make(map[string]*pendingResponse),
inflight: make(map[string]context.CancelCauseFunc),
cancelRequestSignal: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
inboundCtx: inboundCtx,
inboundCancel: inboundCancel,
notificationQueue: make(chan queuedNotification, defaultMaxQueuedNotifications),
w: peerInput,
r: peerOutput,
handler: handler,
pending: make(map[string]*pendingResponse),
inflight: make(map[string]context.CancelCauseFunc),
cancelRequestSignal: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
inboundCtx: inboundCtx,
inboundCancel: inboundCancel,
maxQueuedNotifications: options.maxQueuedNotifications,
notificationQueue: make(chan queuedNotification, options.maxQueuedNotifications),
}
c.notifyCond = sync.NewCond(&c.notifyMu)
go func() {
Expand Down Expand Up @@ -738,7 +772,7 @@ func (c *Connection) sendCancelRequest(idKey string) {
}

func (c *Connection) waitForResponse(ctx context.Context, pr *pendingResponse, idKey string) (responseEnvelope, error) {
peerDisconnectedErr := NewInternalError(map[string]any{"error": "peer disconnected before response"})
peerDisconnectedErr := newInternalErrorWithCause(map[string]any{"error": "peer disconnected before response"}, ErrPeerDisconnected)

select {
case resp := <-pr.ch:
Expand Down Expand Up @@ -775,7 +809,7 @@ func (c *Connection) waitNotificationsUpTo(ctx context.Context, target uint64) e
return nil
}

peerDisconnectedErr := NewInternalError(map[string]any{"error": "peer disconnected while waiting for pre-response notifications"})
peerDisconnectedErr := newInternalErrorWithCause(map[string]any{"error": "peer disconnected while waiting for pre-response notifications"}, ErrPeerDisconnected)
stopWake := make(chan struct{})
defer close(stopWake)

Expand Down
3 changes: 3 additions & 0 deletions connection_cancel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@ func TestConnectionWaitForResponse_PeerDisconnectWinsOverDerivedContextCancel(t
if re.Code != -32603 {
t.Fatalf("iteration %d: expected disconnect error code -32603, got %d (%s)", i, re.Code, re.Message)
}
if !errors.Is(err, ErrPeerDisconnected) {
t.Fatalf("iteration %d: expected error to wrap ErrPeerDisconnected, got %v", i, err)
}

if _, ok := c.pending[idKey]; ok {
t.Fatalf("iteration %d: pending request %q was not cleaned up", i, idKey)
Expand Down
231 changes: 231 additions & 0 deletions connection_queue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
package acp

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"sync/atomic"
"testing"
"time"
)

type asyncMessageReader struct {
ch chan []byte
buf []byte
}

func newAsyncMessageReader(capacity int) *asyncMessageReader {
return &asyncMessageReader{ch: make(chan []byte, capacity)}
}

func (r *asyncMessageReader) Read(p []byte) (int, error) {
for len(r.buf) == 0 {
next, ok := <-r.ch
if !ok {
return 0, io.EOF
}
r.buf = next
}
n := copy(p, r.buf)
r.buf = r.buf[n:]
return n, nil
}

func (r *asyncMessageReader) send(msg []byte) {
r.ch <- msg
}

func (r *asyncMessageReader) close() {
close(r.ch)
}

func sessionUpdateNotificationLine(t *testing.T, seq int) []byte {
t.Helper()

params, err := json.Marshal(SessionNotification{
SessionId: SessionId("test-session"),
Update: UpdateAgentMessageText(fmt.Sprintf("update-%d", seq)),
})
if err != nil {
t.Fatalf("marshal session/update params: %v", err)
}

line, err := json.Marshal(anyMessage{
JSONRPC: "2.0",
Method: ClientMethodSessionUpdate,
Params: params,
})
if err != nil {
t.Fatalf("marshal session/update notification: %v", err)
}
return append(line, '\n')
}

func TestConnectionMaxQueuedNotificationsOption(t *testing.T) {
t.Run("default", func(t *testing.T) {
c := NewConnection(func(context.Context, string, json.RawMessage) (any, *RequestError) {
return nil, nil
}, io.Discard, strings.NewReader(""))
if got := cap(c.notificationQueue); got != defaultMaxQueuedNotifications {
t.Fatalf("default notification queue cap = %d, want %d", got, defaultMaxQueuedNotifications)
}
if got := c.maxQueuedNotifications; got != defaultMaxQueuedNotifications {
t.Fatalf("default max queued notifications = %d, want %d", got, defaultMaxQueuedNotifications)
}
})

t.Run("non_positive_falls_back_to_default", func(t *testing.T) {
c := NewConnection(func(context.Context, string, json.RawMessage) (any, *RequestError) {
return nil, nil
}, io.Discard, strings.NewReader(""), WithMaxQueuedNotifications(0))
if got := cap(c.notificationQueue); got != defaultMaxQueuedNotifications {
t.Fatalf("fallback notification queue cap = %d, want %d", got, defaultMaxQueuedNotifications)
}
})

t.Run("agent_and_client_constructors_apply_option", func(t *testing.T) {
clientConn := NewClientSideConnection(&clientFuncs{}, io.Discard, strings.NewReader(""), WithMaxQueuedNotifications(7))
if got := cap(clientConn.conn.notificationQueue); got != 7 {
t.Fatalf("client notification queue cap = %d, want 7", got)
}

agentConn := NewAgentSideConnection(&agentFuncs{}, io.Discard, strings.NewReader(""), WithMaxQueuedNotifications(9))
if got := cap(agentConn.conn.notificationQueue); got != 9 {
t.Fatalf("agent notification queue cap = %d, want 9", got)
}
})
}

func TestClientSideConnectionNotificationBurstQueueCapacity(t *testing.T) {
const totalNotifications = defaultMaxQueuedNotifications + 128

t.Run("default_capacity_overflows", func(t *testing.T) {
reader := newAsyncMessageReader(totalNotifications)
defer reader.close()

firstStarted := make(chan struct{})
releaseFirst := make(chan struct{})
var delivered atomic.Int64

c := NewClientSideConnection(&clientFuncs{
SessionUpdateFunc: func(context.Context, SessionNotification) error {
if delivered.Add(1) == 1 {
close(firstStarted)
<-releaseFirst
}
return nil
},
}, io.Discard, reader)

reader.send(sessionUpdateNotificationLine(t, 0))
select {
case <-firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first session/update handler")
}

producerDone := make(chan struct{})
go func() {
defer close(producerDone)
for i := 1; i < totalNotifications; i++ {
reader.send(sessionUpdateNotificationLine(t, i))
}
}()

select {
case <-producerDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for notification producer")
}

select {
case <-c.Done():
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for connection cancellation on queue overflow")
}

if cause := context.Cause(c.conn.ctx); !errors.Is(cause, ErrNotificationQueueOverflow) {
t.Fatalf("connection cancellation cause = %v, want %v", cause, ErrNotificationQueueOverflow)
}

close(releaseFirst)
})

t.Run("configured_capacity_delivers_all_notifications", func(t *testing.T) {
reader := newAsyncMessageReader(totalNotifications)
defer reader.close()

firstStarted := make(chan struct{})
releaseFirst := make(chan struct{})
var delivered atomic.Int64

c := NewClientSideConnection(&clientFuncs{
SessionUpdateFunc: func(context.Context, SessionNotification) error {
if delivered.Add(1) == 1 {
close(firstStarted)
<-releaseFirst
}
return nil
},
}, io.Discard, reader, WithMaxQueuedNotifications(totalNotifications*2))

reader.send(sessionUpdateNotificationLine(t, 0))
select {
case <-firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first session/update handler")
}

producerDone := make(chan struct{})
go func() {
defer close(producerDone)
for i := 1; i < totalNotifications; i++ {
reader.send(sessionUpdateNotificationLine(t, i))
}
}()

select {
case <-producerDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for notification producer")
}

select {
case <-c.Done():
t.Fatalf("connection closed while burst was queued: %v", context.Cause(c.conn.ctx))
case <-time.After(50 * time.Millisecond):
}

close(releaseFirst)
waitForDeliveredNotifications(t, &delivered, totalNotifications)

select {
case <-c.Done():
t.Fatalf("connection closed after delivering burst: %v", context.Cause(c.conn.ctx))
default:
}
})
}

func waitForDeliveredNotifications(t *testing.T, delivered *atomic.Int64, want int64) {
t.Helper()

deadline := time.After(2 * time.Second)
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

for {
if got := delivered.Load(); got == want {
return
}

select {
case <-deadline:
t.Fatalf("delivered %d notifications, want %d", delivered.Load(), want)
case <-ticker.C:
}
}
}
Loading