diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs index 165de39855..8c94f4aa85 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/AIAgentsAbstractionsExtensions.cs @@ -48,7 +48,7 @@ private static ChatMessage ChatAssistantToUserIfNotFromNamed(this ChatMessage me /// any that have a different from to /// . /// - public static List? ChangeAssistantToUserForOtherParticipants(this List messages, string targetAgentName) + public static List? ChangeAssistantToUserForOtherParticipants(this IEnumerable messages, string targetAgentName) { List? roleChanged = null; foreach (var m in messages) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs index eac2eb5687..d9acab96d5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs @@ -6,6 +6,7 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -31,140 +32,6 @@ public HandoffAgentExecutorOptions(string? handoffInstructions, bool emitAgentRe public HandoffToolCallFilteringBehavior ToolCallFilteringBehavior { get; set; } = HandoffToolCallFilteringBehavior.HandoffOnly; } -[Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] -internal sealed class HandoffMessagesFilter -{ - private readonly HandoffToolCallFilteringBehavior _filteringBehavior; - - public HandoffMessagesFilter(HandoffToolCallFilteringBehavior filteringBehavior) - { - this._filteringBehavior = filteringBehavior; - } - - [Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] - internal static bool IsHandoffFunctionName(string name) - { - return name.StartsWith(HandoffWorkflowBuilder.FunctionPrefix, StringComparison.Ordinal); - } - - public IEnumerable FilterMessages(List messages) - { - if (this._filteringBehavior == HandoffToolCallFilteringBehavior.None) - { - return messages; - } - - Dictionary filteringCandidates = new(); - List filteredMessages = []; - HashSet messagesToRemove = []; - - bool filterHandoffOnly = this._filteringBehavior == HandoffToolCallFilteringBehavior.HandoffOnly; - foreach (ChatMessage unfilteredMessage in messages) - { - ChatMessage filteredMessage = unfilteredMessage.Clone(); - - // .Clone() is shallow, so we cannot modify the contents of the cloned message in place. - List contents = []; - contents.Capacity = unfilteredMessage.Contents?.Count ?? 0; - filteredMessage.Contents = contents; - - // Because this runs after the role changes from assistant to user for the target agent, we cannot rely on tool calls - // originating only from messages with the Assistant role. Instead, we need to inspect the contents of all non-Tool (result) - // FunctionCallContent. - if (unfilteredMessage.Role != ChatRole.Tool) - { - for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) - { - AIContent content = unfilteredMessage.Contents[i]; - if (content is not FunctionCallContent fcc || (filterHandoffOnly && !IsHandoffFunctionName(fcc.Name))) - { - filteredMessage.Contents.Add(content); - - // Track non-handoff function calls so their tool results are preserved in HandoffOnly mode - if (filterHandoffOnly && content is FunctionCallContent nonHandoffFcc) - { - filteringCandidates[nonHandoffFcc.CallId] = new FilterCandidateState(nonHandoffFcc.CallId) - { - IsHandoffFunction = false, - }; - } - } - else if (filterHandoffOnly) - { - if (!filteringCandidates.TryGetValue(fcc.CallId, out FilterCandidateState? candidateState)) - { - filteringCandidates[fcc.CallId] = new FilterCandidateState(fcc.CallId) - { - IsHandoffFunction = true, - }; - } - else - { - candidateState.IsHandoffFunction = true; - (int messageIndex, int contentIndex) = candidateState.FunctionCallResultLocation!.Value; - ChatMessage messageToFilter = filteredMessages[messageIndex]; - messageToFilter.Contents.RemoveAt(contentIndex); - if (messageToFilter.Contents.Count == 0) - { - messagesToRemove.Add(messageIndex); - } - } - } - else - { - // All mode: strip all FunctionCallContent - } - } - } - else - { - if (!filterHandoffOnly) - { - continue; - } - - for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) - { - AIContent content = unfilteredMessage.Contents[i]; - if (content is not FunctionResultContent frc - || (filteringCandidates.TryGetValue(frc.CallId, out FilterCandidateState? candidateState) - && candidateState.IsHandoffFunction is false)) - { - // Either this is not a function result content, so we should let it through, or it is a FRC that - // we know is not related to a handoff call. In either case, we should include it. - filteredMessage.Contents.Add(content); - } - else if (candidateState is null) - { - // We haven't seen the corresponding function call yet, so add it as a candidate to be filtered later - filteringCandidates[frc.CallId] = new FilterCandidateState(frc.CallId) - { - FunctionCallResultLocation = (filteredMessages.Count, filteredMessage.Contents.Count), - }; - } - // else we have seen the corresponding function call and it is a handoff, so we should filter it out. - } - } - - if (filteredMessage.Contents.Count > 0) - { - filteredMessages.Add(filteredMessage); - } - } - - return filteredMessages.Where((_, index) => !messagesToRemove.Contains(index)); - } - - private class FilterCandidateState(string callId) - { - public (int MessageIndex, int ContentIndex)? FunctionCallResultLocation { get; set; } - - public string CallId => callId; - - public bool? IsHandoffFunction { get; set; } - } -} - internal struct AgentInvocationResult(AgentResponse agentResponse, string? handoffTargetId) { public AgentResponse Response => agentResponse; @@ -175,19 +42,31 @@ internal struct AgentInvocationResult(AgentResponse agentResponse, string? hando public bool IsHandoffRequested => this.HandoffTargetId != null; } -internal record HandoffAgentHostState(HandoffState? CurrentTurnState, List FilteredIncomingMessages, List TurnMessages) +internal record HandoffAgentHostState( + HandoffState? IncomingState, + int ConversationBookmark) { - public HandoffState PrepareHandoff(AgentInvocationResult invocationResult, string currentAgentId) - { - if (this.CurrentTurnState == null) - { - throw new InvalidOperationException("Cannot create a handoff request: Out of turn."); - } - - IEnumerable allMessages = [.. this.CurrentTurnState.Messages, .. this.TurnMessages, .. invocationResult.Response.Messages]; + [MemberNotNullWhen(true, nameof(IncomingState))] + [JsonIgnore] + public bool IsTakingTurn => this.IncomingState != null; +} - return new(this.CurrentTurnState.TurnToken, invocationResult.HandoffTargetId, allMessages.ToList(), currentAgentId); - } +internal sealed record StateRef(string Key, string? ScopeName) +{ + public ValueTask InvokeWithStateAsync(Func> invocation, + IWorkflowContext context, + CancellationToken cancellationToken) + => context.InvokeWithStateAsync(invocation, this.Key, this.ScopeName, cancellationToken); + + public ValueTask InvokeWithStateAsync(Func invocation, + IWorkflowContext context, + CancellationToken cancellationToken) + => context.InvokeWithStateAsync( + async (state, ctx, ct) => + { + await invocation(state, ctx, ct).ConfigureAwait(false); + return state; + }, this.Key, this.ScopeName, cancellationToken); } /// Executor used to represent an agent in a handoffs workflow, responding to events. @@ -208,7 +87,13 @@ internal sealed class HandoffAgentExecutor : private readonly HashSet _handoffFunctionNames = []; private readonly Dictionary _handoffFunctionToAgentId = []; - private static HandoffAgentHostState InitialStateFactory() => new(null, [], []); + private readonly StateRef _sharedStateRef = new(HandoffConstants.HandoffSharedStateKey, + HandoffConstants.HandoffSharedStateScope); + + internal const string AgentSessionKey = nameof(AgentSession); + private AgentSession? _session; + + private static HandoffAgentHostState InitialStateFactory() => new(null, 0); public HandoffAgentExecutor(AIAgent agent, HashSet handoffs, HandoffAgentExecutorOptions options) : base(IdFor(agent), InitialStateFactory) @@ -291,13 +176,18 @@ private ValueTask HandleUserInputResponseAsync( // resumes can be processed in one invocation. return this.InvokeWithStateAsync((state, ctx, ct) => { - state.TurnMessages.Add(new ChatMessage(ChatRole.User, [response]) + if (!state.IsTakingTurn) + { + throw new InvalidOperationException("Cannot process user responses when not taking a turn in Handoff Orchestration."); + } + + ChatMessage userMessage = new(ChatRole.User, [response]) { CreatedAt = DateTimeOffset.UtcNow, MessageId = Guid.NewGuid().ToString("N"), - }); + }; - return this.ContinueTurnAsync(state, ctx, ct); + return this.ContinueTurnAsync(state, [userMessage], ctx, ct); }, context, skipCache: false, cancellationToken); } @@ -315,24 +205,44 @@ private ValueTask HandleFunctionResultAsync( // resumes can be processed in one invocation. return this.InvokeWithStateAsync((state, ctx, ct) => { - state.TurnMessages.Add( - new ChatMessage(ChatRole.Tool, [result]) - { - AuthorName = this._agent.Name ?? this._agent.Id, - CreatedAt = DateTimeOffset.UtcNow, - MessageId = Guid.NewGuid().ToString("N"), - }); + if (!state.IsTakingTurn) + { + throw new InvalidOperationException("Cannot process user responses in when not taking a turn in Handoff Orchestration."); + } - return this.ContinueTurnAsync(state, ctx, ct); + ChatMessage toolMessage = new(ChatRole.Tool, [result]) + { + AuthorName = this._agent.Name ?? this._agent.Id, + CreatedAt = DateTimeOffset.UtcNow, + MessageId = Guid.NewGuid().ToString("N"), + }; + + return this.ContinueTurnAsync(state, [toolMessage], ctx, ct); }, context, skipCache: false, cancellationToken); } - private async ValueTask ContinueTurnAsync(HandoffAgentHostState state, IWorkflowContext context, CancellationToken cancellationToken) + private async ValueTask ContinueTurnAsync(HandoffAgentHostState state, List incomingMessages, IWorkflowContext context, CancellationToken cancellationToken, bool skipAddIncoming = false) { - List? roleChanges = state.FilteredIncomingMessages.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); + if (!state.IsTakingTurn) + { + throw new InvalidOperationException("Cannot process user responses in when not taking a turn in Handoff Orchestration."); + } - bool emitUpdateEvents = state.CurrentTurnState!.ShouldEmitStreamingEvents(this._options.EmitAgentResponseUpdateEvents); - AgentInvocationResult result = await this.InvokeAgentAsync([.. state.FilteredIncomingMessages, .. state.TurnMessages], context, emitUpdateEvents, cancellationToken) + // If a handoff was invoked by a previous agent, filter out the handoff function call and tool result messages + // before sending to the underlying agent. These are internal workflow mechanics that confuse the target model + // into ignoring the original user question. + // + // This will not filter out tool responses and approval responses that are part of this agent's turn, which is + // the expected behavior since those are part of the agent's reasoning process. + HandoffMessagesFilter handoffMessagesFilter = new(this._options.ToolCallFilteringBehavior); + IEnumerable messagesForAgent = state.IncomingState.RequestedHandoffTargetAgentId is not null + ? handoffMessagesFilter.FilterMessages(incomingMessages) + : incomingMessages; + + List? roleChanges = messagesForAgent.ChangeAssistantToUserForOtherParticipants(this._agent.Name ?? this._agent.Id); + + bool emitUpdateEvents = state.IncomingState!.ShouldEmitStreamingEvents(this._options.EmitAgentResponseUpdateEvents); + AgentInvocationResult result = await this.InvokeAgentAsync(messagesForAgent, context, emitUpdateEvents, cancellationToken) .ConfigureAwait(false); if (this.HasOutstandingRequests && result.IsHandoffRequested) @@ -342,20 +252,40 @@ private ValueTask HandleFunctionResultAsync( roleChanges.ResetUserToAssistantForChangedRoles(); + int newConversationBookmark = state.ConversationBookmark; + await this._sharedStateRef.InvokeWithStateAsync( + (sharedState, ctx, ct) => + { + if (sharedState == null) + { + throw new InvalidOperationException("Handoff Orchestration shared state was not properly initialized."); + } + + if (!skipAddIncoming) + { + sharedState.Conversation.AddMessages(incomingMessages); + } + + newConversationBookmark = sharedState.Conversation.AddMessages(result.Response.Messages); + + return new ValueTask(); + }, + context, + cancellationToken).ConfigureAwait(false); + // We send on the HandoffState even if handoff is not requested because we might be terminating the processing, but this only // happens if we have no outstanding requests. if (!this.HasOutstandingRequests) { - HandoffState outgoingState = state.PrepareHandoff(result, this._agent.Id); + HandoffState outgoingState = new(state.IncomingState.TurnToken, result.HandoffTargetId, this._agent.Id); await context.SendMessageAsync(outgoingState, cancellationToken).ConfigureAwait(false); - // reset the state for the next handoff (return-to-current is modeled as a new handoff turn, as opposed to "HITL", which - // can be a bit confusing.) - return null; + // reset the state for the next handoff, making sure to keep track of the conversation bookmark, and avoid resetting the + // agent session. (return-to-current is modeled as a new handoff turn, as opposed to "HITL", which can be a bit confusing.) + return state with { IncomingState = null, ConversationBookmark = newConversationBookmark }; } - state.TurnMessages.AddRange(result.Response.Messages); return state; } @@ -363,28 +293,36 @@ public override ValueTask HandleAsync(HandoffState message, IWorkflowContext con { return this.InvokeWithStateAsync(InvokeContinueTurnAsync, context, skipCache: false, cancellationToken); - ValueTask InvokeContinueTurnAsync(HandoffAgentHostState state, IWorkflowContext context, CancellationToken cancellationToken) + async ValueTask InvokeContinueTurnAsync(HandoffAgentHostState state, IWorkflowContext context, CancellationToken cancellationToken) { // Check that we are not getting this message while in the middle of a turn - if (state.CurrentTurnState != null) + if (state.IsTakingTurn) { throw new InvalidOperationException("Cannot have multiple simultaneous conversations in Handoff Orchestration."); } - // If a handoff was invoked by a previous agent, filter out the handoff function - // call and tool result messages before sending to the underlying agent. These - // are internal workflow mechanics that confuse the target model into ignoring the - // original user question. - HandoffMessagesFilter handoffMessagesFilter = new(this._options.ToolCallFilteringBehavior); - IEnumerable messagesForAgent = message.RequestedHandoffTargetAgentId is not null - ? handoffMessagesFilter.FilterMessages(message.Messages) - : message.Messages; + IEnumerable newConversationMessages = []; + int newConversationBookmark = 0; + + await this._sharedStateRef.InvokeWithStateAsync( + (sharedState, ctx, ct) => + { + if (sharedState == null) + { + throw new InvalidOperationException("Handoff Orchestration shared state was not properly initialized."); + } + + (newConversationMessages, newConversationBookmark) = sharedState.Conversation.CollectNewMessages(state.ConversationBookmark); + + return new ValueTask(); + }, + context, + cancellationToken).ConfigureAwait(false); - // This works because the runtime guarantees that a given executor instance will process messages serially, - // though there is no global cross-executor ordering guarantee (and in turn, no canonical message delivery order) - state = new(message, messagesForAgent.ToList(), []); + state = state with { IncomingState = message, ConversationBookmark = newConversationBookmark }; - return this.ContinueTurnAsync(state, context, cancellationToken); + return await this.ContinueTurnAsync(state, newConversationMessages.ToList(), context, cancellationToken, skipAddIncoming: true) + .ConfigureAwait(false); } } @@ -395,18 +333,35 @@ protected internal override async ValueTask OnCheckpointingAsync(IWorkflowContex { Task userInputRequestsTask = this._userInputHandler?.OnCheckpointingAsync(UserInputRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; Task functionCallRequestsTask = this._functionCallHandler?.OnCheckpointingAsync(FunctionCallRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + Task agentSessionTask = CheckpointAgentSessionAsync(); Task baseTask = base.OnCheckpointingAsync(context, cancellationToken).AsTask(); - await Task.WhenAll(userInputRequestsTask, functionCallRequestsTask, baseTask).ConfigureAwait(false); + await Task.WhenAll(userInputRequestsTask, functionCallRequestsTask, agentSessionTask, baseTask).ConfigureAwait(false); + + async Task CheckpointAgentSessionAsync() + { + JsonElement? sessionState = this._session is not null ? await this._agent.SerializeSessionAsync(this._session, cancellationToken: cancellationToken).ConfigureAwait(false) : null; + await context.QueueStateUpdateAsync(AgentSessionKey, sessionState, cancellationToken: cancellationToken).ConfigureAwait(false); + } } protected internal override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default) { Task userInputRestoreTask = this._userInputHandler?.OnCheckpointRestoredAsync(UserInputRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; Task functionCallRestoreTask = this._functionCallHandler?.OnCheckpointRestoredAsync(FunctionCallRequestStateKey, context, cancellationToken).AsTask() ?? Task.CompletedTask; + Task agentSessionTask = RestoreAgentSessionAsync(); - await Task.WhenAll(userInputRestoreTask, functionCallRestoreTask).ConfigureAwait(false); + await Task.WhenAll(userInputRestoreTask, functionCallRestoreTask, agentSessionTask).ConfigureAwait(false); await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false); + + async Task RestoreAgentSessionAsync() + { + JsonElement? sessionState = await context.ReadStateAsync(AgentSessionKey, cancellationToken: cancellationToken).ConfigureAwait(false); + if (sessionState.HasValue) + { + this._session = await this._agent.DeserializeSessionAsync(sessionState.Value, cancellationToken: cancellationToken).ConfigureAwait(false); + } + } } private bool HasOutstandingRequests => (this._userInputHandler?.HasPendingRequests == true) || (this._functionCallHandler?.HasPendingRequests == true); @@ -417,31 +372,43 @@ private async ValueTask InvokeAgentAsync(IEnumerable agentStream = this._agent.RunStreamingAsync( - messages, - options: this._agentOptions, - cancellationToken: cancellationToken); - string? requestedHandoff = null; List updates = []; List candidateRequests = []; - await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false)) - { - await AddUpdateAsync(update, cancellationToken).ConfigureAwait(false); - - collector.ProcessAgentResponseUpdate(update, CollectHandoffRequestsFilter); - bool CollectHandoffRequestsFilter(FunctionCallContent candidateHandoffRequest) + await this.InvokeWithStateAsync( + async (state, ctx, ct) => { - bool isHandoffRequest = this._handoffFunctionNames.Contains(candidateHandoffRequest.Name); - if (isHandoffRequest) + this._session ??= await this._agent.CreateSessionAsync(ct).ConfigureAwait(false); + + IAsyncEnumerable agentStream = + this._agent.RunStreamingAsync(messages, + this._session, + options: this._agentOptions, + cancellationToken: ct); + + await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false)) { - candidateRequests.Add(candidateHandoffRequest); + await AddUpdateAsync(update, ct).ConfigureAwait(false); + + collector.ProcessAgentResponseUpdate(update, CollectHandoffRequestsFilter); + + bool CollectHandoffRequestsFilter(FunctionCallContent candidateHandoffRequest) + { + bool isHandoffRequest = this._handoffFunctionNames.Contains(candidateHandoffRequest.Name); + if (isHandoffRequest) + { + candidateRequests.Add(candidateHandoffRequest); + } + + return !isHandoffRequest; + } } - return !isHandoffRequest; - } - } + return state; + }, + context, + cancellationToken: cancellationToken).ConfigureAwait(false); if (candidateRequests.Count > 1) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs index 0ba8fc3501..c9c75b91c4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffEndExecutor.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -12,23 +13,33 @@ internal sealed class HandoffEndExecutor(bool returnToPrevious) : Executor(Execu { public const string ExecutorId = "HandoffEnd"; + private readonly StateRef _sharedStateRef = new(HandoffConstants.HandoffSharedStateKey, + HandoffConstants.HandoffSharedStateScope); + protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBuilder) => - protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler((handoff, context, cancellationToken) => - this.HandleAsync(handoff, context, cancellationToken))) + protocolBuilder.ConfigureRoutes(routeBuilder => routeBuilder.AddHandler( + (handoff, context, cancellationToken) => this.HandleAsync(handoff, context, cancellationToken))) .YieldsOutput>(); private async ValueTask HandleAsync(HandoffState handoff, IWorkflowContext context, CancellationToken cancellationToken) { - if (returnToPrevious) - { - await context.QueueStateUpdateAsync(HandoffConstants.PreviousAgentTrackerKey, - handoff.PreviousAgentId, - HandoffConstants.PreviousAgentTrackerScope, - cancellationToken) - .ConfigureAwait(false); - } - - await context.YieldOutputAsync(handoff.Messages, cancellationToken).ConfigureAwait(false); + await this._sharedStateRef.InvokeWithStateAsync( + async (HandoffSharedState? sharedState, IWorkflowContext context, CancellationToken cancellationToken) => + { + if (sharedState == null) + { + throw new InvalidOperationException("Handoff Orchestration shared state was not properly initialized."); + } + + if (returnToPrevious) + { + sharedState.PreviousAgentId = handoff.PreviousAgentId; + } + + await context.YieldOutputAsync(sharedState.Conversation.CloneAllMessages(), cancellationToken).ConfigureAwait(false); + + return sharedState; + }, context, cancellationToken).ConfigureAwait(false); } public ValueTask ResetAsync() => default; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs new file mode 100644 index 0000000000..7bc178c2d4 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.Specialized; + +[Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] +internal sealed class HandoffMessagesFilter +{ + private readonly HandoffToolCallFilteringBehavior _filteringBehavior; + + public HandoffMessagesFilter(HandoffToolCallFilteringBehavior filteringBehavior) + { + this._filteringBehavior = filteringBehavior; + } + + [Experimental(DiagnosticConstants.ExperimentalFeatureDiagnostic)] + internal static bool IsHandoffFunctionName(string name) + { + return name.StartsWith(HandoffWorkflowBuilder.FunctionPrefix, StringComparison.Ordinal); + } + + public IEnumerable FilterMessages(IEnumerable messages) + { + if (this._filteringBehavior == HandoffToolCallFilteringBehavior.None) + { + return messages; + } + + Dictionary filteringCandidates = new(); + List filteredMessages = []; + HashSet messagesToRemove = []; + + bool filterHandoffOnly = this._filteringBehavior == HandoffToolCallFilteringBehavior.HandoffOnly; + foreach (ChatMessage unfilteredMessage in messages) + { + ChatMessage filteredMessage = unfilteredMessage.Clone(); + + // .Clone() is shallow, so we cannot modify the contents of the cloned message in place. + List contents = []; + contents.Capacity = unfilteredMessage.Contents?.Count ?? 0; + filteredMessage.Contents = contents; + + // Because this runs after the role changes from assistant to user for the target agent, we cannot rely on tool calls + // originating only from messages with the Assistant role. Instead, we need to inspect the contents of all non-Tool (result) + // FunctionCallContent. + if (unfilteredMessage.Role != ChatRole.Tool) + { + for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) + { + AIContent content = unfilteredMessage.Contents[i]; + if (content is not FunctionCallContent fcc || (filterHandoffOnly && !IsHandoffFunctionName(fcc.Name))) + { + filteredMessage.Contents.Add(content); + + // Track non-handoff function calls so their tool results are preserved in HandoffOnly mode + if (filterHandoffOnly && content is FunctionCallContent nonHandoffFcc) + { + filteringCandidates[nonHandoffFcc.CallId] = new FilterCandidateState(nonHandoffFcc.CallId) + { + IsHandoffFunction = false, + }; + } + } + else if (filterHandoffOnly) + { + if (!filteringCandidates.TryGetValue(fcc.CallId, out FilterCandidateState? candidateState)) + { + filteringCandidates[fcc.CallId] = new FilterCandidateState(fcc.CallId) + { + IsHandoffFunction = true, + }; + } + else + { + candidateState.IsHandoffFunction = true; + (int messageIndex, int contentIndex) = candidateState.FunctionCallResultLocation!.Value; + ChatMessage messageToFilter = filteredMessages[messageIndex]; + messageToFilter.Contents.RemoveAt(contentIndex); + if (messageToFilter.Contents.Count == 0) + { + messagesToRemove.Add(messageIndex); + } + } + } + else + { + // All mode: strip all FunctionCallContent + } + } + } + else + { + if (!filterHandoffOnly) + { + continue; + } + + for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) + { + AIContent content = unfilteredMessage.Contents[i]; + if (content is not FunctionResultContent frc + || (filteringCandidates.TryGetValue(frc.CallId, out FilterCandidateState? candidateState) + && candidateState.IsHandoffFunction is false)) + { + // Either this is not a function result content, so we should let it through, or it is a FRC that + // we know is not related to a handoff call. In either case, we should include it. + filteredMessage.Contents.Add(content); + } + else if (candidateState is null) + { + // We haven't seen the corresponding function call yet, so add it as a candidate to be filtered later + filteringCandidates[frc.CallId] = new FilterCandidateState(frc.CallId) + { + FunctionCallResultLocation = (filteredMessages.Count, filteredMessage.Contents.Count), + }; + } + // else we have seen the corresponding function call and it is a handoff, so we should filter it out. + } + } + + if (filteredMessage.Contents.Count > 0) + { + filteredMessages.Add(filteredMessage); + } + } + + return filteredMessages.Where((_, index) => !messagesToRemove.Contains(index)); + } + + private class FilterCandidateState(string callId) + { + public (int MessageIndex, int ContentIndex)? FunctionCallResultLocation { get; set; } + + public string CallId => callId; + + public bool? IsHandoffFunction { get; set; } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs index 063f73bb6f..223517c35f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffStartExecutor.cs @@ -9,8 +9,23 @@ namespace Microsoft.Agents.AI.Workflows.Specialized; internal static class HandoffConstants { + internal const string HandoffOrchestrationSharedScope = "HandoffOrchestration"; + internal const string PreviousAgentTrackerKey = "LastAgentId"; - internal const string PreviousAgentTrackerScope = "HandoffOrchestration"; + internal const string PreviousAgentTrackerScope = HandoffOrchestrationSharedScope; + + internal const string MultiPartyConversationKey = "MultiPartyConversation"; + internal const string MultiPartyConversationScope = HandoffOrchestrationSharedScope; + + internal const string HandoffSharedStateKey = "SharedState"; + internal const string HandoffSharedStateScope = HandoffOrchestrationSharedScope; +} + +internal sealed class HandoffSharedState +{ + public MultiPartyConversation Conversation { get; } = new(); + + public string? PreviousAgentId { get; set; } } /// Executor used at the start of a handoffs workflow to accumulate messages and emit them as HandoffState upon receiving a turn token. @@ -29,23 +44,25 @@ protected override ProtocolBuilder ConfigureProtocol(ProtocolBuilder protocolBui protected override ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) { - if (returnToPrevious) - { - return context.InvokeWithStateAsync( - async (string? previousAgentId, IWorkflowContext context, CancellationToken cancellationToken) => - { - HandoffState handoffState = new(new(emitEvents), null, messages, previousAgentId); - await context.SendMessageAsync(handoffState, cancellationToken).ConfigureAwait(false); - - return previousAgentId; - }, - HandoffConstants.PreviousAgentTrackerKey, - HandoffConstants.PreviousAgentTrackerScope, - cancellationToken); - } - - HandoffState handoff = new(new(emitEvents), null, messages); - return context.SendMessageAsync(handoff, cancellationToken); + return context.InvokeWithStateAsync( + async (HandoffSharedState? sharedState, IWorkflowContext context, CancellationToken cancellationToken) => + { + sharedState ??= new HandoffSharedState(); + sharedState.Conversation.AddMessages(messages); + + string? previousAgentId = sharedState.PreviousAgentId; + + // If we are configured to return to the previous agent, include the previous agent id in the handoff state. + // If there was no previousAgent, it will still be null. + HandoffState turnState = new(new(emitEvents), null, returnToPrevious ? previousAgentId : null); + + await context.SendMessageAsync(turnState, cancellationToken).ConfigureAwait(false); + + return sharedState; + }, + HandoffConstants.HandoffSharedStateKey, + HandoffConstants.HandoffSharedStateScope, + cancellationToken); } public new ValueTask ResetAsync() => base.ResetAsync(); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs index 644bc7df0e..24cf788cb8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffState.cs @@ -1,12 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Collections.Generic; -using Microsoft.Extensions.AI; - namespace Microsoft.Agents.AI.Workflows.Specialized; internal sealed record class HandoffState( TurnToken TurnToken, string? RequestedHandoffTargetAgentId, - List Messages, string? PreviousAgentId = null); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs new file mode 100644 index 0000000000..4d59184f0d --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/MultiPartyConversation.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.Specialized; + +internal sealed class MultiPartyConversation +{ + private readonly List _history = []; + private readonly object _mutex = new(); + + public List CloneAllMessages() + { + lock (this._mutex) + { + return this._history.ToList(); + } + } + + public (ChatMessage[], int) CollectNewMessages(int bookmark) + { + lock (this._mutex) + { + int count = this._history.Count - bookmark; + if (count < 0) + { + throw new InvalidOperationException($"Bookmark value too large: {bookmark} vs count={count}"); + } + + return (this._history.Skip(bookmark).ToArray(), this.CurrentBookmark); + } + } + + private int CurrentBookmark => this._history.Count; + + public int AddMessages(IEnumerable messages) + { + lock (this._mutex) + { + this._history.AddRange(messages); + return this.CurrentBookmark; + } + } + + public int AddMessage(ChatMessage message) + { + lock (this._mutex) + { + this._history.Add(message); + return this.CurrentBookmark; + } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostingExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostingExtensions.cs index 281d0694ac..e91531513d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostingExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostingExtensions.cs @@ -41,7 +41,7 @@ internal static FunctionCallContent ToFunctionCall(this ExternalRequest request) { Dictionary parameters = new() { - { "data", request.Data} + { "data", request.Data } }; return new FunctionCallContent(request.RequestId, request.PortInfo.PortId, parameters); diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs index c1f81f0ecf..d015ecdcee 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowSession.cs @@ -247,7 +247,7 @@ private async ValueTask SendMessagesWithResponseConversionAs hasMatchedResponseForStartExecutor |= string.Equals(responseExecutorId, this._workflow.StartExecutorId, StringComparison.Ordinal); } - AIContent normalizedResponseContent = NormalizeResponseContentForDelivery(content, pendingRequest); + object normalizedResponseContent = NormalizeResponseContentForDelivery(content, pendingRequest); externalResponses.Add((pendingRequest.CreateResponse(normalizedResponseContent), pendingRequest.RequestId)); (matchedContentIds ??= new(StringComparer.Ordinal)).Add(contentId); } @@ -303,14 +303,35 @@ ExternalRequest externalRequest /// /// Rewrites workflow-facing response content back to the original agent-owned content ID. /// - private static AIContent NormalizeResponseContentForDelivery(AIContent content, ExternalRequest request) => content switch + private static object NormalizeResponseContentForDelivery(AIContent content, ExternalRequest request) { - FunctionResultContent functionResultContent when request.TryGetDataAs(out FunctionCallContent? functionCallContent) - => CloneFunctionResultContent(functionResultContent, functionCallContent.CallId), - ToolApprovalResponseContent toolApprovalResponseContent when request.TryGetDataAs(out ToolApprovalRequestContent? toolApprovalRequestContent) - => CloneToolApprovalResponseContent(toolApprovalResponseContent, toolApprovalRequestContent.RequestId), - _ => content, - }; + switch (content) + { + // If we got a FRC, and were expecting a FRC (because the request started out as a FCC, rather than getting converted to + // on at the WorkflowSession boundary), clone it and send it in. + case FunctionResultContent functionResultContent when request.TryGetDataAs(out FunctionCallContent? functionCallContent): + return CloneFunctionResultContent(functionResultContent, functionCallContent.CallId); + case FunctionResultContent functionResultContent when !request.PortInfo.ResponseType.IsMatchPolymorphic(typeof(FunctionResultContent)): + { + object? result = functionResultContent.Result; + if (result != null) + { + if (request.PortInfo.ResponseType.IsMatchPolymorphic(result.GetType()) || result is PortableValue) + { + return result; + } + + throw new InvalidOperationException($"Unexpected result type in FunctionResultContent {result.GetType()}; expecting {request.PortInfo.ResponseType}"); + } + + throw new NotSupportedException($"Null result is not supported when using RequestPort with non-AIContent-typed requests. {functionResultContent}"); + } + case ToolApprovalResponseContent toolApprovalResponseContent when request.TryGetDataAs(out ToolApprovalRequestContent? toolApprovalRequestContent): + return CloneToolApprovalResponseContent(toolApprovalResponseContent, toolApprovalRequestContent.RequestId); + default: + return content; + } + } /// /// Gets the workflow-facing request ID from response content types. diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs index 1a5b2ea4d1..f4e1252e2a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffAgentExecutorTests.cs @@ -1,13 +1,43 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Checkpointing; +using Microsoft.Agents.AI.Workflows.Execution; +using Microsoft.Agents.AI.Workflows.InProc; +using Microsoft.Agents.AI.Workflows.Sample; using Microsoft.Agents.AI.Workflows.Specialized; +using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI.Workflows.UnitTests; public class HandoffAgentExecutorTests : AIAgentHostingExecutorTestsBase { + private static async ValueTask PrepareHandoffSharedStateAsync(TestRunContext? runContext = null, IEnumerable? messages = null) + { + runContext ??= new(); + + HandoffSharedState sharedState = new(); + + if (messages != null) + { + sharedState.Conversation.AddMessages(messages); + } + + await runContext.BindWorkflowContext(nameof(HandoffStartExecutor)) + .QueueStateUpdateAsync(HandoffConstants.HandoffSharedStateKey, + sharedState, + HandoffConstants.HandoffSharedStateScope); + + await runContext.StateManager.PublishUpdatesAsync(null); + + return runContext; + } + [Theory] [InlineData(null, null)] [InlineData(null, true)] @@ -21,7 +51,7 @@ public class HandoffAgentExecutorTests : AIAgentHostingExecutorTestsBase public async Task Test_HandoffAgentExecutor_EmitsStreamingUpdatesIFFConfiguredAsync(bool? executorSetting, bool? turnSetting) { // Arrange - TestRunContext testContext = new(); + TestRunContext testContext = await PrepareHandoffSharedStateAsync(); TestReplayAgent agent = new(TestMessages, TestAgentId, TestAgentName); HandoffAgentExecutorOptions options = new("", @@ -33,7 +63,7 @@ public async Task Test_HandoffAgentExecutor_EmitsStreamingUpdatesIFFConfiguredAs testContext.ConfigureExecutor(executor); // Act - HandoffState message = new(new(turnSetting), null, []); + HandoffState message = new(new(turnSetting), null, null); await executor.HandleAsync(message, testContext.BindWorkflowContext(executor.Id)); // Assert @@ -49,7 +79,7 @@ public async Task Test_HandoffAgentExecutor_EmitsStreamingUpdatesIFFConfiguredAs public async Task Test_HandoffAgentExecutor_EmitsResponseIFFConfiguredAsync(bool executorSetting) { // Arrange - TestRunContext testContext = new(); + TestRunContext testContext = await PrepareHandoffSharedStateAsync(); TestReplayAgent agent = new(TestMessages, TestAgentId, TestAgentName); HandoffAgentExecutorOptions options = new("", @@ -61,11 +91,120 @@ public async Task Test_HandoffAgentExecutor_EmitsResponseIFFConfiguredAsync(bool testContext.ConfigureExecutor(executor); // Act - HandoffState message = new(new(false), null, []); + HandoffState message = new(new(false), null, null); await executor.HandleAsync(message, testContext.BindWorkflowContext(executor.Id)); // Assert AgentResponseEvent[] updates = testContext.Events.OfType().ToArray(); CheckResponseEventsAgainstTestMessages(updates, expectingResponse: executorSetting, agent.GetDescriptiveId()); } + + [Fact] + public async Task Test_HandoffAgentExecutor_ComposesWithHITLSubworkflowAsync() + { + // Arrange + TestRunContext testContext = await PrepareHandoffSharedStateAsync(); + + SendsRequestExecutor challengeSender = new(); + Workflow subworkflow = new WorkflowBuilder(challengeSender) + .AddExternalRequest(challengeSender, "SendChallengeToUser") + .WithOutputFrom(challengeSender) + .Build(); + + InProcessExecutionEnvironment environment = InProcessExecution.Lockstep.WithCheckpointing(CheckpointManager.CreateInMemory()); + AIAgent subworkflowAgent = subworkflow.AsAIAgent(includeWorkflowOutputsInResponse: true, name: "Subworkflow", executionEnvironment: environment); + HandoffAgentExecutorOptions options = new("", + emitAgentResponseEvents: true, + emitAgentResponseUpdateEvents: true, + HandoffToolCallFilteringBehavior.None); + + HandoffAgentExecutor executor = new(subworkflowAgent, [], options); + Workflow fakeWorkflow = new(executor.Id) { ExecutorBindings = { { executor.Id, executor } } }; + EdgeMap map = new(testContext, fakeWorkflow, null); + + testContext.ConfigureExecutor(executor, map); + + // Validate that our test assumptions hold + string functionCallPortId = $"{HandoffAgentExecutor.IdFor(subworkflowAgent)}_FunctionCall"; + map.TryGetResponsePortExecutorId(functionCallPortId, out string? responsePortExecutorId).Should().BeTrue(); + responsePortExecutorId.Should().Be(executor.Id); + + // Act + HandoffState message = new(new(false), null, null); + await executor.HandleAsync(message, testContext.BindWorkflowContext(executor.Id)); + + await testContext.StateManager.PublishUpdatesAsync(null); + + // Assert + testContext.ExternalRequests.Should().HaveCount(1) + .And.ContainSingle(request => request.IsDataOfType()); + + FunctionCallContent functionCallContent = testContext.ExternalRequests.Single().Data.As()!; + object? requestData = functionCallContent.Arguments!["data"]; + + Challenge? challenge = null; + if (requestData is PortableValue pv) + { + challenge = pv.As(); + } + else + { + challenge = requestData as Challenge; + } + + if (challenge is null) + { + Assert.Fail($"Expected request data to be of type {typeof(Challenge).FullName}, but was {requestData?.GetType().FullName ?? "null"}"); + return; // Unreachable, but analysis cannot infer that Debug.Fail will throw/exit, and UnreachableException is not available on net472 + } + + // Act 2 + string challengeResponse = new(challenge.Value.Reverse().ToArray()); + FunctionResultContent responseContent = new(functionCallContent.CallId, new Response(challengeResponse)); + + RequestPortInfo requestPortInfo = new(new(typeof(Challenge)), new(typeof(Response)), functionCallPortId); + string requestId = $"{functionCallPortId.Length}:{functionCallPortId}:{functionCallContent.CallId}"; + DeliveryMapping? mapping = await map.PrepareDeliveryForResponseAsync(new(requestPortInfo, requestId, new(responseContent))); + + mapping!.Deliveries.Should().HaveCount(1); + + MessageDelivery delivery = mapping!.Deliveries.Single(); + + object? result = await executor.ExecuteCoreAsync(delivery.Envelope.Message, + delivery.Envelope.MessageType, + testContext.BindWorkflowContext(executor.Id)); + } +} + +internal sealed record Challenge(string Value); + +internal sealed record Response(string Value); + +[SendsMessage(typeof(Challenge))] +internal sealed partial class SendsRequestExecutor(string? id = null) : ChatProtocolExecutor(id ?? nameof(SendsRequestExecutor), s_chatOptions) +{ + internal const string ChallengeString = "{C7A762AE-7DAA-4D9C-A647-E64E6DBC35AE}"; + private static string ResponseKey { get; } = new(ChallengeString.Reverse().ToArray()); + + private static readonly ChatProtocolExecutorOptions s_chatOptions = new() + { + AutoSendTurnToken = false + }; + + protected override ValueTask TakeTurnAsync(List messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default) + => context.SendMessageAsync(new Challenge(ChallengeString), cancellationToken); + + [MessageHandler] + public async ValueTask HandleChallengeResponseAsync(Response response, IWorkflowContext context, CancellationToken cancellationToken = default) + { + if (response.Value != ResponseKey) + { + throw new InvalidOperationException($"Incorrect response received. Expected '{ResponseKey}' but got '{response.Value}'"); + } + + await context.SendMessageAsync(new ChatMessage(ChatRole.Assistant, "Correct response."), cancellationToken) + .ConfigureAwait(false); + + await context.SendMessageAsync(new TurnToken(false), cancellationToken).ConfigureAwait(false); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs index 993a6d462b..dc1072aa72 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/12_HandOff_HostAsAgent.cs @@ -73,6 +73,7 @@ public static async ValueTask RunAsync(TextWriter writer, IWorkflowExecutionEnvi foreach (string input in inputs) { AgentResponse response; + ResponseContinuationToken? continuationToken = null; do { diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SampleSmokeTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SampleSmokeTest.cs index 247499b72e..a290948ae4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SampleSmokeTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SampleSmokeTest.cs @@ -314,6 +314,38 @@ internal async Task Test_RunSample_Step11Async(ExecutionEnvironment environment) Action CreateValidator(string expected) => actual => actual.Should().Be(expected); } + public class Step12ExpectedOutputCalculator(int agentCount) + { + private readonly int[] _bookmarks = new int[agentCount]; + private readonly List _history = new(); + private readonly HashSet _skipIndices = new(); + + public IEnumerable ExpectedOutputs => + this._history.Where((element, index) => !this._skipIndices.Contains(index)); + + public void ProcessInput(string newInput) + { + this._skipIndices.Add(this._history.Count); + this._history.Add(newInput); + + for (int i = 0; i < agentCount; i++) + { + int agentId = i + 1; + int agentBookmark = this._bookmarks[i]; + int count = this._history.Count - agentBookmark; + + count.Should().BeGreaterThanOrEqualTo(0); + + foreach (string input in this._history.Skip(agentBookmark).ToList()) + { + this._history.Add($"{agentId}:{input}"); + } + + this._bookmarks[i] = this._history.Count; + } + } + } + [Theory] [InlineData(ExecutionEnvironment.InProcess_Lockstep)] [InlineData(ExecutionEnvironment.InProcess_OffThread)] @@ -322,14 +354,10 @@ internal async Task Test_RunSample_Step12Async(ExecutionEnvironment environment) { List inputs = ["1", "2", "3"]; - using StringWriter writer = new(); - await Step12EntryPoint.RunAsync(writer, environment.ToWorkflowExecutionEnvironment(), inputs); - - string[] lines = writer.ToString().Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries); - // The expectation is that each agent will echo each input along with every echo from previous agents // E.g.: // (user): 1 + // ----- outputs below // (a1): 1:1 // (a2): 2:1 // (a2): 2:1:1 @@ -340,7 +368,35 @@ internal async Task Test_RunSample_Step12Async(ExecutionEnvironment environment) // (a3): 3:2:1 // (a3): 3:2:1:1 - string[] expected = inputs.SelectMany(input => EchoesForInput(input)).ToArray(); + // If there are multiple inputs (there are), then each successive input adds to the depth of the previous + // ones, so, for example, once we do input = "1", "2": + + // (user): 1 + // (a1): 1:1 <- a1 "last seen" + // (a2): 2:1 + // (a2): 2:1:1 <- a2 "last seen" + // (user): 2 + // ----- outputs below + // (a1): 1:2:1 + // (a1): 1:2:1:1 + // (a1): 1:2 <- from user input, a1 "last seen" + // (a2): 2:2 <- from user input (note that a2 seems like it is seeing these in a different "order" than a1 - but it is not) + // (a2): 2:1:2:1 + // (a2): 2:1:2:1:1 + // (a2): 2:1:2 <- from a1's first echo, a2 "last seen" + + Step12ExpectedOutputCalculator outputGenerator = new(Step12EntryPoint.AgentCount); + foreach (string input in inputs) + { + outputGenerator.ProcessInput(input); + } + + string[] expected = outputGenerator.ExpectedOutputs.ToArray(); + + using StringWriter writer = new(); + await Step12EntryPoint.RunAsync(writer, environment.ToWorkflowExecutionEnvironment(), inputs); + + string[] lines = writer.ToString().Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries); Console.Error.WriteLine("Expected lines: "); foreach (string expectedLine in expected) @@ -357,19 +413,6 @@ internal async Task Test_RunSample_Step12Async(ExecutionEnvironment environment) Assert.Collection(lines, expected.Select(CreateValidator).ToArray()); - IEnumerable EchoesForInput(string input) - { - List echoes = [$"{Step12EntryPoint.EchoPrefixForAgent(1)}{input}"]; - for (int i = 2; i <= Step12EntryPoint.AgentCount; i++) - { - string agentPrefix = Step12EntryPoint.EchoPrefixForAgent(i); - List newEchoes = [$"{agentPrefix}{input}", .. echoes.Select(echo => $"{agentPrefix}{echo}")]; - echoes.AddRange(newEchoes); - } - - return echoes; - } - Action CreateValidator(string expected) => actual => actual.Should().Be(expected); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs index f94c463d59..be0d62528d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRunContext.cs @@ -27,6 +27,9 @@ public IExternalRequestSink RegisterPort(RequestPort port) internal TestRunContext ConfigureExecutor(Executor executor, EdgeMap? map = null) { + // Ensure that we have run the ProtocolBuilder + _ = executor.Protocol.Describe(); + executor.AttachRequestContext(new TestExternalRequestContext(this, executor.Id, map)); this.Executors.Add(executor.Id, executor); return this; @@ -42,6 +45,7 @@ internal TestRunContext ConfigureExecutors(IEnumerable executors, Edge return this; } + internal StateManager StateManager { get; } = new(); private sealed class BoundContext( string executorId, TestRunContext runnerContext, @@ -70,16 +74,16 @@ public ValueTask RequestHaltAsync() => this.AddEventAsync(new RequestHaltEvent()); public ValueTask QueueClearScopeAsync(string? scopeName = null, CancellationToken cancellationToken = default) - => default; + => runnerContext.StateManager.ClearStateAsync(executorId, scopeName); public ValueTask QueueStateUpdateAsync(string key, T? value, string? scopeName = null, CancellationToken cancellationToken = default) - => default; + => runnerContext.StateManager.WriteStateAsync(new ScopeId(executorId, scopeName), key, value); public ValueTask ReadStateAsync(string key, string? scopeName = null, CancellationToken cancellationToken = default) - => new(default(T?)); + => runnerContext.StateManager.ReadStateAsync(new ScopeId(executorId, scopeName), key); public ValueTask> ReadStateKeysAsync(string? scopeName = null, CancellationToken cancellationToken = default) - => new([]); + => runnerContext.StateManager.ReadKeysAsync(new ScopeId(executorId, scopeName)); public ValueTask SendMessageAsync(object message, string? targetId = null, CancellationToken cancellationToken = default) => runnerContext.SendMessageAsync(executorId, message, targetId, cancellationToken);