diff --git a/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs b/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs index ff3ac9408..f735da9e6 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Agents/AgentHookBase.cs @@ -24,44 +24,45 @@ public void SetAgent(Agent agent) _agent = agent; } - public virtual bool OnAgentLoading(ref string id) + public virtual Task OnAgentLoading(string id) { - return true; + return Task.FromResult(id); } - public virtual bool OnInstructionLoaded(string template, IDictionary dict) + public virtual Task OnInstructionLoaded(string template, IDictionary dict) { dict["current_date"] = $"{DateTime.Now:MMM dd, yyyy}"; dict["current_time"] = $"{DateTime.Now:hh:mm tt}"; dict["current_weekday"] = $"{DateTime.Now:dddd}"; dict["current_utc_datetime"] = $"{DateTime.UtcNow}"; - return true; + return Task.FromResult(true); } - public virtual bool OnFunctionsLoaded(List functions) + public virtual Task OnFunctionsLoaded(List functions) { _agent.Functions = functions; - return true; + return Task.FromResult(true); } - public virtual bool OnSamplesLoaded(List samples) + public virtual Task OnSamplesLoaded(List samples) { _agent.Samples = samples; - return true; + return Task.FromResult(true); } - public virtual void OnAgentLoaded(Agent agent) + public virtual Task OnAgentLoaded(Agent agent) { + return Task.CompletedTask; } - public virtual void OnAgentUtilityLoaded(Agent agent) + public virtual Task OnAgentUtilityLoaded(Agent agent) { - + return Task.CompletedTask; } - public virtual void OnAgentMcpToolLoaded(Agent agent) + public virtual Task OnAgentMcpToolLoaded(Agent agent) { - + return Task.CompletedTask; } } diff --git a/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentHook.cs b/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentHook.cs index ca6f042a7..9e6fbab7f 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentHook.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Agents/IAgentHook.cs @@ -10,26 +10,26 @@ public interface IAgentHook : IHookBase /// /// Triggered when agent is loading. - /// Return different agent for redirection purpose. + /// Return different agent id for redirection purpose. /// /// Agent Id - /// - bool OnAgentLoading(ref string id); + /// New agent id if redirection is needed, null otherwise + Task OnAgentLoading(string id); - bool OnInstructionLoaded(string template, IDictionary dict); + Task OnInstructionLoaded(string template, IDictionary dict); - bool OnFunctionsLoaded(List functions); + Task OnFunctionsLoaded(List functions); - bool OnSamplesLoaded(List samples); + Task OnSamplesLoaded(List samples); - void OnAgentUtilityLoaded(Agent agent); + Task OnAgentUtilityLoaded(Agent agent); - void OnAgentMcpToolLoaded(Agent agent); + Task OnAgentMcpToolLoaded(Agent agent); /// /// Triggered when agent is loaded completely. /// /// /// - void OnAgentLoaded(Agent agent); + Task OnAgentLoaded(Agent agent); } diff --git a/src/Infrastructure/BotSharp.Core.A2A/Hooks/A2AAgentHook.cs b/src/Infrastructure/BotSharp.Core.A2A/Hooks/A2AAgentHook.cs index 132453d84..d3d913128 100644 --- a/src/Infrastructure/BotSharp.Core.A2A/Hooks/A2AAgentHook.cs +++ b/src/Infrastructure/BotSharp.Core.A2A/Hooks/A2AAgentHook.cs @@ -23,29 +23,29 @@ public A2AAgentHook(IServiceProvider services, IA2AService a2aService, A2ASettin _a2aSettings = a2aSettings; } - public override bool OnAgentLoading(ref string id) + public override async Task OnAgentLoading(string id) { - var agentId = id; - var remoteConfig = _a2aSettings.Agents?.FirstOrDefault(x => x.Id == agentId); + var remoteConfig = _a2aSettings.Agents?.FirstOrDefault(x => x.Id == id); if (remoteConfig != null) { - return true; + return id; // No redirection needed, continue with current id } - return base.OnAgentLoading(ref id); + return await base.OnAgentLoading(id); } - public override void OnAgentLoaded(Agent agent) + public override async Task OnAgentLoaded(Agent agent) { // Check if this is an A2A remote agent if (agent.Type != AgentType.A2ARemote) { + await base.OnAgentLoaded(agent); return; } var remoteConfig = _a2aSettings.Agents?.FirstOrDefault(x => x.Id == agent.Id); if (remoteConfig != null) { - var agentCard = _a2aService.GetCapabilitiesAsync(remoteConfig.Endpoint).GetAwaiter().GetResult(); + var agentCard = await _a2aService.GetCapabilitiesAsync(remoteConfig.Endpoint); if (agentCard != null) { agent.Name = agentCard.Name; @@ -83,6 +83,6 @@ public override void OnAgentLoaded(Agent agent) }); } } - base.OnAgentLoaded(agent); + await base.OnAgentLoaded(agent); } } diff --git a/src/Infrastructure/BotSharp.Core/Agents/Hooks/BasicAgentHook.cs b/src/Infrastructure/BotSharp.Core/Agents/Hooks/BasicAgentHook.cs index cadeb3b47..332337492 100644 --- a/src/Infrastructure/BotSharp.Core/Agents/Hooks/BasicAgentHook.cs +++ b/src/Infrastructure/BotSharp.Core/Agents/Hooks/BasicAgentHook.cs @@ -11,11 +11,11 @@ public BasicAgentHook(IServiceProvider services, AgentSettings settings) { } - public override void OnAgentUtilityLoaded(Agent agent) + public override Task OnAgentUtilityLoaded(Agent agent) { var conv = _services.GetRequiredService(); var isConvMode = conv.IsConversationMode(); - if (!isConvMode) return; + if (!isConvMode) return Task.CompletedTask; agent.Utilities ??= []; agent.SecondaryFunctions ??= []; @@ -26,6 +26,8 @@ public override void OnAgentUtilityLoaded(Agent agent) agent.SecondaryFunctions = agent.SecondaryFunctions.Concat(functions).DistinctBy(x => x.Name, StringComparer.OrdinalIgnoreCase).ToList(); var contents = templates.Select(x => x.Content); agent.SecondaryInstructions = agent.SecondaryInstructions.Concat(contents).Distinct(StringComparer.OrdinalIgnoreCase).ToList(); + + return Task.CompletedTask; } diff --git a/src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.LoadAgent.cs b/src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.LoadAgent.cs index 031ee9bcf..62647401b 100644 --- a/src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.LoadAgent.cs +++ b/src/Infrastructure/BotSharp.Core/Agents/Services/AgentService.LoadAgent.cs @@ -1,3 +1,4 @@ +using BotSharp.Abstraction.Hooks; using BotSharp.Abstraction.Routing.Models; using System.Collections.Concurrent; @@ -15,7 +16,16 @@ public async Task LoadAgent(string id, bool loadUtility = true) return null; } - HookEmitter.Emit(_services, hook => hook.OnAgentLoading(ref id), id); + var hooks = _services.GetHooks(id); + foreach (var hook in hooks) + { + var newId = await hook.OnAgentLoading(id); + if (!string.IsNullOrEmpty(newId) && newId != id) + { + id = newId; + break; // Only the first hook that redirects takes effect + } + } var agent = await GetAgent(id); if (agent == null) @@ -35,35 +45,35 @@ public async Task LoadAgent(string id, bool loadUtility = true) PopulateState(agent); // After agent is loaded - HookEmitter.Emit(_services, hook => { + await HookEmitter.Emit(_services, async hook => { hook.SetAgent(agent); if (!string.IsNullOrEmpty(agent.Instruction)) { - hook.OnInstructionLoaded(agent.Instruction, agent.TemplateDict); + await hook.OnInstructionLoaded(agent.Instruction, agent.TemplateDict); } if (agent.Functions != null) { - hook.OnFunctionsLoaded(agent.Functions); + await hook.OnFunctionsLoaded(agent.Functions); } if (agent.Samples != null) { - hook.OnSamplesLoaded(agent.Samples); + await hook.OnSamplesLoaded(agent.Samples); } if (loadUtility && !agent.Utilities.IsNullOrEmpty()) { - hook.OnAgentUtilityLoaded(agent); + await hook.OnAgentUtilityLoaded(agent); } if (!agent.McpTools.IsNullOrEmpty()) { - hook.OnAgentMcpToolLoaded(agent); + await hook.OnAgentMcpToolLoaded(agent); } - hook.OnAgentLoaded(agent); + await hook.OnAgentLoaded(agent); }, id); diff --git a/src/Infrastructure/BotSharp.Core/MCP/Hooks/MCPToolAgentHook.cs b/src/Infrastructure/BotSharp.Core/MCP/Hooks/MCPToolAgentHook.cs index b117555fd..ede3dd111 100644 --- a/src/Infrastructure/BotSharp.Core/MCP/Hooks/MCPToolAgentHook.cs +++ b/src/Infrastructure/BotSharp.Core/MCP/Hooks/MCPToolAgentHook.cs @@ -14,7 +14,7 @@ public McpToolAgentHook(IServiceProvider services, AgentSettings settings) { } - public override void OnAgentMcpToolLoaded(Agent agent) + public override async Task OnAgentMcpToolLoaded(Agent agent) { if (agent.Type == AgentType.Routing) { @@ -27,7 +27,7 @@ public override void OnAgentMcpToolLoaded(Agent agent) agent.SecondaryFunctions ??= []; - var functions = GetMcpContent(agent).ConfigureAwait(false).GetAwaiter().GetResult(); + var functions = await GetMcpContent(agent); agent.SecondaryFunctions = agent.SecondaryFunctions.Concat(functions).DistinctBy(x => x.Name, StringComparer.OrdinalIgnoreCase).ToList(); } diff --git a/src/Infrastructure/BotSharp.Core/Routing/Hooks/RoutingAgentHook.cs b/src/Infrastructure/BotSharp.Core/Routing/Hooks/RoutingAgentHook.cs index 4fc978b1c..7522a869e 100644 --- a/src/Infrastructure/BotSharp.Core/Routing/Hooks/RoutingAgentHook.cs +++ b/src/Infrastructure/BotSharp.Core/Routing/Hooks/RoutingAgentHook.cs @@ -13,11 +13,11 @@ public RoutingAgentHook(IServiceProvider services, AgentSettings settings, Routi _routingSetting = routingSetting; } - public override bool OnInstructionLoaded(string template, IDictionary dict) + public override async Task OnInstructionLoaded(string template, IDictionary dict) { if (_agent.Type != AgentType.Routing) { - return base.OnInstructionLoaded(template, dict); + return await base.OnInstructionLoaded(template, dict); } dict["router"] = _agent; @@ -59,10 +59,10 @@ public override bool OnInstructionLoaded(string template, IDictionary functions) + public override async Task OnFunctionsLoaded(List functions) { if (_agent.Type == AgentType.Task) { @@ -73,7 +73,7 @@ public override bool OnFunctionsLoaded(List functions) if (rule != null) { var agentService = _services.GetRequiredService(); - var redirectAgent = agentService.GetAgent(rule.RedirectTo).ConfigureAwait(false).GetAwaiter().GetResult(); + var redirectAgent = await agentService.GetAgent(rule.RedirectTo); var json = JsonSerializer.Serialize(new { @@ -111,6 +111,6 @@ public override bool OnFunctionsLoaded(List functions) } } - return base.OnFunctionsLoaded(functions); + return await base.OnFunctionsLoaded(functions); } } diff --git a/src/Plugins/BotSharp.Plugin.Planner/SqlGeneration/Hooks/SqlPlannerAgentHook.cs b/src/Plugins/BotSharp.Plugin.Planner/SqlGeneration/Hooks/SqlPlannerAgentHook.cs index 2b9c1389b..6a490eeb2 100644 --- a/src/Plugins/BotSharp.Plugin.Planner/SqlGeneration/Hooks/SqlPlannerAgentHook.cs +++ b/src/Plugins/BotSharp.Plugin.Planner/SqlGeneration/Hooks/SqlPlannerAgentHook.cs @@ -1,3 +1,9 @@ +using BotSharp.Abstraction.Agents; +using BotSharp.Abstraction.Agents.Settings; +using BotSharp.Abstraction.Knowledges; +using BotSharp.Abstraction.MLTasks; +using BotSharp.Plugin.Planner.Enums; + namespace BotSharp.Plugin.Planner.SqlGeneration.Hooks; public class SqlPlannerAgentHook : AgentHookBase @@ -9,7 +15,7 @@ public SqlPlannerAgentHook(IServiceProvider services, AgentSettings settings) { } - public override bool OnInstructionLoaded(string template, IDictionary dict) + public override async Task OnInstructionLoaded(string template, IDictionary dict) { var knowledgeHooks = _services.GetServices(); @@ -17,10 +23,10 @@ public override bool OnInstructionLoaded(string template, IDictionary(); foreach (var hook in knowledgeHooks) { - var k = hook.GetGlobalKnowledges(new RoleDialogModel(AgentRole.User, template) + var k = await hook.GetGlobalKnowledges(new RoleDialogModel(AgentRole.User, template) { CurrentAgentId = PlannerAgentId.SqlPlanner - }).ConfigureAwait(false).GetAwaiter().GetResult(); + }); knowledges.AddRange(k); } dict["global_knowledges"] = knowledges; diff --git a/tests/BotSharp.Plugin.PizzaBot/Hooks/CommonAgentHook.cs b/tests/BotSharp.Plugin.PizzaBot/Hooks/CommonAgentHook.cs index bc70387dc..f69d5bdcd 100644 --- a/tests/BotSharp.Plugin.PizzaBot/Hooks/CommonAgentHook.cs +++ b/tests/BotSharp.Plugin.PizzaBot/Hooks/CommonAgentHook.cs @@ -11,11 +11,11 @@ public CommonAgentHook(IServiceProvider services, AgentSettings settings) { } - public override bool OnInstructionLoaded(string template, IDictionary dict) + public override async Task OnInstructionLoaded(string template, IDictionary dict) { dict["current_date"] = DateTime.Now.ToString("MM/dd/yyyy"); dict["current_time"] = DateTime.Now.ToString("hh:mm tt"); dict["current_weekday"] = DateTime.Now.DayOfWeek; - return base.OnInstructionLoaded(template, dict); + return await base.OnInstructionLoaded(template, dict); } } diff --git a/tests/BotSharp.Plugin.PizzaBot/Hooks/PizzaBotAgentHook.cs b/tests/BotSharp.Plugin.PizzaBot/Hooks/PizzaBotAgentHook.cs index cbd26eb86..f0378530e 100644 --- a/tests/BotSharp.Plugin.PizzaBot/Hooks/PizzaBotAgentHook.cs +++ b/tests/BotSharp.Plugin.PizzaBot/Hooks/PizzaBotAgentHook.cs @@ -12,8 +12,8 @@ public PizzaBotAgentHook(IServiceProvider services, AgentSettings settings) { } - public override bool OnInstructionLoaded(string template, IDictionary dict) + public override async Task OnInstructionLoaded(string template, IDictionary dict) { - return base.OnInstructionLoaded(template, dict); + return await base.OnInstructionLoaded(template, dict); } }