Skip to content

Commit 56ac1c8

Browse files
committed
fix(a2a): guarantee artifact streams accurately propagate completion chunks and multi-turn handoffs
- Fixed `lastChunk` flag assignment in `ARTIFACT_PER_RUN` mode to dynamically adhere to `partial` state, instead of remaining `false` indefinitely. - Fixed `append` initialization in `EventProcessor` to ensure the first streamed chunk signals `append=false`. - Leaves the `Message` payload in the terminal `TaskStatusUpdateEvent` completion event as `null` by default, aligning identically with the reference Go implementation. - Refactored `EventConverter.findUserFunctionCall` to ignore `transfer_to_agent` function responses, ensuring remote A2A handoffs dispatch full conversational histories instead of terminating the prompt context. - Modified `RemoteA2AAgent.StreamHandler` to reliably enforce task progression via standard ADK Payload `TaskState.COMPLETED`, preventing premature stream cancellations by non-gRPC servers executing `!streaming` validations. - Added unit tests to `AgentExecutorTest` confirming `lastChunk` signals flip properly and affirming the `null` expectation for `Message`.
1 parent 7407e37 commit 56ac1c8

File tree

4 files changed

+76
-6
lines changed

4 files changed

+76
-6
lines changed

a2a/src/main/java/com/google/adk/a2a/agent/RemoteA2AAgent.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,9 @@ synchronized void handleEvent(ClientEvent clientEvent, AgentCard unused) {
328328
emitter.onNext(event);
329329
});
330330

331-
// For non-streaming communication, complete the flow; for streaming, wait until the client
332-
// marks the completion.
333-
if (isCompleted(clientEvent) || !streaming) {
331+
// Wait until the client receives a status payload marking the completion of the task
332+
// regardless of the underlying streaming or non-streaming protocol configuration.
333+
if (isCompleted(clientEvent)) {
334334
// Only complete the flow once.
335335
if (!done) {
336336
emitAggregatedEventAndClearBuffer(clientEvent);

a2a/src/main/java/com/google/adk/a2a/converters/EventConverter.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ public static String contextId(Event event) {
7474
return null;
7575
}
7676
FunctionResponse functionResponse = findUserFunctionResponse(candidate);
77-
if (functionResponse == null || functionResponse.id().isEmpty()) {
77+
if (functionResponse == null
78+
|| functionResponse.id().isEmpty()
79+
|| "transfer_to_agent".equals(functionResponse.name().orElse(""))) {
7880
return null;
7981
}
8082
for (int i = events.size() - 2; i >= 0; i--) {

a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ private static class EventProcessor {
299299
private final String runArtifactId;
300300
private final AgentExecutorConfig.OutputMode outputMode;
301301
private final Map<String, String> lastAgentPartialArtifact = new ConcurrentHashMap<>();
302+
private boolean isFirstEventForRun = true;
302303

303304
// All artifacts related to the invocation should have the same artifact id.
304305
private EventProcessor(AgentExecutorConfig.OutputMode outputMode) {
@@ -329,8 +330,9 @@ private Maybe<TaskArtifactUpdateEvent> process(
329330
}
330331
}
331332

332-
Boolean append = true;
333-
Boolean lastChunk = false;
333+
Boolean append = !isFirstEventForRun;
334+
isFirstEventForRun = false;
335+
Boolean lastChunk = !event.partial().orElse(false);
334336
String artifactId = runArtifactId;
335337

336338
if (outputMode == AgentExecutorConfig.OutputMode.ARTIFACT_PER_EVENT) {

a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,72 @@ public void process_statefulAggregation_tracksArtifactIdAndAppendForAuthor() {
439439
assertThat(ev3.getArtifact().artifactId()).isEqualTo(firstArtifactId);
440440
}
441441

442+
@Test
443+
public void execute_withDefaultArtifactPerRun_emitsMessageAndLastChunk() {
444+
Event partialEvent =
445+
Event.builder()
446+
.partial(true)
447+
.author("agent")
448+
.content(
449+
Content.builder()
450+
.parts(ImmutableList.of(Part.builder().text("chunk1").build()))
451+
.build())
452+
.build();
453+
Event finalEvent =
454+
Event.builder()
455+
.partial(false)
456+
.author("agent")
457+
.content(
458+
Content.builder()
459+
.parts(ImmutableList.of(Part.builder().text("chunk1chunk2").build()))
460+
.build())
461+
.build();
462+
463+
testAgent.setEventsToEmit(Flowable.just(partialEvent, finalEvent));
464+
AgentExecutor executor =
465+
new AgentExecutor.Builder()
466+
.app(App.builder().name("test_app").rootAgent(testAgent).build())
467+
.sessionService(new InMemorySessionService())
468+
.artifactService(new InMemoryArtifactService())
469+
.agentExecutorConfig(
470+
AgentExecutorConfig.builder()
471+
.outputMode(AgentExecutorConfig.OutputMode.ARTIFACT_PER_RUN)
472+
.build())
473+
.build();
474+
475+
RequestContext requestContext = createRequestContext();
476+
executor.execute(requestContext, eventQueue);
477+
478+
// Verify events were correctly formed.
479+
ImmutableList<TaskArtifactUpdateEvent> artifactEvents =
480+
enqueuedEvents.stream()
481+
.filter(e -> e instanceof TaskArtifactUpdateEvent)
482+
.map(e -> (TaskArtifactUpdateEvent) e)
483+
.collect(toImmutableList());
484+
485+
assertThat(artifactEvents).hasSize(2);
486+
// Partial event has lastChunk = false
487+
assertThat(artifactEvents.get(0).isLastChunk()).isFalse();
488+
// Final event has lastChunk = true
489+
assertThat(artifactEvents.get(1).isLastChunk()).isTrue();
490+
491+
// First chunk appends=false, subsequent chunks append=true
492+
assertThat(artifactEvents.get(0).isAppend()).isFalse();
493+
assertThat(artifactEvents.get(1).isAppend()).isTrue();
494+
495+
// Now verify the final TaskStatusUpdateEvent has a null message as expected
496+
Optional<TaskStatusUpdateEvent> statusEvent =
497+
enqueuedEvents.stream()
498+
.filter(e -> e instanceof TaskStatusUpdateEvent)
499+
.map(e -> (TaskStatusUpdateEvent) e)
500+
.filter(TaskStatusUpdateEvent::isFinal)
501+
.findFirst();
502+
503+
assertThat(statusEvent).isPresent();
504+
Message finalMessage = statusEvent.get().getStatus().message();
505+
assertThat(finalMessage).isNull();
506+
}
507+
442508
private static final class TestAgent extends BaseAgent {
443509
private Flowable<Event> eventsToEmit;
444510

0 commit comments

Comments
 (0)