Skip to content

Commit 1c8c7aa

Browse files
committed
Track active prompts per ACP session
1 parent e486eff commit 1c8c7aa

2 files changed

Lines changed: 164 additions & 85 deletions

File tree

acp-core/src/main/java/com/agentclientprotocol/sdk/spec/AcpAgentSession.java

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
import java.time.Duration;
88
import java.util.Map;
9+
import java.util.Set;
910
import java.util.UUID;
1011
import java.util.concurrent.ConcurrentHashMap;
1112
import java.util.concurrent.Executors;
1213
import java.util.concurrent.atomic.AtomicLong;
13-
import java.util.concurrent.atomic.AtomicReference;
1414

1515
import reactor.core.scheduler.Scheduler;
1616
import reactor.core.scheduler.Schedulers;
@@ -78,10 +78,17 @@ public class AcpAgentSession implements AcpSession {
7878
private final AtomicLong requestCounter = new AtomicLong(0);
7979

8080
/**
81-
* Active prompt tracking for single-turn enforcement.
82-
* Only ONE prompt can be active at a time per ACP session.
81+
* Active prompt tracking for single-turn enforcement, keyed by logical ACP
82+
* sessionId.
83+
*
84+
* <p>
85+
* Kotlin SDK precedent: its Agent.SessionWrapper owns a single active prompt guard
86+
* per logical session wrapper. This Java session can multiplex multiple logical ACP
87+
* sessionIds over one transport connection, so the same single-turn rule needs to
88+
* be applied per sessionId instead of once for the whole connection.
89+
* </p>
8390
*/
84-
private final AtomicReference<ActivePrompt> activePrompt = new AtomicReference<>(null);
91+
private final ConcurrentHashMap<String, ActivePrompt> activePrompts = new ConcurrentHashMap<>();
8592

8693
/**
8794
* Represents an active prompt session for single-turn enforcement.
@@ -235,12 +242,12 @@ private Mono<AcpSchema.JSONRPCResponse> handleIncomingRequest(AcpSchema.JSONRPCR
235242
String sessionId = extractSessionId(request.params());
236243
ActivePrompt newPrompt = new ActivePrompt(sessionId, request.id());
237244

238-
// Try to set as active prompt - fails if another prompt is active
239-
if (!activePrompt.compareAndSet(null, newPrompt)) {
240-
ActivePrompt current = activePrompt.get();
241-
logger.warn("Rejected concurrent prompt request. Active prompt: sessionId={}, requestId={}",
242-
current != null ? current.sessionId() : "unknown",
243-
current != null ? current.requestId() : "unknown");
245+
// Try to set as active prompt - fails if this logical session already has
246+
// a prompt active.
247+
ActivePrompt current = activePrompts.putIfAbsent(sessionId, newPrompt);
248+
if (current != null) {
249+
logger.warn("Rejected concurrent prompt request for sessionId={}. Active requestId={}", sessionId,
250+
current.requestId());
244251
return Mono.just(new AcpSchema.JSONRPCResponse(AcpSchema.JSONRPC_VERSION, request.id(), null,
245252
new AcpSchema.JSONRPCError(-32000, "There is already an active prompt execution", null)));
246253
}
@@ -249,8 +256,8 @@ private Mono<AcpSchema.JSONRPCResponse> handleIncomingRequest(AcpSchema.JSONRPCR
249256
return handler.handle(request.params())
250257
.map(result -> new AcpSchema.JSONRPCResponse(AcpSchema.JSONRPC_VERSION, request.id(), result, null))
251258
.doFinally(signal -> {
252-
activePrompt.compareAndSet(newPrompt, null);
253-
logger.debug("Prompt completed with signal: {}", signal);
259+
activePrompts.remove(sessionId, newPrompt);
260+
logger.debug("Prompt completed for sessionId={} with signal: {}", sessionId, signal);
254261
});
255262
}
256263

@@ -262,8 +269,13 @@ private Mono<AcpSchema.JSONRPCResponse> handleIncomingRequest(AcpSchema.JSONRPCR
262269
/**
263270
* Extracts the sessionId from request parameters.
264271
*/
265-
@SuppressWarnings("unchecked")
266272
private String extractSessionId(Object params) {
273+
if (params instanceof AcpSchema.PromptRequest promptRequest) {
274+
return promptRequest.sessionId() != null ? promptRequest.sessionId() : "unknown";
275+
}
276+
if (params instanceof AcpSchema.CancelNotification cancelNotification) {
277+
return cancelNotification.sessionId() != null ? cancelNotification.sessionId() : "unknown";
278+
}
267279
if (params instanceof Map<?, ?> map) {
268280
Object sessionId = map.get("sessionId");
269281
return sessionId != null ? sessionId.toString() : "unknown";
@@ -289,9 +301,8 @@ private Mono<Void> handleIncomingNotification(AcpSchema.JSONRPCNotification noti
289301
// Handle cancel notification specially
290302
if (AcpSchema.METHOD_SESSION_CANCEL.equals(notification.method())) {
291303
String sessionId = extractSessionId(notification.params());
292-
ActivePrompt current = activePrompt.get();
293-
if (current != null && sessionId.equals(current.sessionId())) {
294-
activePrompt.compareAndSet(current, null);
304+
ActivePrompt current = activePrompts.remove(sessionId);
305+
if (current != null) {
295306
logger.debug("Cancelled active prompt for session: {}", sessionId);
296307
}
297308
}
@@ -372,16 +383,39 @@ public Mono<Void> sendNotification(String method, Object params) {
372383
* @return true if a prompt is currently active
373384
*/
374385
public boolean hasActivePrompt() {
375-
return activePrompt.get() != null;
386+
return !activePrompts.isEmpty();
387+
}
388+
389+
/**
390+
* Checks if there is an active prompt being processed for the specified logical
391+
* ACP session.
392+
* @param sessionId the logical ACP session ID
393+
* @return true if a prompt is currently active for the session
394+
*/
395+
public boolean hasActivePrompt(String sessionId) {
396+
Assert.hasText(sessionId, "The sessionId can not be empty");
397+
return activePrompts.containsKey(sessionId);
376398
}
377399

378400
/**
379-
* Gets the session ID of the active prompt, if any.
380-
* @return the session ID or null if no prompt is active
401+
* Gets one active prompt session ID, if any.
402+
*
403+
* <p>
404+
* This is a legacy aggregate view. When multiple logical ACP sessions are active on
405+
* the same transport connection, the returned session ID is arbitrary.
406+
* </p>
407+
* @return one active session ID or null if no prompt is active
381408
*/
382409
public String getActivePromptSessionId() {
383-
ActivePrompt current = activePrompt.get();
384-
return current != null ? current.sessionId() : null;
410+
return activePrompts.keySet().stream().findFirst().orElse(null);
411+
}
412+
413+
/**
414+
* Gets the logical ACP session IDs that currently have active prompts.
415+
* @return an immutable snapshot of active prompt session IDs
416+
*/
417+
public Set<String> getActivePromptSessionIds() {
418+
return Set.copyOf(activePrompts.keySet());
385419
}
386420

387421
/**
@@ -391,7 +425,7 @@ public String getActivePromptSessionId() {
391425
@Override
392426
public Mono<Void> closeGracefully() {
393427
return Mono.fromRunnable(() -> {
394-
activePrompt.set(null);
428+
activePrompts.clear();
395429
dismissPendingResponses();
396430
timeoutScheduler.dispose();
397431
}).then(this.transport.closeGracefully());
@@ -402,7 +436,7 @@ public Mono<Void> closeGracefully() {
402436
*/
403437
@Override
404438
public void close() {
405-
activePrompt.set(null);
439+
activePrompts.clear();
406440
dismissPendingResponses();
407441
timeoutScheduler.dispose();
408442
transport.close();

acp-core/src/test/java/com/agentclientprotocol/sdk/spec/AcpAgentSessionTest.java

Lines changed: 107 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
package com.agentclientprotocol.sdk.spec;
66

77
import java.time.Duration;
8-
import java.util.HashMap;
98
import java.util.List;
109
import java.util.Map;
1110
import java.util.concurrent.CountDownLatch;
11+
import java.util.concurrent.CopyOnWriteArrayList;
1212
import java.util.concurrent.TimeUnit;
13+
import java.util.concurrent.atomic.AtomicInteger;
1314
import java.util.concurrent.atomic.AtomicReference;
1415

1516
import com.agentclientprotocol.sdk.test.InMemoryTransportPair;
@@ -159,16 +160,17 @@ void handlesNotification() throws Exception {
159160
}
160161

161162
@Test
162-
void singleTurnEnforcementRejectsConcurrentPrompts() throws Exception {
163+
void singleTurnEnforcementRejectsConcurrentPromptsForSameSession() throws Exception {
163164
var transportPair = InMemoryTransportPair.create();
164165
try {
165-
// Create a handler that uses a Mono.delay to simulate async processing
166-
AtomicReference<CountDownLatch> promptCanProceedRef = new AtomicReference<>(new CountDownLatch(1));
166+
CountDownLatch handlerStarted = new CountDownLatch(1);
167+
AtomicInteger handlerInvocations = new AtomicInteger();
167168

168169
Map<String, AcpAgentSession.RequestHandler<?>> requestHandlers = Map.of(AcpSchema.METHOD_SESSION_PROMPT,
169170
params -> Mono.defer(() -> {
170-
// First call gets blocked, second call should be rejected before getting here
171-
return Mono.delay(Duration.ofMillis(100))
171+
handlerInvocations.incrementAndGet();
172+
handlerStarted.countDown();
173+
return Mono.delay(Duration.ofMillis(250))
172174
.map(ignored -> new AcpSchema.PromptResponse(AcpSchema.StopReason.END_TURN));
173175
}));
174176

@@ -177,52 +179,85 @@ void singleTurnEnforcementRejectsConcurrentPrompts() throws Exception {
177179

178180
Thread.sleep(100);
179181

180-
// Manually set active prompt to simulate an in-progress prompt
181-
// We use reflection to access the activePrompt field for testing
182-
java.lang.reflect.Field activePromptField = AcpAgentSession.class.getDeclaredField("activePrompt");
183-
activePromptField.setAccessible(true);
184-
@SuppressWarnings("unchecked")
185-
AtomicReference<Object> activePromptRef = (AtomicReference<Object>) activePromptField.get(session);
186-
187-
// Create an ActivePrompt instance using reflection
188-
Class<?> activePromptClass = Class.forName(
189-
"com.agentclientprotocol.sdk.spec.AcpAgentSession$ActivePrompt");
190-
java.lang.reflect.Constructor<?> constructor = activePromptClass.getDeclaredConstructor(String.class,
191-
Object.class);
192-
constructor.setAccessible(true);
193-
Object activePrompt = constructor.newInstance("session-1", "existing-request-id");
194-
activePromptRef.set(activePrompt);
195-
196-
// Verify active prompt is set
197-
assertThat(session.hasActivePrompt()).isTrue();
182+
CountDownLatch responseLatch = new CountDownLatch(2);
183+
List<AcpSchema.JSONRPCResponse> responses = new CopyOnWriteArrayList<>();
198184

199-
// Set up client to receive response
200-
CountDownLatch responseLatch = new CountDownLatch(1);
201-
AtomicReference<AcpSchema.JSONRPCResponse> response = new AtomicReference<>();
185+
transportPair.clientTransport().connect(mono -> mono.doOnNext(msg -> {
186+
if (msg instanceof AcpSchema.JSONRPCResponse response) {
187+
responses.add(response);
188+
}
189+
responseLatch.countDown();
190+
}).then(Mono.empty())).subscribe();
191+
192+
Thread.sleep(50);
193+
194+
transportPair.clientTransport().sendMessage(promptRequest("1", "session-1", "first")).block(TIMEOUT);
195+
assertThat(handlerStarted.await(5, TimeUnit.SECONDS)).isTrue();
196+
assertThat(session.hasActivePrompt("session-1")).isTrue();
197+
198+
transportPair.clientTransport().sendMessage(promptRequest("2", "session-1", "second")).block(TIMEOUT);
199+
200+
assertThat(responseLatch.await(5, TimeUnit.SECONDS)).isTrue();
201+
202+
AcpSchema.JSONRPCResponse rejectedResponse = responseById(responses, "2");
203+
assertThat(rejectedResponse.error()).isNotNull();
204+
assertThat(rejectedResponse.error().code()).isEqualTo(-32000);
205+
assertThat(rejectedResponse.error().message()).contains("already an active prompt");
206+
assertThat(handlerInvocations.get()).isEqualTo(1);
207+
assertThat(session.hasActivePrompt()).isFalse();
208+
}
209+
finally {
210+
transportPair.closeGracefully().block(TIMEOUT);
211+
}
212+
}
213+
214+
@Test
215+
void singleTurnEnforcementAllowsConcurrentPromptsForDifferentSessions() throws Exception {
216+
var transportPair = InMemoryTransportPair.create();
217+
try {
218+
CountDownLatch handlersStarted = new CountDownLatch(2);
219+
AtomicInteger handlerInvocations = new AtomicInteger();
220+
221+
Map<String, AcpAgentSession.RequestHandler<?>> requestHandlers = Map.of(AcpSchema.METHOD_SESSION_PROMPT,
222+
params -> Mono.defer(() -> {
223+
handlerInvocations.incrementAndGet();
224+
handlersStarted.countDown();
225+
return Mono.delay(Duration.ofMillis(250))
226+
.map(ignored -> new AcpSchema.PromptResponse(AcpSchema.StopReason.END_TURN));
227+
}));
228+
229+
AcpAgentSession session = new AcpAgentSession(TIMEOUT, transportPair.agentTransport(), requestHandlers,
230+
Map.of());
231+
232+
Thread.sleep(100);
233+
234+
CountDownLatch responseLatch = new CountDownLatch(2);
235+
List<AcpSchema.JSONRPCResponse> responses = new CopyOnWriteArrayList<>();
202236

203237
transportPair.clientTransport().connect(mono -> mono.doOnNext(msg -> {
204-
response.set((AcpSchema.JSONRPCResponse) msg);
238+
if (msg instanceof AcpSchema.JSONRPCResponse response) {
239+
responses.add(response);
240+
}
205241
responseLatch.countDown();
206242
}).then(Mono.empty())).subscribe();
207243

208244
Thread.sleep(50);
209245

210-
// Send prompt request while another is "active"
211-
Map<String, Object> params = new HashMap<>();
212-
params.put("sessionId", "session-1");
213-
params.put("prompt", List.of(new AcpSchema.TextContent("Hello")));
214-
AcpSchema.JSONRPCRequest request = new AcpSchema.JSONRPCRequest(AcpSchema.JSONRPC_VERSION, "1",
215-
AcpSchema.METHOD_SESSION_PROMPT, params);
216-
transportPair.clientTransport().sendMessage(request).block(TIMEOUT);
246+
transportPair.clientTransport().sendMessage(promptRequest("1", "session-1", "first")).block(TIMEOUT);
247+
transportPair.clientTransport().sendMessage(promptRequest("2", "session-2", "second")).block(TIMEOUT);
248+
249+
assertThat(handlersStarted.await(5, TimeUnit.SECONDS)).isTrue();
250+
assertThat(session.hasActivePrompt("session-1")).isTrue();
251+
assertThat(session.hasActivePrompt("session-2")).isTrue();
252+
assertThat(session.getActivePromptSessionIds()).containsExactlyInAnyOrder("session-1", "session-2");
217253

218-
// Wait for response
219254
assertThat(responseLatch.await(5, TimeUnit.SECONDS)).isTrue();
220255

221-
// Should be rejected with error
222-
assertThat(response.get()).isNotNull();
223-
assertThat(response.get().error()).isNotNull();
224-
assertThat(response.get().error().code()).isEqualTo(-32000);
225-
assertThat(response.get().error().message()).contains("already an active prompt");
256+
assertThat(responseById(responses, "1").error()).isNull();
257+
assertThat(responseById(responses, "2").error()).isNull();
258+
assertThat(handlerInvocations.get()).isEqualTo(2);
259+
assertThat(session.hasActivePrompt()).isFalse();
260+
assertThat(session.getActivePromptSessionIds()).isEmpty();
226261
}
227262
finally {
228263
transportPair.closeGracefully().block(TIMEOUT);
@@ -233,42 +268,43 @@ void singleTurnEnforcementRejectsConcurrentPrompts() throws Exception {
233268
void hasActivePromptReturnsCorrectState() throws Exception {
234269
var transportPair = InMemoryTransportPair.create();
235270
try {
271+
CountDownLatch handlerStarted = new CountDownLatch(1);
272+
236273
Map<String, AcpAgentSession.RequestHandler<?>> requestHandlers = Map.of(AcpSchema.METHOD_SESSION_PROMPT,
237-
params -> Mono.just(new AcpSchema.PromptResponse(AcpSchema.StopReason.END_TURN)));
274+
params -> Mono.defer(() -> {
275+
handlerStarted.countDown();
276+
return Mono.delay(Duration.ofMillis(250))
277+
.map(ignored -> new AcpSchema.PromptResponse(AcpSchema.StopReason.END_TURN));
278+
}));
238279

239280
AcpAgentSession session = new AcpAgentSession(TIMEOUT, transportPair.agentTransport(), requestHandlers,
240281
Map.of());
241282

242283
Thread.sleep(100);
243284

244-
// Initially no active prompt
245285
assertThat(session.hasActivePrompt()).isFalse();
286+
assertThat(session.hasActivePrompt("session-1")).isFalse();
246287
assertThat(session.getActivePromptSessionId()).isNull();
288+
assertThat(session.getActivePromptSessionIds()).isEmpty();
289+
290+
CountDownLatch responseLatch = new CountDownLatch(1);
291+
transportPair.clientTransport().connect(mono -> mono.doOnNext(msg -> responseLatch.countDown())
292+
.then(Mono.empty())).subscribe();
293+
294+
Thread.sleep(50);
295+
transportPair.clientTransport().sendMessage(promptRequest("1", "session-1", "hello")).block(TIMEOUT);
247296

248-
// Manually set active prompt using reflection to test the getter methods
249-
java.lang.reflect.Field activePromptField = AcpAgentSession.class.getDeclaredField("activePrompt");
250-
activePromptField.setAccessible(true);
251-
@SuppressWarnings("unchecked")
252-
AtomicReference<Object> activePromptRef = (AtomicReference<Object>) activePromptField.get(session);
253-
254-
// Create an ActivePrompt instance using reflection
255-
Class<?> activePromptClass = Class.forName(
256-
"com.agentclientprotocol.sdk.spec.AcpAgentSession$ActivePrompt");
257-
java.lang.reflect.Constructor<?> constructor = activePromptClass.getDeclaredConstructor(String.class,
258-
Object.class);
259-
constructor.setAccessible(true);
260-
Object activePrompt = constructor.newInstance("session-1", "request-1");
261-
activePromptRef.set(activePrompt);
262-
263-
// Now there should be an active prompt
297+
assertThat(handlerStarted.await(5, TimeUnit.SECONDS)).isTrue();
264298
assertThat(session.hasActivePrompt()).isTrue();
299+
assertThat(session.hasActivePrompt("session-1")).isTrue();
300+
assertThat(session.getActivePromptSessionIds()).containsExactly("session-1");
265301
assertThat(session.getActivePromptSessionId()).isEqualTo("session-1");
266302

267-
// Clear active prompt
268-
activePromptRef.set(null);
303+
assertThat(responseLatch.await(5, TimeUnit.SECONDS)).isTrue();
269304

270-
// Active prompt should be cleared
271305
assertThat(session.hasActivePrompt()).isFalse();
306+
assertThat(session.hasActivePrompt("session-1")).isFalse();
307+
assertThat(session.getActivePromptSessionIds()).isEmpty();
272308
assertThat(session.getActivePromptSessionId()).isNull();
273309
}
274310
finally {
@@ -327,4 +363,13 @@ void handlerErrorReturnsJsonRpcError() throws Exception {
327363
}
328364
}
329365

366+
private static AcpSchema.JSONRPCRequest promptRequest(String id, String sessionId, String text) {
367+
return new AcpSchema.JSONRPCRequest(AcpSchema.JSONRPC_VERSION, id, AcpSchema.METHOD_SESSION_PROMPT,
368+
new AcpSchema.PromptRequest(sessionId, List.of(new AcpSchema.TextContent(text))));
369+
}
370+
371+
private static AcpSchema.JSONRPCResponse responseById(List<AcpSchema.JSONRPCResponse> responses, Object id) {
372+
return responses.stream().filter(response -> id.equals(response.id())).findFirst().orElseThrow();
373+
}
374+
330375
}

0 commit comments

Comments
 (0)