diff --git a/control-plane/internal/handlers/coverage_handlers_90_test.go b/control-plane/internal/handlers/coverage_handlers_90_test.go index 943c6c9d2..570829668 100644 --- a/control-plane/internal/handlers/coverage_handlers_90_test.go +++ b/control-plane/internal/handlers/coverage_handlers_90_test.go @@ -240,6 +240,7 @@ func TestExecutionNotesCoverageAdditional(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/notes", strings.NewReader(`{"message":" kept "}`)) req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Execution-ID", "exec-2") + req.Header.Set("X-Agent-Node-ID", "node-2") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) diff --git a/control-plane/internal/handlers/execution_notes.go b/control-plane/internal/handlers/execution_notes.go index 495263cdb..cc19bce3c 100644 --- a/control-plane/internal/handlers/execution_notes.go +++ b/control-plane/internal/handlers/execution_notes.go @@ -2,12 +2,14 @@ package handlers import ( "context" + "errors" "fmt" "net/http" "strings" "time" "github.com/Agent-Field/agentfield/control-plane/internal/events" + "github.com/Agent-Field/agentfield/control-plane/internal/server/middleware" "github.com/Agent-Field/agentfield/control-plane/pkg/types" "github.com/gin-gonic/gin" @@ -20,6 +22,22 @@ type ExecutionNoteStorage interface { GetExecutionEventBus() *events.ExecutionEventBus } +type executionNoteDIDDocumentLookup interface { + GetDIDDocument(ctx context.Context, did string) (*types.DIDDocumentRecord, error) +} + +type executionNoteAgentDIDLister interface { + ListAgentDIDs(ctx context.Context) ([]*types.AgentDIDInfo, error) +} + +type executionNoteAuthorizationError struct { + message string +} + +func (e *executionNoteAuthorizationError) Error() string { + return e.message +} + // AddNoteRequest represents the request body for adding a note to an execution type AddNoteRequest struct { Message string `json:"message" binding:"required"` @@ -76,12 +94,21 @@ func AddExecutionNoteHandler(storageProvider ExecutionNoteStorage) gin.HandlerFu } // Update the execution with the new note - ctx := context.Background() + ctx := c.Request.Context() + callerAgentID, err := executionNoteCallerAgentID(ctx, c, storageProvider) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to resolve caller identity: %v", err)}) + return + } + var runID string updated, err := storageProvider.UpdateExecutionRecord(ctx, executionID, func(execution *types.Execution) (*types.Execution, error) { if execution == nil { return nil, fmt.Errorf("execution with ID %s not found", executionID) } + if err := ensureExecutionNoteOwnership(callerAgentID, execution); err != nil { + return nil, err + } // Store run ID for SSE event (run_id is the workflow ID equivalent) runID = execution.RunID @@ -99,6 +126,14 @@ func AddExecutionNoteHandler(storageProvider ExecutionNoteStorage) gin.HandlerFu }) if err != nil { + var authzErr *executionNoteAuthorizationError + if errors.As(err, &authzErr) { + c.JSON(http.StatusForbidden, gin.H{ + "error": "execution_ownership_mismatch", + "message": authzErr.message, + }) + return + } c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to add note: %v", err)}) return } @@ -130,6 +165,71 @@ func AddExecutionNoteHandler(storageProvider ExecutionNoteStorage) gin.HandlerFu } } +func ensureExecutionNoteOwnership(callerAgentID string, execution *types.Execution) error { + ownerAgentID := strings.TrimSpace(execution.AgentNodeID) + if ownerAgentID == "" { + return &executionNoteAuthorizationError{message: "execution owner is required to add notes"} + } + + if callerAgentID == "" { + return &executionNoteAuthorizationError{message: "caller agent identity is required to add notes to this execution"} + } + if callerAgentID != ownerAgentID { + return &executionNoteAuthorizationError{message: "this execution does not belong to the requesting agent"} + } + + return nil +} + +func executionNoteCallerAgentID(ctx context.Context, c *gin.Context, storageProvider ExecutionNoteStorage) (string, error) { + if callerDID := strings.TrimSpace(middleware.GetVerifiedCallerDID(c)); callerDID != "" { + return resolveExecutionNoteAgentIDByDID(ctx, storageProvider, callerDID) + } + + if callerID, exists := c.Get(string(middleware.CallerAgentIDKey)); exists { + if id, ok := callerID.(string); ok { + if id = strings.TrimSpace(id); id != "" { + return id, nil + } + } + } + if agentID := strings.TrimSpace(c.GetHeader("X-Caller-Agent-ID")); agentID != "" { + return agentID, nil + } + if agentID := strings.TrimSpace(c.GetHeader("X-Agent-Node-ID")); agentID != "" { + return agentID, nil + } + + return "", nil +} + +func resolveExecutionNoteAgentIDByDID(ctx context.Context, storageProvider ExecutionNoteStorage, callerDID string) (string, error) { + if lookup, ok := storageProvider.(executionNoteDIDDocumentLookup); ok { + if record, err := lookup.GetDIDDocument(ctx, callerDID); err == nil && record != nil { + return strings.TrimSpace(record.AgentID), nil + } + } + + lister, ok := storageProvider.(executionNoteAgentDIDLister) + if !ok { + return "", nil + } + agentDIDs, err := lister.ListAgentDIDs(ctx) + if err != nil { + return "", fmt.Errorf("failed to resolve caller DID: %w", err) + } + for _, info := range agentDIDs { + if info == nil { + continue + } + if strings.TrimSpace(info.DID) == callerDID { + return strings.TrimSpace(info.AgentNodeID), nil + } + } + + return "", nil +} + // GetExecutionNotesHandler handles GET /api/v1/executions/:execution_id/notes // Retrieves notes for a specific execution with optional tag filtering func GetExecutionNotesHandler(storageProvider ExecutionNoteStorage) gin.HandlerFunc { diff --git a/control-plane/internal/handlers/execution_notes_test.go b/control-plane/internal/handlers/execution_notes_test.go index b59a0d5f3..c2a016151 100644 --- a/control-plane/internal/handlers/execution_notes_test.go +++ b/control-plane/internal/handlers/execution_notes_test.go @@ -3,28 +3,57 @@ package handlers import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" "testing" "time" + "github.com/Agent-Field/agentfield/control-plane/internal/server/middleware" "github.com/Agent-Field/agentfield/control-plane/pkg/types" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +type executionNoteDIDAuthStorage struct { + *testExecutionStorage + didDocuments map[string]*types.DIDDocumentRecord + agentDIDs []*types.AgentDIDInfo + didLookupErr error + listErr error +} + +func (s *executionNoteDIDAuthStorage) GetDIDDocument(ctx context.Context, did string) (*types.DIDDocumentRecord, error) { + if s.didLookupErr != nil { + return nil, s.didLookupErr + } + if s.didDocuments == nil { + return nil, nil + } + return s.didDocuments[did], nil +} + +func (s *executionNoteDIDAuthStorage) ListAgentDIDs(ctx context.Context) ([]*types.AgentDIDInfo, error) { + if s.listErr != nil { + return nil, s.listErr + } + return s.agentDIDs, nil +} + func TestAddExecutionNoteHandler_AppendsNoteAndPublishesEvent(t *testing.T) { gin.SetMode(gin.TestMode) executionID := "exec-1" runID := "wf-1" // run_id is the workflow ID equivalent + agentID := "agent-1" storage := newTestExecutionStorage(nil) exec := &types.Execution{ ExecutionID: executionID, RunID: runID, + AgentNodeID: agentID, Notes: []types.ExecutionNote{}, UpdatedAt: time.Now(), } @@ -43,6 +72,7 @@ func TestAddExecutionNoteHandler_AppendsNoteAndPublishesEvent(t *testing.T) { reqBody := `{"message":"This is a note","tags":["debug"]}` req := httptest.NewRequest(http.MethodPost, "/api/v1/executions/note", strings.NewReader(reqBody)) req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Agent-Node-ID", agentID) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) @@ -72,6 +102,330 @@ func TestAddExecutionNoteHandler_AppendsNoteAndPublishesEvent(t *testing.T) { } } +func TestAddExecutionNoteHandler_RejectsNonOwnerAPIKeyCaller(t *testing.T) { + gin.SetMode(gin.TestMode) + + executionID := "exec-owned-by-b" + storage := newTestExecutionStorage(nil) + require.NoError(t, storage.CreateExecutionRecord(context.Background(), &types.Execution{ + ExecutionID: executionID, + RunID: "run-1", + AgentNodeID: "agent-b", + Notes: []types.ExecutionNote{}, + UpdatedAt: time.Now(), + })) + + router := gin.New() + router.POST("/api/v1/executions/note", AddExecutionNoteHandler(storage)) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/executions/note", strings.NewReader(`{"message":"poisoned note"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Agent-Node-ID", "agent-a") + req.Header.Set("X-Execution-ID", executionID) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + require.Equal(t, http.StatusForbidden, resp.Code) + require.Contains(t, resp.Body.String(), "this execution does not belong to the requesting agent") + + updated, err := storage.GetExecutionRecord(context.Background(), executionID) + require.NoError(t, err) + require.Empty(t, updated.Notes) +} + +func TestAddExecutionNoteHandler_RejectsMissingOwnerOrCaller(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + ownerID string + callerID string + wantStatus int + wantBody string + }{ + { + name: "execution owner missing", + ownerID: "", + callerID: "agent-a", + wantStatus: http.StatusForbidden, + wantBody: "execution owner is required to add notes", + }, + { + name: "caller identity missing", + ownerID: "agent-a", + callerID: "", + wantStatus: http.StatusForbidden, + wantBody: "caller agent identity is required to add notes to this execution", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + executionID := "exec-" + strings.ReplaceAll(tt.name, " ", "-") + storage := newTestExecutionStorage(nil) + require.NoError(t, storage.CreateExecutionRecord(context.Background(), &types.Execution{ + ExecutionID: executionID, + RunID: "run-1", + AgentNodeID: tt.ownerID, + Notes: []types.ExecutionNote{}, + UpdatedAt: time.Now(), + })) + + router := gin.New() + router.POST("/api/v1/executions/note", AddExecutionNoteHandler(storage)) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/executions/note", strings.NewReader(`{"message":"should be rejected"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Execution-ID", executionID) + if tt.callerID != "" { + req.Header.Set("X-Agent-Node-ID", tt.callerID) + } + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + require.Equal(t, tt.wantStatus, resp.Code) + require.Contains(t, resp.Body.String(), tt.wantBody) + + updated, err := storage.GetExecutionRecord(context.Background(), executionID) + require.NoError(t, err) + require.Empty(t, updated.Notes) + }) + } +} + +func TestAddExecutionNoteHandler_DIDCallerOwnership(t *testing.T) { + gin.SetMode(gin.TestMode) + + const callerDID = "did:web:example.com:agents:agent-a" + + tests := []struct { + name string + executionOwner string + wantStatus int + wantNotes int + }{ + { + name: "owner write succeeds", + executionOwner: "agent-a", + wantStatus: http.StatusOK, + wantNotes: 1, + }, + { + name: "non owner write forbidden", + executionOwner: "agent-b", + wantStatus: http.StatusForbidden, + wantNotes: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + executionID := "exec-did-" + strings.ReplaceAll(tt.name, " ", "-") + storage := &executionNoteDIDAuthStorage{ + testExecutionStorage: newTestExecutionStorage(nil), + didDocuments: map[string]*types.DIDDocumentRecord{ + callerDID: { + DID: callerDID, + AgentID: "agent-a", + }, + }, + } + require.NoError(t, storage.CreateExecutionRecord(context.Background(), &types.Execution{ + ExecutionID: executionID, + RunID: "run-did", + AgentNodeID: tt.executionOwner, + Notes: []types.ExecutionNote{}, + UpdatedAt: time.Now(), + })) + + router := gin.New() + router.POST("/api/v1/executions/note", func(c *gin.Context) { + c.Set(string(middleware.VerifiedCallerDIDKey), callerDID) + AddExecutionNoteHandler(storage)(c) + }) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/executions/note", strings.NewReader(`{"message":"did note"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Execution-ID", executionID) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + require.Equal(t, tt.wantStatus, resp.Code) + + updated, err := storage.GetExecutionRecord(context.Background(), executionID) + require.NoError(t, err) + require.Len(t, updated.Notes, tt.wantNotes) + if tt.wantStatus == http.StatusForbidden { + require.Contains(t, resp.Body.String(), "this execution does not belong to the requesting agent") + } + }) + } +} + +func TestAddExecutionNoteHandler_DIDResolutionFailure(t *testing.T) { + gin.SetMode(gin.TestMode) + + const callerDID = "did:web:example.com:agents:agent-a" + + tests := []struct { + name string + storage *executionNoteDIDAuthStorage + wantStatus int + wantBody string + }{ + { + name: "DID resolver error returns server error", + storage: &executionNoteDIDAuthStorage{ + testExecutionStorage: newTestExecutionStorage(nil), + listErr: errors.New("DID registry unavailable"), + }, + wantStatus: http.StatusInternalServerError, + wantBody: "Failed to resolve caller identity", + }, + { + name: "unresolved DID fails closed", + storage: &executionNoteDIDAuthStorage{ + testExecutionStorage: newTestExecutionStorage(nil), + agentDIDs: []*types.AgentDIDInfo{ + {DID: "did:web:example.com:agents:other", AgentNodeID: "agent-other"}, + }, + }, + wantStatus: http.StatusForbidden, + wantBody: "caller agent identity is required to add notes to this execution", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + executionID := "exec-" + strings.ReplaceAll(tt.name, " ", "-") + require.NoError(t, tt.storage.CreateExecutionRecord(context.Background(), &types.Execution{ + ExecutionID: executionID, + RunID: "run-did", + AgentNodeID: "agent-a", + Notes: []types.ExecutionNote{}, + UpdatedAt: time.Now(), + })) + + router := gin.New() + router.POST("/api/v1/executions/note", func(c *gin.Context) { + c.Set(string(middleware.VerifiedCallerDIDKey), callerDID) + AddExecutionNoteHandler(tt.storage)(c) + }) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/executions/note", strings.NewReader(`{"message":"did note"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Execution-ID", executionID) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + require.Equal(t, tt.wantStatus, resp.Code) + require.Contains(t, resp.Body.String(), tt.wantBody) + + updated, err := tt.storage.GetExecutionRecord(context.Background(), executionID) + require.NoError(t, err) + require.Empty(t, updated.Notes) + }) + } +} + +func TestExecutionNoteCallerAgentIDResolution(t *testing.T) { + gin.SetMode(gin.TestMode) + + newContext := func() *gin.Context { + resp := httptest.NewRecorder() + c, _ := gin.CreateTestContext(resp) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/executions/note", nil) + return c + } + + t.Run("authorization error string", func(t *testing.T) { + err := &executionNoteAuthorizationError{message: "denied"} + require.Equal(t, "denied", err.Error()) + }) + + t.Run("caller context takes precedence", func(t *testing.T) { + c := newContext() + c.Set(string(middleware.CallerAgentIDKey), " agent-from-context ") + + got, err := executionNoteCallerAgentID(context.Background(), c, newTestExecutionStorage(nil)) + + require.NoError(t, err) + require.Equal(t, "agent-from-context", got) + }) + + t.Run("caller header fallback", func(t *testing.T) { + c := newContext() + c.Request.Header.Set("X-Caller-Agent-ID", " agent-from-caller ") + c.Request.Header.Set("X-Agent-Node-ID", "agent-from-node") + + got, err := executionNoteCallerAgentID(context.Background(), c, newTestExecutionStorage(nil)) + + require.NoError(t, err) + require.Equal(t, "agent-from-caller", got) + }) + + t.Run("agent node header fallback", func(t *testing.T) { + c := newContext() + c.Request.Header.Set("X-Agent-Node-ID", " agent-from-node ") + + got, err := executionNoteCallerAgentID(context.Background(), c, newTestExecutionStorage(nil)) + + require.NoError(t, err) + require.Equal(t, "agent-from-node", got) + }) + + t.Run("DID list fallback skips nil entries", func(t *testing.T) { + const callerDID = "did:web:example.com:agents:agent-a" + c := newContext() + c.Set(string(middleware.VerifiedCallerDIDKey), callerDID) + storage := &executionNoteDIDAuthStorage{ + testExecutionStorage: newTestExecutionStorage(nil), + agentDIDs: []*types.AgentDIDInfo{ + nil, + {DID: "did:web:example.com:agents:other", AgentNodeID: "agent-other"}, + {DID: callerDID, AgentNodeID: " agent-a "}, + }, + } + + got, err := executionNoteCallerAgentID(context.Background(), c, storage) + + require.NoError(t, err) + require.Equal(t, "agent-a", got) + }) + + t.Run("DID lookup error falls back to list", func(t *testing.T) { + const callerDID = "did:web:example.com:agents:agent-a" + c := newContext() + c.Set(string(middleware.VerifiedCallerDIDKey), callerDID) + storage := &executionNoteDIDAuthStorage{ + testExecutionStorage: newTestExecutionStorage(nil), + didLookupErr: errors.New("lookup failed"), + agentDIDs: []*types.AgentDIDInfo{ + {DID: callerDID, AgentNodeID: "agent-a"}, + }, + } + + got, err := executionNoteCallerAgentID(context.Background(), c, storage) + + require.NoError(t, err) + require.Equal(t, "agent-a", got) + }) + + t.Run("DID with no resolver returns empty caller", func(t *testing.T) { + c := newContext() + c.Set(string(middleware.VerifiedCallerDIDKey), "did:web:example.com:agents:agent-a") + + got, err := executionNoteCallerAgentID(context.Background(), c, newTestExecutionStorage(nil)) + + require.NoError(t, err) + require.Empty(t, got) + }) +} + func TestGetExecutionNotesHandler_ReturnsFilteredNotes(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/control-plane/internal/server/middleware/auth.go b/control-plane/internal/server/middleware/auth.go index 6fc5df474..c82e79dc2 100644 --- a/control-plane/internal/server/middleware/auth.go +++ b/control-plane/internal/server/middleware/auth.go @@ -106,9 +106,9 @@ func APIKeyAuth(config AuthConfig) gin.HandlerFunc { "error": "unauthorized", "message": "invalid or missing API key. Provide via X-API-Key header, Authorization: Bearer , or ?api_key= query param", "help": map[string]string{ - "kb": "GET /api/v1/agentic/kb/topics (public, no auth required)", - "guide": "GET /api/v1/agentic/kb/guide?goal= (public)", - "api_discovery": "GET /api/v1/agentic/discover (requires auth)", + "kb": "GET /api/v1/agentic/kb/topics (public, no auth required)", + "guide": "GET /api/v1/agentic/kb/guide?goal= (public)", + "api_discovery": "GET /api/v1/agentic/discover (requires auth)", "agent_discovery": "GET /api/v1/discovery/capabilities (requires auth — lists live agents, reasoners, skills)", }, }) @@ -117,6 +117,11 @@ func APIKeyAuth(config AuthConfig) gin.HandlerFunc { // Set auth level for downstream handlers (used by agentic API for filtering) c.Set("auth_level", "api_key") + if callerAgentID := strings.TrimSpace(c.GetHeader("X-Caller-Agent-ID")); callerAgentID != "" { + c.Set(string(CallerAgentIDKey), callerAgentID) + } else if callerAgentID := strings.TrimSpace(c.GetHeader("X-Agent-Node-ID")); callerAgentID != "" { + c.Set(string(CallerAgentIDKey), callerAgentID) + } c.Next() } } diff --git a/control-plane/internal/server/middleware/auth_test.go b/control-plane/internal/server/middleware/auth_test.go index b0c76249b..0a66e82cf 100644 --- a/control-plane/internal/server/middleware/auth_test.go +++ b/control-plane/internal/server/middleware/auth_test.go @@ -91,6 +91,55 @@ func TestAPIKeyAuth_ValidQueryParam(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuth_SetsCallerAgentIDContext(t *testing.T) { + tests := []struct { + name string + headers map[string]string + expected string + }{ + { + name: "caller header takes precedence", + headers: map[string]string{ + "X-Caller-Agent-ID": "agent-from-caller", + "X-Agent-Node-ID": "agent-from-node", + }, + expected: "agent-from-caller", + }, + { + name: "agent node header fallback", + headers: map[string]string{ + "X-Agent-Node-ID": "agent-from-node", + }, + expected: "agent-from-node", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router := gin.New() + router.Use(APIKeyAuth(AuthConfig{APIKey: "secret-key"})) + router.GET("/api/v1/test", func(c *gin.Context) { + callerID, _ := c.Get(string(CallerAgentIDKey)) + c.JSON(http.StatusOK, gin.H{"caller_agent_id": callerID}) + }) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/test", nil) + req.Header.Set("X-API-Key", "secret-key") + for key, value := range tt.headers { + req.Header.Set(key, value) + } + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + var resp map[string]string + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(t, tt.expected, resp["caller_agent_id"]) + }) + } +} + func TestAPIKeyAuth_InvalidKey(t *testing.T) { router := setupRouter(AuthConfig{APIKey: "secret-key"})