diff --git a/README.md b/README.md index cdcec5b..6b9ef18 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,11 @@ If you're building a [Client](https://agentclientprotocol.com/protocol/overview# `acp.NewClientSideConnection(client, stdin, stdout)`. - Call `Initialize`, `NewSession`, and `Prompt` to run a turn and stream updates. +Connections accept options for transport-level behavior. For example, +`acp.WithMaxQueuedNotifications(n)` sets the per-connection capacity of the +inbound notification queue, while leaving outbound notifications and requests +unchanged. Values less than or equal to zero use the default capacity. + Helper constructors are provided to reduce boilerplate when working with union types: - Content blocks: `acp.TextBlock`, `acp.ImageBlock`, `acp.AudioBlock`, diff --git a/acp_test.go b/acp_test.go index 22dc45c..3e94eff 100644 --- a/acp_test.go +++ b/acp_test.go @@ -872,7 +872,7 @@ func TestConnectionFailsFastOnNotificationQueueOverflow(t *testing.T) { } cause := context.Cause(c.ctx) - if !errors.Is(cause, errNotificationQueueOverflow) { + if !errors.Is(cause, ErrNotificationQueueOverflow) { t.Fatalf("expected overflow cancellation cause, got %v", cause) } diff --git a/agent.go b/agent.go index 26efc57..1f782d1 100644 --- a/agent.go +++ b/agent.go @@ -18,11 +18,11 @@ type AgentSideConnection struct { // NewAgentSideConnection creates a new agent-side connection bound to the // provided Agent implementation. -func NewAgentSideConnection(agent Agent, peerInput io.Writer, peerOutput io.Reader) *AgentSideConnection { +func NewAgentSideConnection(agent Agent, peerInput io.Writer, peerOutput io.Reader, opts ...ConnectionOption) *AgentSideConnection { asc := &AgentSideConnection{} asc.agent = agent asc.sessionCancels = make(map[string]context.CancelFunc) - asc.conn = NewConnection(asc.handleWithExtensions, peerInput, peerOutput) + asc.conn = NewConnection(asc.handleWithExtensions, peerInput, peerOutput, opts...) return asc } diff --git a/client.go b/client.go index c5faca7..574df15 100644 --- a/client.go +++ b/client.go @@ -13,10 +13,10 @@ type ClientSideConnection struct { // NewClientSideConnection creates a new client-side connection bound to the // provided Client implementation. -func NewClientSideConnection(client Client, peerInput io.Writer, peerOutput io.Reader) *ClientSideConnection { +func NewClientSideConnection(client Client, peerInput io.Writer, peerOutput io.Reader, opts ...ConnectionOption) *ClientSideConnection { csc := &ClientSideConnection{} csc.client = client - csc.conn = NewConnection(csc.handleWithExtensions, peerInput, peerOutput) + csc.conn = NewConnection(csc.handleWithExtensions, peerInput, peerOutput, opts...) return csc } diff --git a/connection.go b/connection.go index e33beb0..53de047 100644 --- a/connection.go +++ b/connection.go @@ -19,7 +19,29 @@ const ( defaultMaxQueuedNotifications = 1024 ) -var errNotificationQueueOverflow = errors.New("notification queue overflow") +// ErrNotificationQueueOverflow is the cancellation cause used when the inbound +// notification queue reaches its configured per-connection capacity. +var ErrNotificationQueueOverflow = errors.New("notification queue overflow") + +var errNotificationQueueOverflow = ErrNotificationQueueOverflow + +type connectionOptions struct { + maxQueuedNotifications int +} + +// ConnectionOption configures a Connection. +type ConnectionOption func(*connectionOptions) + +// WithMaxQueuedNotifications sets the per-connection capacity of the inbound +// notification queue. Values less than or equal to zero use the default. +// +// This bounds inbound notification buffering only; outbound notifications and +// requests are unaffected. +func WithMaxQueuedNotifications(n int) ConnectionOption { + return func(o *connectionOptions) { + o.maxQueuedNotifications = n + } +} type anyMessage struct { JSONRPC string `json:"jsonrpc"` @@ -88,24 +110,36 @@ type Connection struct { // notificationQueue serializes notification processing to maintain order. // It is bounded to keep memory usage predictable. - notificationQueue chan queuedNotification + maxQueuedNotifications int + notificationQueue chan queuedNotification } -func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection { +func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader, opts ...ConnectionOption) *Connection { + options := connectionOptions{maxQueuedNotifications: defaultMaxQueuedNotifications} + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + if options.maxQueuedNotifications <= 0 { + options.maxQueuedNotifications = defaultMaxQueuedNotifications + } + ctx, cancel := context.WithCancelCause(context.Background()) inboundCtx, inboundCancel := context.WithCancelCause(context.Background()) c := &Connection{ - w: peerInput, - r: peerOutput, - handler: handler, - pending: make(map[string]*pendingResponse), - inflight: make(map[string]context.CancelCauseFunc), - cancelRequestSignal: make(chan struct{}, 1), - ctx: ctx, - cancel: cancel, - inboundCtx: inboundCtx, - inboundCancel: inboundCancel, - notificationQueue: make(chan queuedNotification, defaultMaxQueuedNotifications), + w: peerInput, + r: peerOutput, + handler: handler, + pending: make(map[string]*pendingResponse), + inflight: make(map[string]context.CancelCauseFunc), + cancelRequestSignal: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + inboundCtx: inboundCtx, + inboundCancel: inboundCancel, + maxQueuedNotifications: options.maxQueuedNotifications, + notificationQueue: make(chan queuedNotification, options.maxQueuedNotifications), } c.notifyCond = sync.NewCond(&c.notifyMu) go func() { @@ -738,7 +772,7 @@ func (c *Connection) sendCancelRequest(idKey string) { } func (c *Connection) waitForResponse(ctx context.Context, pr *pendingResponse, idKey string) (responseEnvelope, error) { - peerDisconnectedErr := NewInternalError(map[string]any{"error": "peer disconnected before response"}) + peerDisconnectedErr := newInternalErrorWithCause(map[string]any{"error": "peer disconnected before response"}, ErrPeerDisconnected) select { case resp := <-pr.ch: @@ -775,7 +809,7 @@ func (c *Connection) waitNotificationsUpTo(ctx context.Context, target uint64) e return nil } - peerDisconnectedErr := NewInternalError(map[string]any{"error": "peer disconnected while waiting for pre-response notifications"}) + peerDisconnectedErr := newInternalErrorWithCause(map[string]any{"error": "peer disconnected while waiting for pre-response notifications"}, ErrPeerDisconnected) stopWake := make(chan struct{}) defer close(stopWake) diff --git a/connection_cancel_test.go b/connection_cancel_test.go index 8815305..091d190 100644 --- a/connection_cancel_test.go +++ b/connection_cancel_test.go @@ -650,6 +650,9 @@ func TestConnectionWaitForResponse_PeerDisconnectWinsOverDerivedContextCancel(t if re.Code != -32603 { t.Fatalf("iteration %d: expected disconnect error code -32603, got %d (%s)", i, re.Code, re.Message) } + if !errors.Is(err, ErrPeerDisconnected) { + t.Fatalf("iteration %d: expected error to wrap ErrPeerDisconnected, got %v", i, err) + } if _, ok := c.pending[idKey]; ok { t.Fatalf("iteration %d: pending request %q was not cleaned up", i, idKey) diff --git a/connection_queue_test.go b/connection_queue_test.go new file mode 100644 index 0000000..b903afc --- /dev/null +++ b/connection_queue_test.go @@ -0,0 +1,231 @@ +package acp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + "sync/atomic" + "testing" + "time" +) + +type asyncMessageReader struct { + ch chan []byte + buf []byte +} + +func newAsyncMessageReader(capacity int) *asyncMessageReader { + return &asyncMessageReader{ch: make(chan []byte, capacity)} +} + +func (r *asyncMessageReader) Read(p []byte) (int, error) { + for len(r.buf) == 0 { + next, ok := <-r.ch + if !ok { + return 0, io.EOF + } + r.buf = next + } + n := copy(p, r.buf) + r.buf = r.buf[n:] + return n, nil +} + +func (r *asyncMessageReader) send(msg []byte) { + r.ch <- msg +} + +func (r *asyncMessageReader) close() { + close(r.ch) +} + +func sessionUpdateNotificationLine(t *testing.T, seq int) []byte { + t.Helper() + + params, err := json.Marshal(SessionNotification{ + SessionId: SessionId("test-session"), + Update: UpdateAgentMessageText(fmt.Sprintf("update-%d", seq)), + }) + if err != nil { + t.Fatalf("marshal session/update params: %v", err) + } + + line, err := json.Marshal(anyMessage{ + JSONRPC: "2.0", + Method: ClientMethodSessionUpdate, + Params: params, + }) + if err != nil { + t.Fatalf("marshal session/update notification: %v", err) + } + return append(line, '\n') +} + +func TestConnectionMaxQueuedNotificationsOption(t *testing.T) { + t.Run("default", func(t *testing.T) { + c := NewConnection(func(context.Context, string, json.RawMessage) (any, *RequestError) { + return nil, nil + }, io.Discard, strings.NewReader("")) + if got := cap(c.notificationQueue); got != defaultMaxQueuedNotifications { + t.Fatalf("default notification queue cap = %d, want %d", got, defaultMaxQueuedNotifications) + } + if got := c.maxQueuedNotifications; got != defaultMaxQueuedNotifications { + t.Fatalf("default max queued notifications = %d, want %d", got, defaultMaxQueuedNotifications) + } + }) + + t.Run("non_positive_falls_back_to_default", func(t *testing.T) { + c := NewConnection(func(context.Context, string, json.RawMessage) (any, *RequestError) { + return nil, nil + }, io.Discard, strings.NewReader(""), WithMaxQueuedNotifications(0)) + if got := cap(c.notificationQueue); got != defaultMaxQueuedNotifications { + t.Fatalf("fallback notification queue cap = %d, want %d", got, defaultMaxQueuedNotifications) + } + }) + + t.Run("agent_and_client_constructors_apply_option", func(t *testing.T) { + clientConn := NewClientSideConnection(&clientFuncs{}, io.Discard, strings.NewReader(""), WithMaxQueuedNotifications(7)) + if got := cap(clientConn.conn.notificationQueue); got != 7 { + t.Fatalf("client notification queue cap = %d, want 7", got) + } + + agentConn := NewAgentSideConnection(&agentFuncs{}, io.Discard, strings.NewReader(""), WithMaxQueuedNotifications(9)) + if got := cap(agentConn.conn.notificationQueue); got != 9 { + t.Fatalf("agent notification queue cap = %d, want 9", got) + } + }) +} + +func TestClientSideConnectionNotificationBurstQueueCapacity(t *testing.T) { + const totalNotifications = defaultMaxQueuedNotifications + 128 + + t.Run("default_capacity_overflows", func(t *testing.T) { + reader := newAsyncMessageReader(totalNotifications) + defer reader.close() + + firstStarted := make(chan struct{}) + releaseFirst := make(chan struct{}) + var delivered atomic.Int64 + + c := NewClientSideConnection(&clientFuncs{ + SessionUpdateFunc: func(context.Context, SessionNotification) error { + if delivered.Add(1) == 1 { + close(firstStarted) + <-releaseFirst + } + return nil + }, + }, io.Discard, reader) + + reader.send(sessionUpdateNotificationLine(t, 0)) + select { + case <-firstStarted: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for first session/update handler") + } + + producerDone := make(chan struct{}) + go func() { + defer close(producerDone) + for i := 1; i < totalNotifications; i++ { + reader.send(sessionUpdateNotificationLine(t, i)) + } + }() + + select { + case <-producerDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for notification producer") + } + + select { + case <-c.Done(): + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for connection cancellation on queue overflow") + } + + if cause := context.Cause(c.conn.ctx); !errors.Is(cause, ErrNotificationQueueOverflow) { + t.Fatalf("connection cancellation cause = %v, want %v", cause, ErrNotificationQueueOverflow) + } + + close(releaseFirst) + }) + + t.Run("configured_capacity_delivers_all_notifications", func(t *testing.T) { + reader := newAsyncMessageReader(totalNotifications) + defer reader.close() + + firstStarted := make(chan struct{}) + releaseFirst := make(chan struct{}) + var delivered atomic.Int64 + + c := NewClientSideConnection(&clientFuncs{ + SessionUpdateFunc: func(context.Context, SessionNotification) error { + if delivered.Add(1) == 1 { + close(firstStarted) + <-releaseFirst + } + return nil + }, + }, io.Discard, reader, WithMaxQueuedNotifications(totalNotifications*2)) + + reader.send(sessionUpdateNotificationLine(t, 0)) + select { + case <-firstStarted: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for first session/update handler") + } + + producerDone := make(chan struct{}) + go func() { + defer close(producerDone) + for i := 1; i < totalNotifications; i++ { + reader.send(sessionUpdateNotificationLine(t, i)) + } + }() + + select { + case <-producerDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for notification producer") + } + + select { + case <-c.Done(): + t.Fatalf("connection closed while burst was queued: %v", context.Cause(c.conn.ctx)) + case <-time.After(50 * time.Millisecond): + } + + close(releaseFirst) + waitForDeliveredNotifications(t, &delivered, totalNotifications) + + select { + case <-c.Done(): + t.Fatalf("connection closed after delivering burst: %v", context.Cause(c.conn.ctx)) + default: + } + }) +} + +func waitForDeliveredNotifications(t *testing.T, delivered *atomic.Int64, want int64) { + t.Helper() + + deadline := time.After(2 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + if got := delivered.Load(); got == want { + return + } + + select { + case <-deadline: + t.Fatalf("delivered %d notifications, want %d", delivered.Load(), want) + case <-ticker.C: + } + } +} diff --git a/errors.go b/errors.go index fcbe25d..af3c69b 100644 --- a/errors.go +++ b/errors.go @@ -7,11 +7,17 @@ import ( "fmt" ) +// ErrPeerDisconnected is wrapped by request errors returned when the peer +// disconnects before a response is available. +var ErrPeerDisconnected = errors.New("peer disconnected") + // RequestError represents a JSON-RPC error response. type RequestError struct { Code int `json:"code"` Message string `json:"message"` Data any `json:"data,omitempty"` + + cause error } func (e *RequestError) Error() string { @@ -39,6 +45,13 @@ func (e *RequestError) Error() string { return fmt.Sprintf("code %d: %s", e.Code, e.Message) } +func (e *RequestError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + func NewParseError(data any) *RequestError { return &RequestError{Code: -32700, Message: "Parse error", Data: data} } @@ -59,6 +72,10 @@ func NewInternalError(data any) *RequestError { return &RequestError{Code: -32603, Message: "Internal error", Data: data} } +func newInternalErrorWithCause(data any, cause error) *RequestError { + return &RequestError{Code: -32603, Message: "Internal error", Data: data, cause: cause} +} + func NewRequestCancelled(data any) *RequestError { return &RequestError{Code: -32800, Message: "Request cancelled", Data: data} }