From b6774bf07811b0ad34793c115fa32e65651fe56e Mon Sep 17 00:00:00 2001 From: glmgbj233 <2411434344@qq.com> Date: Thu, 21 May 2026 08:11:11 +0000 Subject: [PATCH 1/2] Harden generated output handling Keep generated output behind the intended trust boundary while preserving the normal safe workflow. Add regression coverage for the unsafe flow and the expected safe behavior. --- sdk/go/agent/agent.go | 30 ++++++++++--- sdk/go/agent/agent_test.go | 91 +++++++++++++++++++++++++++----------- 2 files changed, 90 insertions(+), 31 deletions(-) diff --git a/sdk/go/agent/agent.go b/sdk/go/agent/agent.go index 7f584804c..f7f98f479 100644 --- a/sdk/go/agent/agent.go +++ b/sdk/go/agent/agent.go @@ -2000,13 +2000,15 @@ func (a *Agent) AIWithTools(ctx context.Context, prompt string, config ai.ToolCa }, } + allowedTargets := make(map[string]struct{}, len(tools)) + for _, tool := range tools { + allowedTargets[normalizeToolInvocationTarget(agentToolNameToInvocationTarget(tool.Function.Name))] = struct{}{} + } + callFn := func(ctx context.Context, target string, input map[string]interface{}) (map[string]interface{}, error) { - if strings.Contains(target, ":skill:") { - parts := strings.SplitN(target, ":skill:", 2) - target = parts[0] + "." + parts[1] - } else if strings.Contains(target, ":") { - parts := strings.SplitN(target, ":", 2) - target = parts[0] + "." + parts[1] + target = normalizeToolInvocationTarget(target) + if _, ok := allowedTargets[target]; !ok { + return nil, fmt.Errorf("tool call target %q is not a discovered capability", target) } return a.Call(ctx, target, input) } @@ -2014,6 +2016,22 @@ func (a *Agent) AIWithTools(ctx context.Context, prompt string, config ai.ToolCa return a.aiClient.ExecuteToolCallLoop(ctx, messages, tools, config, callFn) } +func normalizeToolInvocationTarget(target string) string { + if strings.Contains(target, ":skill:") { + parts := strings.SplitN(target, ":skill:", 2) + return parts[0] + "." + parts[1] + } + if strings.Contains(target, ":") { + parts := strings.SplitN(target, ":", 2) + return parts[0] + "." + parts[1] + } + return target +} + +func agentToolNameToInvocationTarget(name string) string { + return strings.ReplaceAll(name, "__", ":") +} + // AIStream makes a streaming AI/LLM call. // Returns channels for streaming chunks and errors. // diff --git a/sdk/go/agent/agent_test.go b/sdk/go/agent/agent_test.go index 412dac41e..79e01dede 100644 --- a/sdk/go/agent/agent_test.go +++ b/sdk/go/agent/agent_test.go @@ -935,6 +935,47 @@ func TestAIWithTools(t *testing.T) { require.Len(t, trace.Calls, 1) assert.Equal(t, "agent-1.lookup", trace.Calls[0].ToolName) }) + + t.Run("rejects model-invented tool targets before outbound execute request", func(t *testing.T) { + var executeCalls atomic.Int32 + var chatRequests atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/discovery/capabilities": + _, _ = w.Write([]byte(`{"discovered_at":"2025-01-01T00:00:00Z","total_agents":1,"total_reasoners":1,"total_skills":0,"pagination":{"limit":50,"offset":0,"has_more":false},"capabilities":[{"agent_id":"agent-1","reasoners":[{"id":"lookup","invocation_target":"agent-1.lookup","input_schema":{"type":"object"}}],"skills":[]}]}`)) + case "/chat/completions": + if chatRequests.Add(1) == 1 { + _ = json.NewEncoder(w).Encode(ai.Response{Choices: []ai.Choice{{Message: ai.Message{ToolCalls: []ai.ToolCall{{ID: "call-1", Type: "function", Function: ai.ToolCallFunction{Name: "agent-1.admin", Arguments: `{"query":"status"}`}}}}}}}) + return + } + _ = json.NewEncoder(w).Encode(ai.Response{Choices: []ai.Choice{{Message: ai.Message{Content: []ai.ContentPart{{Type: "text", Text: "blocked"}}}}}}) + case "/api/v1/execute/agent-1.lookup": + executeCalls.Add(1) + _, _ = w.Write([]byte(`{"status":"open"}`)) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer server.Close() + + agent, err := New(Config{ + NodeID: "agent-1", + Version: "1.0.0", + AgentFieldURL: server.URL, + Logger: log.New(io.Discard, "", 0), + AIConfig: &ai.Config{APIKey: "test-key", BaseURL: server.URL, Model: "gpt-4o"}, + }) + require.NoError(t, err) + + resp, trace, err := agent.AIWithTools(context.Background(), "hello", ai.ToolCallConfig{MaxTurns: 1, MaxToolCalls: 1, PromptConfig: &ai.PromptConfig{}}) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, trace) + require.Len(t, trace.Calls, 1) + assert.Contains(t, trace.Calls[0].Error, "not a discovered capability") + assert.Equal(t, "blocked", resp.Text()) + assert.Equal(t, int32(0), executeCalls.Load()) + }) } func TestRunAndServe_ShutdownOnContextCancel(t *testing.T) { @@ -1446,28 +1487,28 @@ func TestCallLocalUnknownReasoner(t *testing.T) { } func TestCall_TargetPrefixing(t *testing.T) { - var capturedPath string - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedPath = r.URL.Path - - resp := map[string]any{ - "status": "succeeded", - "result": map[string]any{"ok": true}, - } - _ = json.NewEncoder(w).Encode(resp) - })) - defer server.Close() - - agent, _ := New(Config{ - NodeID: "node-1", - Version: "1.0.0", - AgentFieldURL: server.URL, - Logger: log.New(io.Discard, "", 0), - }) - - _, err := agent.Call(context.Background(), "lookup", nil) - require.NoError(t, err) - - assert.Contains(t, capturedPath, "/execute/node-1.lookup") -} \ No newline at end of file + var capturedPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + + resp := map[string]any{ + "status": "succeeded", + "result": map[string]any{"ok": true}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + agent, _ := New(Config{ + NodeID: "node-1", + Version: "1.0.0", + AgentFieldURL: server.URL, + Logger: log.New(io.Discard, "", 0), + }) + + _, err := agent.Call(context.Background(), "lookup", nil) + require.NoError(t, err) + + assert.Contains(t, capturedPath, "/execute/node-1.lookup") +} From c47d33b1eb29834700771100ce509a94fff38602 Mon Sep 17 00:00:00 2001 From: glmgbj233 <2411434344@qq.com> Date: Mon, 25 May 2026 02:34:02 +0000 Subject: [PATCH 2/2] Add coverage for generated tool target handling --- sdk/go/agent/agent_test.go | 109 +++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/sdk/go/agent/agent_test.go b/sdk/go/agent/agent_test.go index 79e01dede..2090b5226 100644 --- a/sdk/go/agent/agent_test.go +++ b/sdk/go/agent/agent_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "log" "net/http" @@ -978,6 +979,36 @@ func TestAIWithTools(t *testing.T) { }) } +func TestNormalizeToolInvocationTarget(t *testing.T) { + tests := []struct { + name string + target string + want string + }{ + { + name: "skill target", + target: "agent-1:skill:lookup", + want: "agent-1.lookup", + }, + { + name: "reasoner target", + target: "agent-1:lookup", + want: "agent-1.lookup", + }, + { + name: "already normalized", + target: "agent-1.lookup", + want: "agent-1.lookup", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, normalizeToolInvocationTarget(tt.target)) + }) + } +} + func TestRunAndServe_ShutdownOnContextCancel(t *testing.T) { var shutdownCalls atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1471,6 +1502,84 @@ func TestCallLocalEmitsStructuredExecutionLogs(t *testing.T) { } } +func TestCallLocalCoversCompletionAndFailureBranches(t *testing.T) { + tests := []struct { + name string + handler func(context.Context, map[string]any) (any, error) + wantEvent string + wantLevel string + wantErr string + }{ + { + name: "success", + handler: func(_ context.Context, input map[string]any) (any, error) { + return map[string]any{"echo": input["msg"]}, nil + }, + wantEvent: "call.local.complete", + wantLevel: "info", + }, + { + name: "failure", + handler: func(context.Context, map[string]any) (any, error) { + return nil, errors.New("handler failed") + }, + wantEvent: "call.local.failed", + wantLevel: "error", + wantErr: "handler failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ag, err := New(Config{ + NodeID: "node-1", + Version: "1.0.0", + Logger: log.New(io.Discard, "", 0), + }) + require.NoError(t, err) + ag.RegisterReasoner("child", tt.handler) + + parentCtx := contextWithExecution(context.Background(), ExecutionContext{ + RunID: "run-1", + ExecutionID: "exec-parent", + WorkflowID: "wf-1", + RootWorkflowID: "wf-1", + ReasonerName: "parent", + AgentNodeID: "node-1", + }) + + stdout, _, callErr := captureOutput(t, func() error { + _, err := ag.CallLocal(parentCtx, "child", map[string]any{"msg": "hi"}) + return err + }) + + if tt.wantErr == "" { + require.NoError(t, callErr) + } else { + require.Error(t, callErr) + assert.Contains(t, callErr.Error(), tt.wantErr) + } + + lines := strings.Split(strings.TrimSpace(stdout), "\n") + require.GreaterOrEqual(t, len(lines), 2) + + var seen bool + for _, line := range lines { + var entry ExecutionLogEntry + require.NoError(t, json.Unmarshal([]byte(line), &entry)) + if entry.EventType != tt.wantEvent { + continue + } + seen = true + assert.Equal(t, "child", entry.ReasonerID) + assert.Equal(t, tt.wantLevel, entry.Level) + assert.Equal(t, "sdk.runtime", entry.Source) + } + assert.True(t, seen, "expected %s log entry", tt.wantEvent) + }) + } +} + func TestCallLocalUnknownReasoner(t *testing.T) { cfg := Config{ NodeID: "node-1",