From 041db507dbd2247770d21f0f3bda1d1d5ffe5579 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Wed, 15 Apr 2026 14:10:00 -0400 Subject: [PATCH 1/4] fix: Add session support for Handoff-hosted Agents In order to better support using `Workflows` hosted as `AIAgents` inside of Handoff workflows, we need to make proper use of AgentSession. This causes potential issues around checkpointing and making sure that we properly compute only the new incoming messages for each agent invocation. --- .../AIAgentsAbstractionsExtensions.cs | 2 +- .../Specialized/HandoffAgentExecutor.cs | 345 ++++++++---------- .../Specialized/HandoffEndExecutor.cs | 35 +- .../Specialized/HandoffMessagesFilter.cs | 143 ++++++++ .../Specialized/HandoffStartExecutor.cs | 53 ++- .../Specialized/HandoffState.cs | 4 - .../Specialized/MultiPartyConversation.cs | 49 +++ .../HandoffAgentExecutorTests.cs | 147 +++++++- .../Sample/12_HandOff_HostAsAgent.cs | 1 + .../SampleSmokeTest.cs | 81 +++- .../TestRunContext.cs | 12 +- 11 files changed, 613 insertions(+), 259 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs create mode 100644 dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs index 165de39855..8c94f4aa85 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs @@ -48,7 +48,7 @@ private static ChatMessage ChatAssistantToUserIfNotFromNamed(this ChatMessage me /// any that have a different from to /// . /// - public static List? ChangeAssistantToUserForOtherParticipants(this List messages, string targetAgentName) + public static List? ChangeAssistantToUserForOtherParticipants(this IEnumerable messages, string targetAgentName) { List? roleChanged = null; foreach (var m in messages) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs index eac2eb5687..d9a4ae941f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs @@ -6,6 +6,7 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -31,140 +32,6 @@ public HandoffAgentExecutorOptions(string? handoffInstructions, bool emitAgentRe public HandoffToolCallFilteringBehavior ToolCallFilteringBehavior { get; set; } = HandoffToolCallFilteringBehavior.HandoffOnly; } -[Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] -internal sealed class HandoffMessagesFilter -{ - private readonly HandoffToolCallFilteringBehavior _filteringBehavior; - - public HandoffMessagesFilter(HandoffToolCallFilteringBehavior filteringBehavior) - { - this._filteringBehavior = filteringBehavior; - } - - [Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] - internal static bool IsHandoffFunctionName(string name) - { - return name.StartsWith(HandoffWorkflowBuilder.FunctionPrefix, StringComparison.Ordinal); - } - - public IEnumerable FilterMessages(List messages) - { - if (this._filteringBehavior == HandoffToolCallFilteringBehavior.None) - { - return messages; - } - - Dictionary filteringCandidates = new(); - List filteredMessages = []; - HashSet messagesToRemove = []; - - bool filterHandoffOnly = this._filteringBehavior == HandoffToolCallFilteringBehavior.HandoffOnly; - foreach (ChatMessage unfilteredMessage in messages) - { - ChatMessage filteredMessage = unfilteredMessage.Clone(); - - // .Clone() is shallow, so we cannot modify the contents of the cloned message in place. - List contents = []; - contents.Capacity = unfilteredMessage.Contents?.Count ?? 0; - filteredMessage.Contents = contents; - - // Because this runs after the role changes from assistant to user for the target agent, we cannot rely on tool calls - // originating only from messages with the Assistant role. Instead, we need to inspect the contents of all non-Tool (result) - // FunctionCallContent. - if (unfilteredMessage.Role != ChatRole.Tool) - { - for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) - { - AIContent content = unfilteredMessage.Contents[i]; - if (content is not FunctionCallContent fcc || (filterHandoffOnly && !IsHandoffFunctionName(fcc.Name))) - { - filteredMessage.Contents.Add(content); - - // Track non-handoff function calls so their tool results are preserved in HandoffOnly mode - if (filterHandoffOnly && content is FunctionCallContent nonHandoffFcc) - { - filteringCandidates[nonHandoffFcc.CallId] = new FilterCandidateState(nonHandoffFcc.CallId) - { - IsHandoffFunction = false, - }; - } - } - else if (filterHandoffOnly) - { - if (!filteringCandidates.TryGetValue(fcc.CallId, out FilterCandidateState? candidateState)) - { - filteringCandidates[fcc.CallId] = new FilterCandidateState(fcc.CallId) - { - IsHandoffFunction = true, - }; - } - else - { - candidateState.IsHandoffFunction = true; - (int messageIndex, int contentIndex) = candidateState.FunctionCallResultLocation!.Value; - ChatMessage messageToFilter = filteredMessages[messageIndex]; - messageToFilter.Contents.RemoveAt(contentIndex); - if (messageToFilter.Contents.Count == 0) - { - messagesToRemove.Add(messageIndex); - } - } - } - else - { - // All mode: strip all FunctionCallContent - } - } - } - else - { - if (!filterHandoffOnly) - { - continue; - } - - for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) - { - AIContent content = unfilteredMessage.Contents[i]; - if (content is not FunctionResultContent frc - || (filteringCandidates.TryGetValue(frc.CallId, out FilterCandidateState? candidateState) - && candidateState.IsHandoffFunction is false)) - { - // Either this is not a function result content, so we should let it through, or it is a FRC that - // we know is not related to a handoff call. In either case, we should include it. - filteredMessage.Contents.Add(content); - } - else if (candidateState is null) - { - // We haven't seen the corresponding function call yet, so add it as a candidate to be filtered later - filteringCandidates[frc.CallId] = new FilterCandidateState(frc.CallId) - { - FunctionCallResultLocation = (filteredMessages.Count, filteredMessage.Contents.Count), - }; - } - // else we have seen the corresponding function call and it is a handoff, so we should filter it out. - } - } - - if (filteredMessage.Contents.Count > 0) - { - filteredMessages.Add(filteredMessage); - } - } - - return filteredMessages.Where((_, index) => !messagesToRemove.Contains(index)); - } - - private class FilterCandidateState(string callId) - { - public (int MessageIndex, int ContentIndex)? FunctionCallResultLocation { get; set; } - - public string CallId => callId; - - public bool? IsHandoffFunction { get; set; } - } -} - internal struct AgentInvocationResult(AgentResponse agentResponse, string? handoffTargetId) { public AgentResponse Response => agentResponse; @@ -175,19 +42,32 @@ internal struct AgentInvocationResult(AgentResponse agentResponse, string? hando public bool IsHandoffRequested => this.HandoffTargetId != null; } -internal record HandoffAgentHostState(HandoffState? CurrentTurnState, List FilteredIncomingMessages, List TurnMessages) +internal record HandoffAgentHostState( + HandoffState? IncomingState, + int ConversationBookmark, + AgentSession? AgentSession) { - public HandoffState PrepareHandoff(AgentInvocationResult invocationResult, string currentAgentId) - { - if (this.CurrentTurnState == null) - { - throw new InvalidOperationException("Cannot create a handoff request: Out of turn."); - } - - IEnumerable allMessages = [.. this.CurrentTurnState.Messages, .. this.TurnMessages, .. invocationResult.Response.Messages]; + [MemberNotNullWhen(true, nameof(IncomingState))] + [JsonIgnore] + public bool IsTakingTurn => this.IncomingState != null; +} - return new(this.CurrentTurnState.TurnToken, invocationResult.HandoffTargetId, allMessages.ToList(), currentAgentId); - } +internal sealed record StateRef(string Key, string? ScopeName) +{ + public ValueTask InvokeWithStateAsync(Func> invocation, + IWorkflowContext context, + CancellationToken cancellationToken) + => context.InvokeWithStateAsync(invocation, this.Key, this.ScopeName, cancellationToken); + + public ValueTask InvokeWithStateAsync(Func invocation, + IWorkflowContext context, + CancellationToken cancellationToken) + => context.InvokeWithStateAsync( + async (state, ctx, ct) => + { + await invocation(state, ctx, ct).ConfigureAwait(false); + return state; + }, this.Key, this.ScopeName, cancellationToken); } /// Executor used to represent an agent in a handoffs workflow, responding to events. @@ -208,7 +88,10 @@ internal sealed class HandoffAgentExecutor : private readonly HashSet _handoffFunctionNames = []; private readonly Dictionary _handoffFunctionToAgentId = []; - private static HandoffAgentHostState InitialStateFactory() => new(null, [], []); + private readonly StateRef _sharedStateRef = new(HandoffConstants.HandoffSharedStateKey, + HandoffConstants.HandoffSharedStateScope); + + private static HandoffAgentHostState InitialStateFactory() => new(null, 0, null); public HandoffAgentExecutor(AIAgent agent, HashSet handoffs, HandoffAgentExecutorOptions options) : base(IdFor(agent), InitialStateFactory) @@ -291,13 +174,18 @@ private ValueTask HandleUserInputResponseAsync( // resumes can be processed in one invocation. return this.InvokeWithStateAsync((state, ctx, ct) => { - state.TurnMessages.Add(new ChatMessage(ChatRole.User, [response]) + if (!state.IsTakingTurn) + { + throw new InvalidOperationException("Cannot process user responses when not taking a turn in Handoff Orchestration."); + } + + ChatMessage userMessage = new(ChatRole.User, [response]) { CreatedAt = DateTimeOffset.UtcNow, MessageId = Guid.NewGuid().ToString("N"), - }); + }; - return this.ContinueTurnAsync(state, ctx, ct); + return this.ContinueTurnAsync(state, [userMessage], ctx, ct); }, context, skipCache: false, cancellationToken); } @@ -315,24 +203,44 @@ private ValueTask HandleFunctionResultAsync( // resumes can be processed in one invocation. return this.InvokeWithStateAsync((state, ctx, ct) => { - state.TurnMessages.Add( - new ChatMessage(ChatRole.Tool, [result]) - { - AuthorName = this._agent.Name ?? this._agent.Id, - CreatedAt = DateTimeOffset.UtcNow, - MessageId = Guid.NewGuid().ToString("N"), - }); + if (!state.IsTakingTurn) + { + throw new InvalidOperationException("Cannot process user responses in when not taking a turn in Handoff Orchestration."); + } + + ChatMessage toolMessage = new(ChatRole.Tool, [result]) + { + AuthorName = this._agent.Name ?? this._agent.Id, + CreatedAt = DateTimeOffset.UtcNow, + MessageId = Guid.NewGuid().ToString("N"), + }; - return this.ContinueTurnAsync(state, ctx, ct); + return this.ContinueTurnAsync(state, [toolMessage], ctx, ct); }, context, skipCache: false, cancellationToken); } - private async ValueTask ContinueTurnAsync(HandoffAgentHostState state, IWorkflowContext context, CancellationToken cancellationToken) + private async ValueTask ContinueTurnAsync(HandoffAgentHostState state, List incomingMessages, IWorkflowContext context, CancellationToken cancellationToken, bool skipAddIncoming = false) { - List? roleChanges = state.FilteredIncomingMessages.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); + if (!state.IsTakingTurn) + { + throw new InvalidOperationException("Cannot process user responses in when not taking a turn in Handoff Orchestration."); + } - bool emitUpdateEvents = state.CurrentTurnState!.ShouldEmitStreamingEvents(this._options.EmitAgentResponseUpdateEvents); - AgentInvocationResult result = await this.InvokeAgentAsync([.. state.FilteredIncomingMessages, .. state.TurnMessages], context, emitUpdateEvents, cancellationToken) + // If a handoff was invoked by a previous agent, filter out the handoff function call and tool result messages + // before sending to the underlying agent. These are internal workflow mechanics that confuse the target model + // into ignoring the original user question. + // + // This will not filter out tool responses and approval responses that are part of this agent's turn, which is + // the expected behavior since those are part of the agent's reasoning process. + HandoffMessagesFilter handoffMessagesFilter = new(this._options.ToolCallFilteringBehavior); + IEnumerable messagesForAgent = state.IncomingState.RequestedHandoffTargetAgentId is not null + ? handoffMessagesFilter.FilterMessages(incomingMessages) + : incomingMessages; + + List? roleChanges = messagesForAgent.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); + + bool emitUpdateEvents = state.IncomingState!.ShouldEmitStreamingEvents(this._options.EmitAgentResponseUpdateEvents); + AgentInvocationResult result = await this.InvokeAgentAsync(messagesForAgent, context, emitUpdateEvents, cancellationToken) .ConfigureAwait(false); if (this.HasOutstandingRequests && result.IsHandoffRequested) @@ -342,20 +250,40 @@ private ValueTask HandleFunctionResultAsync( roleChanges.ResetUserToAssistantForChangedRoles(); + int newConversationBookmark = state.ConversationBookmark; + await this._sharedStateRef.InvokeWithStateAsync( + (sharedState, ctx, ct) => + { + if (sharedState == null) + { + throw new InvalidOperationException("Handoff Orchestration shared state was not properly initialized."); + } + + if (!skipAddIncoming) + { + sharedState.Conversation.AddMessages(incomingMessages); + } + + newConversationBookmark = sharedState.Conversation.AddMessages(result.Response.Messages); + + return new ValueTask(); + }, + context, + cancellationToken).ConfigureAwait(false); + // We send on the HandoffState even if handoff is not requested because we might be terminating the processing, but this only // happens if we have no outstanding requests. if (!this.HasOutstandingRequests) { - HandoffState outgoingState = state.PrepareHandoff(result, this._agent.Id); + HandoffState outgoingState = new(state.IncomingState.TurnToken, result.HandoffTargetId, this._agent.Id); await context.SendMessageAsync(outgoingState, cancellationToken).ConfigureAwait(false); - // reset the state for the next handoff (return-to-current is modeled as a new handoff turn, as opposed to "HITL", which - // can be a bit confusing.) - return null; + // reset the state for the next handoff, making sure to keep track of the conversation bookmark, and avoid resetting the + // agent session. (return-to-current is modeled as a new handoff turn, as opposed to "HITL", which can be a bit confusing.) + return state with { IncomingState = null, ConversationBookmark = newConversationBookmark }; } - state.TurnMessages.AddRange(result.Response.Messages); return state; } @@ -363,28 +291,36 @@ public override ValueTask HandleAsync(HandoffState message, IWorkflowContext con { return this.InvokeWithStateAsync(InvokeContinueTurnAsync, context, skipCache: false, cancellationToken); - ValueTask InvokeContinueTurnAsync(HandoffAgentHostState state, IWorkflowContext context, CancellationToken cancellationToken) + async ValueTask InvokeContinueTurnAsync(HandoffAgentHostState state, IWorkflowContext context, CancellationToken cancellationToken) { // Check that we are not getting this message while in the middle of a turn - if (state.CurrentTurnState != null) + if (state.IsTakingTurn) { throw new InvalidOperationException("Cannot have multiple simultaneous conversations in Handoff Orchestration."); } - // If a handoff was invoked by a previous agent, filter out the handoff function - // call and tool result messages before sending to the underlying agent. These - // are internal workflow mechanics that confuse the target model into ignoring the - // original user question. - HandoffMessagesFilter handoffMessagesFilter = new(this._options.ToolCallFilteringBehavior); - IEnumerable messagesForAgent = message.RequestedHandoffTargetAgentId is not null - ? handoffMessagesFilter.FilterMessages(message.Messages) - : message.Messages; + IEnumerable newConversationMessages = []; + int newConversationBookmark = 0; + + await this._sharedStateRef.InvokeWithStateAsync( + (sharedState, ctx, ct) => + { + if (sharedState == null) + { + throw new InvalidOperationException("Handoff Orchestration shared state was not properly initialized."); + } + + (newConversationMessages, newConversationBookmark) = sharedState.Conversation.CollectNewMessages(state.ConversationBookmark); - // This works because the runtime guarantees that a given executor instance will process messages serially, - // though there is no global cross-executor ordering guarantee (and in turn, no canonical message delivery order) - state = new(message, messagesForAgent.ToList(), []); + return new ValueTask(); + }, + context, + cancellationToken).ConfigureAwait(false); - return this.ContinueTurnAsync(state, context, cancellationToken); + state = state with { IncomingState = message, ConversationBookmark = newConversationBookmark }; + + return await this.ContinueTurnAsync(state, newConversationMessages.ToList(), context, cancellationToken, skipAddIncoming: true) + .ConfigureAwait(false); } } @@ -417,31 +353,46 @@ private async ValueTask InvokeAgentAsync(IEnumerable agentStream = this._agent.RunStreamingAsync( - messages, - options: this._agentOptions, - cancellationToken: cancellationToken); - string? requestedHandoff = null; List updates = []; List candidateRequests = []; - await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false)) - { - await AddUpdateAsync(update, cancellationToken).ConfigureAwait(false); - collector.ProcessAgentResponseUpdate(update, CollectHandoffRequestsFilter); - - bool CollectHandoffRequestsFilter(FunctionCallContent candidateHandoffRequest) + await this.InvokeWithStateAsync( + async (state, ctx, ct) => { - bool isHandoffRequest = this._handoffFunctionNames.Contains(candidateHandoffRequest.Name); - if (isHandoffRequest) + if (state.AgentSession == null) { - candidateRequests.Add(candidateHandoffRequest); + state = state with { AgentSession = await this._agent.CreateSessionAsync(ct).ConfigureAwait(false) }; } - return !isHandoffRequest; - } - } + IAsyncEnumerable agentStream = + this._agent.RunStreamingAsync(messages, + state.AgentSession, + options: this._agentOptions, + cancellationToken: ct); + + await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false)) + { + await AddUpdateAsync(update, ct).ConfigureAwait(false); + + collector.ProcessAgentResponseUpdate(update, CollectHandoffRequestsFilter); + + bool CollectHandoffRequestsFilter(FunctionCallContent candidateHandoffRequest) + { + bool isHandoffRequest = this._handoffFunctionNames.Contains(candidateHandoffRequest.Name); + if (isHandoffRequest) + { + candidateRequests.Add(candidateHandoffRequest); + } + + return !isHandoffRequest; + } + } + + return state; + }, + context, + cancellationToken: cancellationToken).ConfigureAwait(false); if (candidateRequests.Count > 1) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs index 0ba8fc3501..c9c75b91c4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -12,23 +13,33 @@ internal sealed class HandoffEndExecutor(bool returnToPrevious) : Executor(Execu { public const string ExecutorId = "HandoffEnd"; + private readonly StateRef _sharedStateRef = new(HandoffConstants.HandoffSharedStateKey, + HandoffConstants.HandoffSharedStateScope); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) => - protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler((handoff, context, cancellationToken) => - this.HandleAsync(handoff, context, cancellationToken))) + protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler( + (handoff, context, cancellationToken) => this.HandleAsync(handoff, context, cancellationToken))) .YieldsOutput>(); private async ValueTask HandleAsync(HandoffState handoff, IWorkflowContext context, CancellationToken cancellationToken) { - if (returnToPrevious) - { - await context.QueueStateUpdateAsync(HandoffConstants.PreviousAgentTrackerKey, - handoff.PreviousAgentId, - HandoffConstants.PreviousAgentTrackerScope, - cancellationToken) - .ConfigureAwait(false); - } - - await context.YieldOutputAsync(handoff.Messages, cancellationToken).ConfigureAwait(false); + await this._sharedStateRef.InvokeWithStateAsync( + async (HandoffSharedState? sharedState, IWorkflowContext context, CancellationToken cancellationToken) => + { + if (sharedState == null) + { + throw new InvalidOperationException("Handoff Orchestration shared state was not properly initialized."); + } + + if (returnToPrevious) + { + sharedState.PreviousAgentId = handoff.PreviousAgentId; + } + + await context.YieldOutputAsync(sharedState.Conversation.CloneAllMessages(), cancellationToken).ConfigureAwait(false); + + return sharedState; + }, context, cancellationToken).ConfigureAwait(false); } public ValueTask ResetAsync() => default; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs new file mode 100644 index 0000000000..7bc178c2d4 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.Specialized; + +[Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] +internal sealed class HandoffMessagesFilter +{ + private readonly HandoffToolCallFilteringBehavior _filteringBehavior; + + public HandoffMessagesFilter(HandoffToolCallFilteringBehavior filteringBehavior) + { + this._filteringBehavior = filteringBehavior; + } + + [Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] + internal static bool IsHandoffFunctionName(string name) + { + return name.StartsWith(HandoffWorkflowBuilder.FunctionPrefix, StringComparison.Ordinal); + } + + public IEnumerable FilterMessages(IEnumerable messages) + { + if (this._filteringBehavior == HandoffToolCallFilteringBehavior.None) + { + return messages; + } + + Dictionary filteringCandidates = new(); + List filteredMessages = []; + HashSet messagesToRemove = []; + + bool filterHandoffOnly = this._filteringBehavior == HandoffToolCallFilteringBehavior.HandoffOnly; + foreach (ChatMessage unfilteredMessage in messages) + { + ChatMessage filteredMessage = unfilteredMessage.Clone(); + + // .Clone() is shallow, so we cannot modify the contents of the cloned message in place. + List contents = []; + contents.Capacity = unfilteredMessage.Contents?.Count ?? 0; + filteredMessage.Contents = contents; + + // Because this runs after the role changes from assistant to user for the target agent, we cannot rely on tool calls + // originating only from messages with the Assistant role. Instead, we need to inspect the contents of all non-Tool (result) + // FunctionCallContent. + if (unfilteredMessage.Role != ChatRole.Tool) + { + for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) + { + AIContent content = unfilteredMessage.Contents[i]; + if (content is not FunctionCallContent fcc || (filterHandoffOnly && !IsHandoffFunctionName(fcc.Name))) + { + filteredMessage.Contents.Add(content); + + // Track non-handoff function calls so their tool results are preserved in HandoffOnly mode + if (filterHandoffOnly && content is FunctionCallContent nonHandoffFcc) + { + filteringCandidates[nonHandoffFcc.CallId] = new FilterCandidateState(nonHandoffFcc.CallId) + { + IsHandoffFunction = false, + }; + } + } + else if (filterHandoffOnly) + { + if (!filteringCandidates.TryGetValue(fcc.CallId, out FilterCandidateState? candidateState)) + { + filteringCandidates[fcc.CallId] = new FilterCandidateState(fcc.CallId) + { + IsHandoffFunction = true, + }; + } + else + { + candidateState.IsHandoffFunction = true; + (int messageIndex, int contentIndex) = candidateState.FunctionCallResultLocation!.Value; + ChatMessage messageToFilter = filteredMessages[messageIndex]; + messageToFilter.Contents.RemoveAt(contentIndex); + if (messageToFilter.Contents.Count == 0) + { + messagesToRemove.Add(messageIndex); + } + } + } + else + { + // All mode: strip all FunctionCallContent + } + } + } + else + { + if (!filterHandoffOnly) + { + continue; + } + + for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) + { + AIContent content = unfilteredMessage.Contents[i]; + if (content is not FunctionResultContent frc + || (filteringCandidates.TryGetValue(frc.CallId, out FilterCandidateState? candidateState) + && candidateState.IsHandoffFunction is false)) + { + // Either this is not a function result content, so we should let it through, or it is a FRC that + // we know is not related to a handoff call. In either case, we should include it. + filteredMessage.Contents.Add(content); + } + else if (candidateState is null) + { + // We haven't seen the corresponding function call yet, so add it as a candidate to be filtered later + filteringCandidates[frc.CallId] = new FilterCandidateState(frc.CallId) + { + FunctionCallResultLocation = (filteredMessages.Count, filteredMessage.Contents.Count), + }; + } + // else we have seen the corresponding function call and it is a handoff, so we should filter it out. + } + } + + if (filteredMessage.Contents.Count > 0) + { + filteredMessages.Add(filteredMessage); + } + } + + return filteredMessages.Where((_, index) => !messagesToRemove.Contains(index)); + } + + private class FilterCandidateState(string callId) + { + public (int MessageIndex, int ContentIndex)? FunctionCallResultLocation { get; set; } + + public string CallId => callId; + + public bool? IsHandoffFunction { get; set; } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs index 063f73bb6f..223517c35f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs @@ -9,8 +9,23 @@ namespace Microsoft.Agents.AI.Workflows.Specialized; internal static class HandoffConstants { + internal const string HandoffOrchestrationSharedScope = "HandoffOrchestration"; + internal const string PreviousAgentTrackerKey = "LastAgentId"; - internal const string PreviousAgentTrackerScope = "HandoffOrchestration"; + internal const string PreviousAgentTrackerScope = HandoffOrchestrationSharedScope; + + internal const string MultiPartyConversationKey = "MultiPartyConversation"; + internal const string MultiPartyConversationScope = HandoffOrchestrationSharedScope; + + internal const string HandoffSharedStateKey = "SharedState"; + internal const string HandoffSharedStateScope = HandoffOrchestrationSharedScope; +} + +internal sealed class HandoffSharedState +{ + public MultiPartyConversation Conversation { get; } = new(); + + public string? PreviousAgentId { get; set; } } /// Executor used at the start of a handoffs workflow to accumulate messages and emit them as HandoffState upon receiving a turn token. @@ -29,23 +44,25 @@ protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBui protected override ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) { - if (returnToPrevious) - { - return context.InvokeWithStateAsync( - async (string? previousAgentId, IWorkflowContext context, CancellationToken cancellationToken) => - { - HandoffState handoffState = new(new(emitEvents), null, messages, previousAgentId); - await context.SendMessageAsync(handoffState, cancellationToken).ConfigureAwait(false); - - return previousAgentId; - }, - HandoffConstants.PreviousAgentTrackerKey, - HandoffConstants.PreviousAgentTrackerScope, - cancellationToken); - } - - HandoffState handoff = new(new(emitEvents), null, messages); - return context.SendMessageAsync(handoff, cancellationToken); + return context.InvokeWithStateAsync( + async (HandoffSharedState? sharedState, IWorkflowContext context, CancellationToken cancellationToken) => + { + sharedState ??= new HandoffSharedState(); + sharedState.Conversation.AddMessages(messages); + + string? previousAgentId = sharedState.PreviousAgentId; + + // If we are configured to return to the previous agent, include the previous agent id in the handoff state. + // If there was no previousAgent, it will still be null. + HandoffState turnState = new(new(emitEvents), null, returnToPrevious ? previousAgentId : null); + + await context.SendMessageAsync(turnState, cancellationToken).ConfigureAwait(false); + + return sharedState; + }, + HandoffConstants.HandoffSharedStateKey, + HandoffConstants.HandoffSharedStateScope, + cancellationToken); } public new ValueTask ResetAsync() => base.ResetAsync(); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs index 644bc7df0e..24cf788cb8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs @@ -1,12 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Collections.Generic; -using Microsoft.Extensions.AI; - namespace Microsoft.Agents.AI.Workflows.Specialized; internal sealed record class HandoffState( TurnToken TurnToken, string? RequestedHandoffTargetAgentId, - List Messages, string? PreviousAgentId = null); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs new file mode 100644 index 0000000000..b587db2197 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.Specialized; + +internal sealed class MultiPartyConversation +{ + private readonly List _history = []; + private readonly object _mutex = new(); + + public List CloneAllMessages() => this._history.ToList(); + + public (ChatMessage[], int) CollectNewMessages(int bookmark) + { + lock (this._mutex) + { + int count = this._history.Count - bookmark; + if (count > 0) + { + return (this._history.Skip(bookmark).ToArray(), this.CurrentBookmark); + } + + return ([], this.CurrentBookmark); + } + } + + private int CurrentBookmark => this._history.Count; + + public int AddMessages(IEnumerable messages) + { + lock (this._mutex) + { + this._history.AddRange(messages); + return this.CurrentBookmark; + } + } + + public int AddMessage(ChatMessage message) + { + lock (this._mutex) + { + this._history.Add(message); + return this.CurrentBookmark; + } + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs index 1a5b2ea4d1..f4e1252e2a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs @@ -1,13 +1,43 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Checkpointing; +using Microsoft.Agents.AI.Workflows.Execution; +using Microsoft.Agents.AI.Workflows.InProc; +using Microsoft.Agents.AI.Workflows.Sample; using Microsoft.Agents.AI.Workflows.Specialized; +using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI.Workflows.UnitTests; public class HandoffAgentExecutorTests : AIAgentHostingExecutorTestsBase { + private static async ValueTask PrepareHandoffSharedStateAsync(TestRunContext? runContext = null, IEnumerable? messages = null) + { + runContext ??= new(); + + HandoffSharedState sharedState = new(); + + if (messages != null) + { + sharedState.Conversation.AddMessages(messages); + } + + await runContext.BindWorkflowContext(nameof(HandoffStartExecutor)) + .QueueStateUpdateAsync(HandoffConstants.HandoffSharedStateKey, + sharedState, + HandoffConstants.HandoffSharedStateScope); + + await runContext.StateManager.PublishUpdatesAsync(null); + + return runContext; + } + [Theory] [InlineData(null, null)] [InlineData(null, true)] @@ -21,7 +51,7 @@ public class HandoffAgentExecutorTests : AIAgentHostingExecutorTestsBase public async Task Test_HandoffAgentExecutor_EmitsStreamingUpdatesIFFConfiguredAsync(bool? executorSetting, bool? turnSetting) { // Arrange - TestRunContext testContext = new(); + TestRunContext testContext = await PrepareHandoffSharedStateAsync(); TestReplayAgent agent = new(TestMessages, TestAgentId, TestAgentName); HandoffAgentExecutorOptions options = new("", @@ -33,7 +63,7 @@ public async Task Test_HandoffAgentExecutor_EmitsStreamingUpdatesIFFConfiguredAs testContext.ConfigureExecutor(executor); // Act - HandoffState message = new(new(turnSetting), null, []); + HandoffState message = new(new(turnSetting), null, null); await executor.HandleAsync(message, testContext.BindWorkflowContext(executor.Id)); // Assert @@ -49,7 +79,7 @@ public async Task Test_HandoffAgentExecutor_EmitsStreamingUpdatesIFFConfiguredAs public async Task Test_HandoffAgentExecutor_EmitsResponseIFFConfiguredAsync(bool executorSetting) { // Arrange - TestRunContext testContext = new(); + TestRunContext testContext = await PrepareHandoffSharedStateAsync(); TestReplayAgent agent = new(TestMessages, TestAgentId, TestAgentName); HandoffAgentExecutorOptions options = new("", @@ -61,11 +91,120 @@ public async Task Test_HandoffAgentExecutor_EmitsResponseIFFConfiguredAsync(bool testContext.ConfigureExecutor(executor); // Act - HandoffState message = new(new(false), null, []); + HandoffState message = new(new(false), null, null); await executor.HandleAsync(message, testContext.BindWorkflowContext(executor.Id)); // Assert AgentResponseEvent[] updates = testContext.Events.OfType().ToArray(); CheckResponseEventsAgainstTestMessages(updates, expectingResponse: executorSetting, agent.GetDescriptiveId()); } + + [Fact] + public async Task Test_HandoffAgentExecutor_ComposesWithHITLSubworkflowAsync() + { + // Arrange + TestRunContext testContext = await PrepareHandoffSharedStateAsync(); + + SendsRequestExecutor challengeSender = new(); + Workflow subworkflow = new WorkflowBuilder(challengeSender) + .AddExternalRequest(challengeSender, "SendChallengeToUser") + .WithOutputFrom(challengeSender) + .Build(); + + InProcessExecutionEnvironment environment = InProcessExecution.Lockstep.WithCheckpointing(CheckpointManager.CreateInMemory()); + AIAgent subworkflowAgent = subworkflow.AsAIAgent(includeWorkflowOutputsInResponse: true, name: "Subworkflow", executionEnvironment: environment); + HandoffAgentExecutorOptions options = new("", + emitAgentResponseEvents: true, + emitAgentResponseUpdateEvents: true, + HandoffToolCallFilteringBehavior.None); + + HandoffAgentExecutor executor = new(subworkflowAgent, [], options); + Workflow fakeWorkflow = new(executor.Id) { ExecutorBindings = { { executor.Id, executor } } }; + EdgeMap map = new(testContext, fakeWorkflow, null); + + testContext.ConfigureExecutor(executor, map); + + // Validate that our test assumptions hold + string functionCallPortId = $"{HandoffAgentExecutor.IdFor(subworkflowAgent)}_FunctionCall"; + map.TryGetResponsePortExecutorId(functionCallPortId, out string? responsePortExecutorId).Should().BeTrue(); + responsePortExecutorId.Should().Be(executor.Id); + + // Act + HandoffState message = new(new(false), null, null); + await executor.HandleAsync(message, testContext.BindWorkflowContext(executor.Id)); + + await testContext.StateManager.PublishUpdatesAsync(null); + + // Assert + testContext.ExternalRequests.Should().HaveCount(1) + .And.ContainSingle(request => request.IsDataOfType()); + + FunctionCallContent functionCallContent = testContext.ExternalRequests.Single().Data.As()!; + object? requestData = functionCallContent.Arguments!["data"]; + + Challenge? challenge = null; + if (requestData is PortableValue pv) + { + challenge = pv.As(); + } + else + { + challenge = requestData as Challenge; + } + + if (challenge is null) + { + Assert.Fail($"Expected request data to be of type {typeof(Challenge).FullName}, but was {requestData?.GetType().FullName ?? "null"}"); + return; // Unreachable, but analysis cannot infer that Debug.Fail will throw/exit, and UnreachableException is not available on net472 + } + + // Act 2 + string challengeResponse = new(challenge.Value.Reverse().ToArray()); + FunctionResultContent responseContent = new(functionCallContent.CallId, new Response(challengeResponse)); + + RequestPortInfo requestPortInfo = new(new(typeof(Challenge)), new(typeof(Response)), functionCallPortId); + string requestId = $"{functionCallPortId.Length}:{functionCallPortId}:{functionCallContent.CallId}"; + DeliveryMapping? mapping = await map.PrepareDeliveryForResponseAsync(new(requestPortInfo, requestId, new(responseContent))); + + mapping!.Deliveries.Should().HaveCount(1); + + MessageDelivery delivery = mapping!.Deliveries.Single(); + + object? result = await executor.ExecuteCoreAsync(delivery.Envelope.Message, + delivery.Envelope.MessageType, + testContext.BindWorkflowContext(executor.Id)); + } +} + +internal sealed record Challenge(string Value); + +internal sealed record Response(string Value); + +[SendsMessage(typeof(Challenge))] +internal sealed partial class SendsRequestExecutor(string? id = null) : ChatProtocolExecutor(id ?? nameof(SendsRequestExecutor), s_chatOptions) +{ + internal const string ChallengeString = "{C7A762AE-7DAA-4D9C-A647-E64E6DBC35AE}"; + private static string ResponseKey { get; } = new(ChallengeString.Reverse().ToArray()); + + private static readonly ChatProtocolExecutorOptions s_chatOptions = new() + { + AutoSendTurnToken = false + }; + + protected override ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) + => context.SendMessageAsync(new Challenge(ChallengeString), cancellationToken); + + [MessageHandler] + public async ValueTask HandleChallengeResponseAsync(Response response, IWorkflowContext context, CancellationToken cancellationToken = default) + { + if (response.Value != ResponseKey) + { + throw new InvalidOperationException($"Incorrect response received. Expected '{ResponseKey}' but got '{response.Value}'"); + } + + await context.SendMessageAsync(new ChatMessage(ChatRole.Assistant, "Correct response."), cancellationToken) + .ConfigureAwait(false); + + await context.SendMessageAsync(new TurnToken(false), cancellationToken).ConfigureAwait(false); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs index 993a6d462b..dc1072aa72 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs @@ -73,6 +73,7 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi foreach (string input in inputs) { AgentResponse response; + ResponseContinuationToken? continuationToken = null; do { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SampleSmokeTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SampleSmokeTest.cs index 247499b72e..a290948ae4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SampleSmokeTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SampleSmokeTest.cs @@ -314,6 +314,38 @@ internal async Task Test_RunSample_Step11Async(ExecutionEnvironment environment) Action CreateValidator(string expected) => actual => actual.Should().Be(expected); } + public class Step12ExpectedOutputCalculator(int agentCount) + { + private readonly int[] _bookmarks = new int[agentCount]; + private readonly List _history = new(); + private readonly HashSet _skipIndices = new(); + + public IEnumerable ExpectedOutputs => + this._history.Where((element, index) => !this._skipIndices.Contains(index)); + + public void ProcessInput(string newInput) + { + this._skipIndices.Add(this._history.Count); + this._history.Add(newInput); + + for (int i = 0; i < agentCount; i++) + { + int agentId = i + 1; + int agentBookmark = this._bookmarks[i]; + int count = this._history.Count - agentBookmark; + + count.Should().BeGreaterThanOrEqualTo(0); + + foreach (string input in this._history.Skip(agentBookmark).ToList()) + { + this._history.Add($"{agentId}:{input}"); + } + + this._bookmarks[i] = this._history.Count; + } + } + } + [Theory] [InlineData(ExecutionEnvironment.InProcess_Lockstep)] [InlineData(ExecutionEnvironment.InProcess_OffThread)] @@ -322,14 +354,10 @@ internal async Task Test_RunSample_Step12Async(ExecutionEnvironment environment) { List inputs = ["1", "2", "3"]; - using StringWriter writer = new(); - await Step12EntryPoint.RunAsync(writer, environment.ToWorkflowExecutionEnvironment(), inputs); - - string[] lines = writer.ToString().Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries); - // The expectation is that each agent will echo each input along with every echo from previous agents // E.g.: // (user): 1 + // ----- outputs below // (a1): 1:1 // (a2): 2:1 // (a2): 2:1:1 @@ -340,7 +368,35 @@ internal async Task Test_RunSample_Step12Async(ExecutionEnvironment environment) // (a3): 3:2:1 // (a3): 3:2:1:1 - string[] expected = inputs.SelectMany(input => EchoesForInput(input)).ToArray(); + // If there are multiple inputs (there are), then each successive input adds to the depth of the previous + // ones, so, for example, once we do input = "1", "2": + + // (user): 1 + // (a1): 1:1 <- a1 "last seen" + // (a2): 2:1 + // (a2): 2:1:1 <- a2 "last seen" + // (user): 2 + // ----- outputs below + // (a1): 1:2:1 + // (a1): 1:2:1:1 + // (a1): 1:2 <- from user input, a1 "last seen" + // (a2): 2:2 <- from user input (note that a2 seems like it is seeing these in a different "order" than a1 - but it is not) + // (a2): 2:1:2:1 + // (a2): 2:1:2:1:1 + // (a2): 2:1:2 <- from a1's first echo, a2 "last seen" + + Step12ExpectedOutputCalculator outputGenerator = new(Step12EntryPoint.AgentCount); + foreach (string input in inputs) + { + outputGenerator.ProcessInput(input); + } + + string[] expected = outputGenerator.ExpectedOutputs.ToArray(); + + using StringWriter writer = new(); + await Step12EntryPoint.RunAsync(writer, environment.ToWorkflowExecutionEnvironment(), inputs); + + string[] lines = writer.ToString().Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries); Console.Error.WriteLine("Expected lines: "); foreach (string expectedLine in expected) @@ -357,19 +413,6 @@ internal async Task Test_RunSample_Step12Async(ExecutionEnvironment environment) Assert.Collection(lines, expected.Select(CreateValidator).ToArray()); - IEnumerable EchoesForInput(string input) - { - List echoes = [$"{Step12EntryPoint.EchoPrefixForAgent(1)}{input}"]; - for (int i = 2; i <= Step12EntryPoint.AgentCount; i++) - { - string agentPrefix = Step12EntryPoint.EchoPrefixForAgent(i); - List newEchoes = [$"{agentPrefix}{input}", .. echoes.Select(echo => $"{agentPrefix}{echo}")]; - echoes.AddRange(newEchoes); - } - - return echoes; - } - Action CreateValidator(string expected) => actual => actual.Should().Be(expected); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs index f94c463d59..be0d62528d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs @@ -27,6 +27,9 @@ public IExternalRequestSink RegisterPort(RequestPort port) internal TestRunContext ConfigureExecutor(Executor executor, EdgeMap? map = null) { + // Ensure that we have run the ProtocolBuilder + _ = executor.Protocol.Describe(); + executor.AttachRequestContext(new TestExternalRequestContext(this, executor.Id, map)); this.Executors.Add(executor.Id, executor); return this; @@ -42,6 +45,7 @@ internal TestRunContext ConfigureExecutors(IEnumerable executors, Edge return this; } + internal StateManager StateManager { get; } = new(); private sealed class BoundContext( string executorId, TestRunContext runnerContext, @@ -70,16 +74,16 @@ public ValueTask RequestHaltAsync() => this.AddEventAsync(new RequestHaltEvent()); public ValueTask QueueClearScopeAsync(string? scopeName = null, CancellationToken cancellationToken = default) - => default; + => runnerContext.StateManager.ClearStateAsync(executorId, scopeName); public ValueTask QueueStateUpdateAsync(string key, T? value, string? scopeName = null, CancellationToken cancellationToken = default) - => default; + => runnerContext.StateManager.WriteStateAsync(new ScopeId(executorId, scopeName), key, value); public ValueTask ReadStateAsync(string key, string? scopeName = null, CancellationToken cancellationToken = default) - => new(default(T?)); + => runnerContext.StateManager.ReadStateAsync(new ScopeId(executorId, scopeName), key); public ValueTask> ReadStateKeysAsync(string? scopeName = null, CancellationToken cancellationToken = default) - => new([]); + => runnerContext.StateManager.ReadKeysAsync(new ScopeId(executorId, scopeName)); public ValueTask SendMessageAsync(object message, string? targetId = null, CancellationToken cancellationToken = default) => runnerContext.SendMessageAsync(executorId, message, targetId, cancellationToken); From 4fc14f5220b942ad2a5402281d5976e97be56711 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 16 Apr 2026 09:18:47 -0400 Subject: [PATCH 2/4] fix: AgentSession checkpointing using AIAgent's Serialize/Deserialize methods We cannot rely on implicit serialization through `HandoffHostState` because we are missing type information. --- .../Specialized/HandoffAgentExecutor.cs | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs index d9a4ae941f..d9acab96d5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs @@ -44,8 +44,7 @@ internal struct AgentInvocationResult(AgentResponse agentResponse, string? hando internal record HandoffAgentHostState( HandoffState? IncomingState, - int ConversationBookmark, - AgentSession? AgentSession) + int ConversationBookmark) { [MemberNotNullWhen(true, nameof(IncomingState))] [JsonIgnore] @@ -91,7 +90,10 @@ internal sealed class HandoffAgentExecutor : private readonly StateRef _sharedStateRef = new(HandoffConstants.HandoffSharedStateKey, HandoffConstants.HandoffSharedStateScope); - private static HandoffAgentHostState InitialStateFactory() => new(null, 0, null); + internal const string AgentSessionKey = nameof(AgentSession); + private AgentSession? _session; + + private static HandoffAgentHostState InitialStateFactory() => new(null, 0); public HandoffAgentExecutor(AIAgent agent, HashSet handoffs, HandoffAgentExecutorOptions options) : base(IdFor(agent), InitialStateFactory) @@ -331,18 +333,35 @@ protected internal override async ValueTask OnCheckpointingAsync(IWorkflowContex { Task userInputRequestsTask = this._userInputHandler?.OnCheckpointingAsync(UserInputRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; Task functionCallRequestsTask = this._functionCallHandler?.OnCheckpointingAsync(FunctionCallRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + Task agentSessionTask = CheckpointAgentSessionAsync(); Task baseTask = base.OnCheckpointingAsync(context, cancellationToken).AsTask(); - await Task.WhenAll(userInputRequestsTask, functionCallRequestsTask, baseTask).ConfigureAwait(false); + await Task.WhenAll(userInputRequestsTask, functionCallRequestsTask, agentSessionTask, baseTask).ConfigureAwait(false); + + async Task CheckpointAgentSessionAsync() + { + JsonElement? sessionState = this._session is not null ? await this._agent.SerializeSessionAsync(this._session, cancellationToken: cancellationToken).ConfigureAwait(false) : null; + await context.QueueStateUpdateAsync(AgentSessionKey, sessionState, cancellationToken: cancellationToken).ConfigureAwait(false); + } } protected internal override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default) { Task userInputRestoreTask = this._userInputHandler?.OnCheckpointRestoredAsync(UserInputRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; Task functionCallRestoreTask = this._functionCallHandler?.OnCheckpointRestoredAsync(FunctionCallRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + Task agentSessionTask = RestoreAgentSessionAsync(); - await Task.WhenAll(userInputRestoreTask, functionCallRestoreTask).ConfigureAwait(false); + await Task.WhenAll(userInputRestoreTask, functionCallRestoreTask, agentSessionTask).ConfigureAwait(false); await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false); + + async Task RestoreAgentSessionAsync() + { + JsonElement? sessionState = await context.ReadStateAsync(AgentSessionKey, cancellationToken: cancellationToken).ConfigureAwait(false); + if (sessionState.HasValue) + { + this._session = await this._agent.DeserializeSessionAsync(sessionState.Value, cancellationToken: cancellationToken).ConfigureAwait(false); + } + } } private bool HasOutstandingRequests => (this._userInputHandler?.HasPendingRequests == true) || (this._functionCallHandler?.HasPendingRequests == true); @@ -360,14 +379,11 @@ private async ValueTask InvokeAgentAsync(IEnumerable { - if (state.AgentSession == null) - { - state = state with { AgentSession = await this._agent.CreateSessionAsync(ct).ConfigureAwait(false) }; - } + this._session ??= await this._agent.CreateSessionAsync(ct).ConfigureAwait(false); IAsyncEnumerable agentStream = this._agent.RunStreamingAsync(messages, - state.AgentSession, + this._session, options: this._agentOptions, cancellationToken: ct); From e7b4edad6e037f456e3bef986deaa396f55f0011 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 16 Apr 2026 09:19:20 -0400 Subject: [PATCH 3/4] fix: Thread safety issue in `MultiPartyConversation.AllMessages` --- .../Specialized/MultiPartyConversation.cs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs index b587db2197..4d59184f0d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.AI; @@ -11,19 +12,25 @@ internal sealed class MultiPartyConversation private readonly List _history = []; private readonly object _mutex = new(); - public List CloneAllMessages() => this._history.ToList(); + public List CloneAllMessages() + { + lock (this._mutex) + { + return this._history.ToList(); + } + } public (ChatMessage[], int) CollectNewMessages(int bookmark) { lock (this._mutex) { int count = this._history.Count - bookmark; - if (count > 0) + if (count < 0) { - return (this._history.Skip(bookmark).ToArray(), this.CurrentBookmark); + throw new InvalidOperationException($"Bookmark value too large: {bookmark} vs count={count}"); } - return ([], this.CurrentBookmark); + return (this._history.Skip(bookmark).ToArray(), this.CurrentBookmark); } } From f69b9b2c7aced44729796a4ee0bcf85e363112c5 Mon Sep 17 00:00:00 2001 From: Jacob Alber Date: Thu, 16 Apr 2026 10:25:08 -0400 Subject: [PATCH 4/4] fix: Enable unwrapping of FunctionResultContent when ExternalRequest was wrapped into FunctionCallContent --- .../WorkflowHostingExtensions.cs | 2 +- .../WorkflowSession.cs | 37 +++++++++++++++---- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostingExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostingExtensions.cs index 281d0694ac..e91531513d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostingExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostingExtensions.cs @@ -41,7 +41,7 @@ internal static FunctionCallContent ToFunctionCall(this ExternalRequest request) { Dictionary parameters = new() { - { "data", request.Data} + { "data", request.Data } }; return new FunctionCallContent(request.RequestId, request.PortInfo.PortId, parameters); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs index c1f81f0ecf..d015ecdcee 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs @@ -247,7 +247,7 @@ private async ValueTask SendMessagesWithResponseConversionAs hasMatchedResponseForStartExecutor |= string.Equals(responseExecutorId, this._workflow.StartExecutorId, StringComparison.Ordinal); } - AIContent normalizedResponseContent = NormalizeResponseContentForDelivery(content, pendingRequest); + object normalizedResponseContent = NormalizeResponseContentForDelivery(content, pendingRequest); externalResponses.Add((pendingRequest.CreateResponse(normalizedResponseContent), pendingRequest.RequestId)); (matchedContentIds ??= new(StringComparer.Ordinal)).Add(contentId); } @@ -303,14 +303,35 @@ ExternalRequest externalRequest /// /// Rewrites workflow-facing response content back to the original agent-owned content ID. /// - private static AIContent NormalizeResponseContentForDelivery(AIContent content, ExternalRequest request) => content switch + private static object NormalizeResponseContentForDelivery(AIContent content, ExternalRequest request) { - FunctionResultContent functionResultContent when request.TryGetDataAs(out FunctionCallContent? functionCallContent) - => CloneFunctionResultContent(functionResultContent, functionCallContent.CallId), - ToolApprovalResponseContent toolApprovalResponseContent when request.TryGetDataAs(out ToolApprovalRequestContent? toolApprovalRequestContent) - => CloneToolApprovalResponseContent(toolApprovalResponseContent, toolApprovalRequestContent.RequestId), - _ => content, - }; + switch (content) + { + // If we got a FRC, and were expecting a FRC (because the request started out as a FCC, rather than getting converted to + // on at the WorkflowSession boundary), clone it and send it in. + case FunctionResultContent functionResultContent when request.TryGetDataAs(out FunctionCallContent? functionCallContent): + return CloneFunctionResultContent(functionResultContent, functionCallContent.CallId); + case FunctionResultContent functionResultContent when !request.PortInfo.ResponseType.IsMatchPolymorphic(typeof(FunctionResultContent)): + { + object? result = functionResultContent.Result; + if (result != null) + { + if (request.PortInfo.ResponseType.IsMatchPolymorphic(result.GetType()) || result is PortableValue) + { + return result; + } + + throw new InvalidOperationException($"Unexpected result type in FunctionResultContent {result.GetType()}; expecting {request.PortInfo.ResponseType}"); + } + + throw new NotSupportedException($"Null result is not supported when using RequestPort with non-AIContent-typed requests. {functionResultContent}"); + } + case ToolApprovalResponseContent toolApprovalResponseContent when request.TryGetDataAs(out ToolApprovalRequestContent? toolApprovalRequestContent): + return CloneToolApprovalResponseContent(toolApprovalResponseContent, toolApprovalRequestContent.RequestId); + default: + return content; + } + } /// /// Gets the workflow-facing request ID from response content types.