From 373cce13a490deb99efc6ca85093732ec5aa1cd6 Mon Sep 17 00:00:00 2001 From: CoderMungan Date: Tue, 31 Mar 2026 00:25:20 +0300 Subject: [PATCH] feat(mcp): add input sanitization and test coverage MCP-SAN (#49): Input sanitization for the MCP server layer. - Add sanitize package: Content (Markdown structure injection), Reflect (truncate + strip control chars for error messages), SessionID (path-safe session identifiers), StripControl, Truncate - Sanitize all reflected user inputs in dispatch error messages (tool names, prompt names, resource URIs) via sanitize.Reflect - Reject unknown entry types before writing to .context/ files - Enforce MaxContentLen (32KB) on entry content in extract.EntryArgs - Sanitize entry content and optional fields via sanitize.Content and extract.SanitizedOpts before writing - Cap journal source limit to MaxSourceLimit (100) - Sanitize caller identifiers in session events - Add input length constants to config/mcp/cfg - Add error message keys for input-too-long and unknown-entry-type MCP-COV (#50): Comprehensive test coverage for MCP subsystem. - internal/mcp/proto: 22 schema round-trip and edge-case tests - internal/mcp/session: 7 state lifecycle tests (100% coverage) - internal/mcp/server: 4 integration tests (Serve edge cases, prompt add-learning) - internal/mcp/server/def/tool: 9 tool definition tests - internal/mcp/server/def/prompt: 9 prompt definition tests - internal/mcp/server/extract: 7 extraction and sanitization tests - internal/mcp/server/io: 3 WriteJSON tests (100% coverage) - internal/mcp/server/out: 8 response builder tests (100% coverage) - internal/mcp/server/parse: 3 request parsing tests (100% coverage) - internal/mcp/server/stat: 2 statistics tests (100% coverage) - internal/sanitize: 22 sanitization tests (Content, Reflect, SessionID, StripControl, Truncate + existing Filename) - Server package coverage: 73% -> 92% Closes #49 Closes #50 Signed-off-by: CoderMungan --- internal/assets/commands/text/mcp.yaml | 4 + internal/config/embed/text/mcp_err.go | 2 + internal/config/mcp/cfg/config.go | 15 + internal/mcp/proto/schema_test.go | 345 ++++++++++++++++++ internal/mcp/server/def/prompt/prompt_test.go | 150 ++++++++ internal/mcp/server/def/tool/tool_test.go | 136 +++++++ internal/mcp/server/extract/extract.go | 37 +- internal/mcp/server/extract/extract_test.go | 113 ++++++ internal/mcp/server/io/write_test.go | 48 +++ internal/mcp/server/out/out_test.go | 113 ++++++ internal/mcp/server/parse/parse_test.go | 56 +++ internal/mcp/server/resource/dispatch.go | 4 +- internal/mcp/server/route/prompt/dispatch.go | 4 +- internal/mcp/server/route/tool/dispatch.go | 4 +- internal/mcp/server/route/tool/tool.go | 42 ++- internal/mcp/server/server_test.go | 107 ++++++ internal/mcp/server/stat/stat_test.go | 22 ++ internal/mcp/session/state_test.go | 145 ++++++++ internal/sanitize/content.go | 78 ++++ internal/sanitize/doc.go | 5 +- internal/sanitize/path.go | 51 +++ internal/sanitize/reflect.go | 44 +++ internal/sanitize/sanitize_test.go | 172 +++++++++ 23 files changed, 1690 insertions(+), 7 deletions(-) create mode 100644 internal/mcp/proto/schema_test.go create mode 100644 internal/mcp/server/def/prompt/prompt_test.go create mode 100644 internal/mcp/server/def/tool/tool_test.go create mode 100644 internal/mcp/server/extract/extract_test.go create mode 100644 internal/mcp/server/io/write_test.go create mode 100644 internal/mcp/server/out/out_test.go create mode 100644 internal/mcp/server/parse/parse_test.go create mode 100644 internal/mcp/server/stat/stat_test.go create mode 100644 internal/mcp/session/state_test.go create mode 100644 internal/sanitize/content.go create mode 100644 internal/sanitize/path.go create mode 100644 internal/sanitize/reflect.go create mode 100644 internal/sanitize/sanitize_test.go diff --git a/internal/assets/commands/text/mcp.yaml b/internal/assets/commands/text/mcp.yaml index 980d3e9af..51f950457 100644 --- a/internal/assets/commands/text/mcp.yaml +++ b/internal/assets/commands/text/mcp.yaml @@ -346,6 +346,10 @@ mcp.err-unknown-prompt: short: 'unknown prompt: %s' mcp.err-uri-required: short: uri is required +mcp.err-input-too-long: + short: '%s exceeds maximum length (%d bytes)' +mcp.err-unknown-entry-type: + short: 'unknown entry type: %s' mcp.format-watch-completed: short: 'Completed: %s' mcp.format-wrote: diff --git a/internal/config/embed/text/mcp_err.go b/internal/config/embed/text/mcp_err.go index 704bb0925..f985a7950 100644 --- a/internal/config/embed/text/mcp_err.go +++ b/internal/config/embed/text/mcp_err.go @@ -18,4 +18,6 @@ const ( DescKeyMCPErrQueryRequired = "mcp.err-query-required" DescKeyMCPErrUnknownPrompt = "mcp.err-unknown-prompt" DescKeyMCPErrURIRequired = "mcp.err-uri-required" + DescKeyMCPErrInputTooLong = "mcp.err-input-too-long" + DescKeyMCPErrUnknownEntryType = "mcp.err-unknown-entry-type" ) diff --git a/internal/config/mcp/cfg/config.go b/internal/config/mcp/cfg/config.go index 43e21ff52..23e600b00 100644 --- a/internal/config/mcp/cfg/config.go +++ b/internal/config/mcp/cfg/config.go @@ -12,8 +12,23 @@ const ( // DefaultSourceLimit is the max sessions returned by ctx_journal_source. DefaultSourceLimit = 5 + // MaxSourceLimit caps the source limit to prevent unbounded queries. + MaxSourceLimit = 100 // MinWordLen is the shortest word considered for overlap matching. MinWordLen = 4 // MinWordOverlap is the minimum word matches to signal task completion. MinWordOverlap = 2 + + // --- Input length limits (MCP-SAN.1) --- + + // MaxContentLen is the maximum byte length for entry content fields. + MaxContentLen = 32_000 + // MaxNameLen is the maximum byte length for tool/prompt/resource names. + MaxNameLen = 256 + // MaxQueryLen is the maximum byte length for search queries. + MaxQueryLen = 1_000 + // MaxCallerLen is the maximum byte length for caller identifiers. + MaxCallerLen = 128 + // MaxURILen is the maximum byte length for resource URIs. + MaxURILen = 512 ) diff --git a/internal/mcp/proto/schema_test.go b/internal/mcp/proto/schema_test.go new file mode 100644 index 000000000..228bda61e --- /dev/null +++ b/internal/mcp/proto/schema_test.go @@ -0,0 +1,345 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package proto + +import ( + "encoding/json" + "testing" +) + +func roundTrip(t *testing.T, v interface{}, dst interface{}) { + t.Helper() + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := json.Unmarshal(data, dst); err != nil { + t.Fatalf("unmarshal: %v", err) + } +} + +func TestRequestRoundTrip(t *testing.T) { + orig := Request{ + JSONRPC: "2.0", + ID: json.RawMessage(`1`), + Method: "tools/call", + Params: json.RawMessage(`{"name":"ctx_status"}`), + } + var got Request + roundTrip(t, orig, &got) + if got.JSONRPC != orig.JSONRPC { + t.Errorf("JSONRPC = %q, want %q", got.JSONRPC, orig.JSONRPC) + } + if got.Method != orig.Method { + t.Errorf("Method = %q, want %q", got.Method, orig.Method) + } + if string(got.ID) != string(orig.ID) { + t.Errorf("ID = %s, want %s", got.ID, orig.ID) + } +} + +func TestResponseSuccessRoundTrip(t *testing.T) { + orig := Response{ + JSONRPC: "2.0", + ID: json.RawMessage(`1`), + Result: map[string]string{"key": "value"}, + } + var got Response + roundTrip(t, orig, &got) + if got.JSONRPC != "2.0" { + t.Errorf("JSONRPC = %q, want %q", got.JSONRPC, "2.0") + } + if got.Error != nil { + t.Errorf("unexpected error: %v", got.Error) + } +} + +func TestResponseErrorRoundTrip(t *testing.T) { + orig := Response{ + JSONRPC: "2.0", + ID: json.RawMessage(`1`), + Error: &RPCError{ + Code: ErrCodeNotFound, + Message: "method not found", + }, + } + var got Response + roundTrip(t, orig, &got) + if got.Error == nil { + t.Fatal("expected error in response") + } + if got.Error.Code != ErrCodeNotFound { + t.Errorf("Code = %d, want %d", got.Error.Code, ErrCodeNotFound) + } + if got.Error.Message != "method not found" { + t.Errorf("Message = %q, want %q", got.Error.Message, "method not found") + } +} + +func TestNotificationRoundTrip(t *testing.T) { + orig := Notification{ + JSONRPC: "2.0", + Method: "notifications/initialized", + } + var got Notification + roundTrip(t, orig, &got) + if got.Method != "notifications/initialized" { + t.Errorf("Method = %q, want %q", got.Method, "notifications/initialized") + } +} + +func TestRPCErrorWithData(t *testing.T) { + orig := RPCError{ + Code: ErrCodeInvalidArg, + Message: "invalid", + Data: map[string]string{"field": "name"}, + } + var got RPCError + roundTrip(t, orig, &got) + if got.Code != ErrCodeInvalidArg { + t.Errorf("Code = %d, want %d", got.Code, ErrCodeInvalidArg) + } +} + +func TestInitializeParamsRoundTrip(t *testing.T) { + orig := InitializeParams{ + ProtocolVersion: ProtocolVersion, + ClientInfo: AppInfo{Name: "test-client", Version: "1.0.0"}, + } + var got InitializeParams + roundTrip(t, orig, &got) + if got.ProtocolVersion != ProtocolVersion { + t.Errorf("ProtocolVersion = %q, want %q", got.ProtocolVersion, ProtocolVersion) + } + if got.ClientInfo.Name != "test-client" { + t.Errorf("ClientInfo.Name = %q, want %q", got.ClientInfo.Name, "test-client") + } +} + +func TestInitializeResultRoundTrip(t *testing.T) { + orig := InitializeResult{ + ProtocolVersion: ProtocolVersion, + Capabilities: ServerCaps{ + Resources: &ResourcesCap{Subscribe: true, ListChanged: true}, + Tools: &ToolsCap{ListChanged: true}, + Prompts: &PromptsCap{ListChanged: false}, + }, + ServerInfo: AppInfo{Name: "ctx", Version: "0.3.0"}, + } + var got InitializeResult + roundTrip(t, orig, &got) + if got.Capabilities.Resources == nil { + t.Fatal("expected Resources capability") + } + if !got.Capabilities.Resources.Subscribe { + t.Error("expected Subscribe = true") + } +} + +func TestResourceRoundTrip(t *testing.T) { + orig := Resource{ + URI: "ctx://context/tasks", + Name: "tasks", + MimeType: "text/markdown", + } + var got Resource + roundTrip(t, orig, &got) + if got.URI != orig.URI { + t.Errorf("URI = %q, want %q", got.URI, orig.URI) + } +} + +func TestToolRoundTrip(t *testing.T) { + orig := Tool{ + Name: "ctx_status", + InputSchema: InputSchema{ + Type: "object", + Properties: map[string]Property{ + "verbose": {Type: "boolean", Description: "Verbose"}, + }, + Required: []string{"verbose"}, + }, + Annotations: &ToolAnnotations{ReadOnlyHint: true}, + } + var got Tool + roundTrip(t, orig, &got) + if got.Name != "ctx_status" { + t.Errorf("Name = %q, want %q", got.Name, "ctx_status") + } + if got.Annotations == nil || !got.Annotations.ReadOnlyHint { + t.Error("expected ReadOnlyHint = true") + } +} + +func TestCallToolParamsRoundTrip(t *testing.T) { + orig := CallToolParams{ + Name: "ctx_add", + Arguments: map[string]interface{}{"type": "task", "content": "Test"}, + } + var got CallToolParams + roundTrip(t, orig, &got) + if got.Name != "ctx_add" { + t.Errorf("Name = %q, want %q", got.Name, "ctx_add") + } +} + +func TestCallToolResultRoundTrip(t *testing.T) { + orig := CallToolResult{ + Content: []ToolContent{{Type: "text", Text: "Done"}}, + } + var got CallToolResult + roundTrip(t, orig, &got) + if len(got.Content) != 1 { + t.Fatalf("Content count = %d, want 1", len(got.Content)) + } + if got.Content[0].Text != "Done" { + t.Errorf("Text = %q, want %q", got.Content[0].Text, "Done") + } + if got.IsError { + t.Error("expected IsError = false") + } +} + +func TestCallToolResultErrorRoundTrip(t *testing.T) { + orig := CallToolResult{ + Content: []ToolContent{{Type: "text", Text: "failed"}}, + IsError: true, + } + var got CallToolResult + roundTrip(t, orig, &got) + if !got.IsError { + t.Error("expected IsError = true") + } +} + +func TestPromptRoundTrip(t *testing.T) { + orig := Prompt{ + Name: "ctx-session-start", + Arguments: []PromptArgument{ + {Name: "content", Required: true}, + }, + } + var got Prompt + roundTrip(t, orig, &got) + if got.Name != "ctx-session-start" { + t.Errorf("Name = %q, want %q", got.Name, "ctx-session-start") + } + if len(got.Arguments) != 1 || !got.Arguments[0].Required { + t.Error("expected 1 required argument") + } +} + +func TestGetPromptResultRoundTrip(t *testing.T) { + orig := GetPromptResult{ + Description: "Test", + Messages: []PromptMessage{ + {Role: "user", Content: ToolContent{Type: "text", Text: "Hi"}}, + }, + } + var got GetPromptResult + roundTrip(t, orig, &got) + if len(got.Messages) != 1 { + t.Fatalf("Messages count = %d, want 1", len(got.Messages)) + } + if got.Messages[0].Role != "user" { + t.Errorf("Role = %q, want %q", got.Messages[0].Role, "user") + } +} + +func TestSubscribeParamsRoundTrip(t *testing.T) { + orig := SubscribeParams{URI: "ctx://context/tasks"} + var got SubscribeParams + roundTrip(t, orig, &got) + if got.URI != orig.URI { + t.Errorf("URI = %q, want %q", got.URI, orig.URI) + } +} + +func TestUnsubscribeParamsRoundTrip(t *testing.T) { + orig := UnsubscribeParams{URI: "ctx://context/decisions"} + var got UnsubscribeParams + roundTrip(t, orig, &got) + if got.URI != orig.URI { + t.Errorf("URI = %q, want %q", got.URI, orig.URI) + } +} + +func TestResourceUpdatedParamsRoundTrip(t *testing.T) { + orig := ResourceUpdatedParams{URI: "ctx://context/tasks"} + var got ResourceUpdatedParams + roundTrip(t, orig, &got) + if got.URI != orig.URI { + t.Errorf("URI = %q, want %q", got.URI, orig.URI) + } +} + +func TestErrorCodeConstants(t *testing.T) { + if ErrCodeParse != -32700 { + t.Errorf("ErrCodeParse = %d, want -32700", ErrCodeParse) + } + if ErrCodeNotFound != -32601 { + t.Errorf("ErrCodeNotFound = %d, want -32601", ErrCodeNotFound) + } + if ErrCodeInvalidArg != -32602 { + t.Errorf("ErrCodeInvalidArg = %d, want -32602", ErrCodeInvalidArg) + } + if ErrCodeInternal != -32603 { + t.Errorf("ErrCodeInternal = %d, want -32603", ErrCodeInternal) + } +} + +func TestProtocolVersionValue(t *testing.T) { + if ProtocolVersion != "2024-11-05" { + t.Errorf("ProtocolVersion = %q, want %q", ProtocolVersion, "2024-11-05") + } +} + +func TestRequestNilParams(t *testing.T) { + orig := Request{ + JSONRPC: "2.0", + ID: json.RawMessage(`"abc"`), + Method: "ping", + } + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var got Request + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got.Params != nil { + t.Errorf("expected nil Params, got %s", got.Params) + } +} + +func TestResponseNilID(t *testing.T) { + orig := Response{ + JSONRPC: "2.0", + Error: &RPCError{Code: ErrCodeParse, Message: "parse error"}, + } + var got Response + roundTrip(t, orig, &got) + if got.ID != nil { + t.Errorf("expected nil ID, got %s", got.ID) + } +} + +func TestPropertyEnumRoundTrip(t *testing.T) { + orig := Property{ + Type: "string", + Enum: []string{"task", "decision", "learning"}, + } + var got Property + roundTrip(t, orig, &got) + if len(got.Enum) != 3 { + t.Fatalf("Enum count = %d, want 3", len(got.Enum)) + } + if got.Enum[0] != "task" { + t.Errorf("Enum[0] = %q, want %q", got.Enum[0], "task") + } +} diff --git a/internal/mcp/server/def/prompt/prompt_test.go b/internal/mcp/server/def/prompt/prompt_test.go new file mode 100644 index 000000000..a3cebdee8 --- /dev/null +++ b/internal/mcp/server/def/prompt/prompt_test.go @@ -0,0 +1,150 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package prompt + +import ( + "testing" + + cfgPrompt "github.com/ActiveMemory/ctx/internal/config/mcp/prompt" +) + +func TestDefsCount(t *testing.T) { + if len(Defs) != 5 { + t.Errorf("prompt count = %d, want 5", len(Defs)) + } +} + +func TestDefsNoDuplicateNames(t *testing.T) { + seen := make(map[string]bool) + for _, d := range Defs { + if seen[d.Name] { + t.Errorf("duplicate prompt name: %s", d.Name) + } + seen[d.Name] = true + } +} + +func TestDefsAllNamed(t *testing.T) { + for i, d := range Defs { + if d.Name == "" { + t.Errorf("prompt[%d] has empty name", i) + } + } +} + +func TestDefsContainsAllConfigPrompts(t *testing.T) { + want := []string{ + cfgPrompt.SessionStart, + cfgPrompt.AddDecision, + cfgPrompt.AddLearning, + cfgPrompt.Reflect, + cfgPrompt.Checkpoint, + } + names := make(map[string]bool) + for _, d := range Defs { + names[d.Name] = true + } + for _, w := range want { + if !names[w] { + t.Errorf("missing prompt: %s", w) + } + } +} + +func TestDefsAddDecisionArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.AddDecision { + continue + } + if len(d.Arguments) != 4 { + t.Errorf( + "add-decision argument count = %d, want 4", + len(d.Arguments), + ) + } + for _, a := range d.Arguments { + if !a.Required { + t.Errorf( + "argument %q should be required", a.Name, + ) + } + } + return + } + t.Error("add-decision prompt not found") +} + +func TestDefsAddLearningArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.AddLearning { + continue + } + if len(d.Arguments) != 4 { + t.Errorf( + "add-learning argument count = %d, want 4", + len(d.Arguments), + ) + } + for _, a := range d.Arguments { + if !a.Required { + t.Errorf( + "argument %q should be required", a.Name, + ) + } + } + return + } + t.Error("add-learning prompt not found") +} + +func TestDefsSessionStartNoArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.SessionStart { + continue + } + if len(d.Arguments) != 0 { + t.Errorf( + "session-start should have 0 args, got %d", + len(d.Arguments), + ) + } + return + } + t.Error("session-start prompt not found") +} + +func TestDefsReflectNoArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.Reflect { + continue + } + if len(d.Arguments) != 0 { + t.Errorf( + "reflect should have 0 args, got %d", + len(d.Arguments), + ) + } + return + } + t.Error("reflect prompt not found") +} + +func TestDefsCheckpointNoArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.Checkpoint { + continue + } + if len(d.Arguments) != 0 { + t.Errorf( + "checkpoint should have 0 args, got %d", + len(d.Arguments), + ) + } + return + } + t.Error("checkpoint prompt not found") +} diff --git a/internal/mcp/server/def/tool/tool_test.go b/internal/mcp/server/def/tool/tool_test.go new file mode 100644 index 000000000..580c17370 --- /dev/null +++ b/internal/mcp/server/def/tool/tool_test.go @@ -0,0 +1,136 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package tool + +import ( + "testing" + + cfgMcpTool "github.com/ActiveMemory/ctx/internal/config/mcp/tool" + "github.com/ActiveMemory/ctx/internal/mcp/proto" +) + +func TestDefsCount(t *testing.T) { + if len(Defs) != 11 { + t.Errorf("tool count = %d, want 11", len(Defs)) + } +} + +func TestDefsNoDuplicateNames(t *testing.T) { + seen := make(map[string]bool) + for _, d := range Defs { + if seen[d.Name] { + t.Errorf("duplicate tool name: %s", d.Name) + } + seen[d.Name] = true + } +} + +func TestDefsAllNamed(t *testing.T) { + for i, d := range Defs { + if d.Name == "" { + t.Errorf("tool[%d] has empty name", i) + } + } +} + +// Note: Description fields are populated by desc.Text() at package +// init time. They are verified as non-empty in the server integration +// tests where lookup.Init() runs before this package is imported. + +func TestDefsAllHaveObjectSchema(t *testing.T) { + for _, d := range Defs { + if d.InputSchema.Type != "object" { + t.Errorf( + "tool %q schema type = %q, want %q", + d.Name, d.InputSchema.Type, "object", + ) + } + } +} + +func TestDefsContainsAllConfigTools(t *testing.T) { + want := []string{ + cfgMcpTool.Status, + cfgMcpTool.Add, + cfgMcpTool.Complete, + cfgMcpTool.Drift, + cfgMcpTool.JournalSource, + cfgMcpTool.WatchUpdate, + cfgMcpTool.Compact, + cfgMcpTool.Next, + cfgMcpTool.CheckTaskCompletion, + cfgMcpTool.SessionEvent, + cfgMcpTool.Remind, + } + names := make(map[string]bool) + for _, d := range Defs { + names[d.Name] = true + } + for _, w := range want { + if !names[w] { + t.Errorf("missing tool: %s", w) + } + } +} + +func TestDefsAnnotations(t *testing.T) { + for _, d := range Defs { + if d.Annotations == nil { + t.Errorf( + "tool %q has nil annotations", d.Name, + ) + } + } +} + +func TestDefsAddRequiredFields(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgMcpTool.Add { + continue + } + if len(d.InputSchema.Required) < 2 { + t.Errorf( + "add tool requires at least 2 fields, got %d", + len(d.InputSchema.Required), + ) + } + return + } + t.Error("add tool not found in Defs") +} + +func TestDefsMergeProps(t *testing.T) { + dst := map[string]proto.Property{ + "a": {Type: "string"}, + } + src := map[string]proto.Property{ + "b": {Type: "number"}, + } + result := MergeProps(dst, src) + if len(result) != 2 { + t.Errorf("merged length = %d, want 2", len(result)) + } + if result["b"].Type != "number" { + t.Errorf( + "result[b].Type = %q, want %q", + result["b"].Type, "number", + ) + } +} + +func TestDefsEntryAttrProps(t *testing.T) { + props := EntryAttrProps("test.key") + expected := []string{ + "context", "rationale", "consequence", + "lesson", "application", + } + for _, key := range expected { + if _, ok := props[key]; !ok { + t.Errorf("missing entry attr prop: %s", key) + } + } +} diff --git a/internal/mcp/server/extract/extract.go b/internal/mcp/server/extract/extract.go index 441503b0f..5e22b6c7b 100644 --- a/internal/mcp/server/extract/extract.go +++ b/internal/mcp/server/extract/extract.go @@ -7,21 +7,30 @@ package extract import ( + "fmt" + + "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/cli" + "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" "github.com/ActiveMemory/ctx/internal/config/mcp/field" errMcp "github.com/ActiveMemory/ctx/internal/err/mcp" "github.com/ActiveMemory/ctx/internal/mcp/handler" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // EntryArgs extracts required type/content from MCP args. // +// Validates that both fields are present and that content does not +// exceed MaxContentLen. +// // Parameters: // - args: MCP tool arguments // // Returns: // - string: extracted entry type // - string: extracted content string -// - error: non-nil if type or content is missing +// - error: non-nil if type or content is missing, or content too long func EntryArgs( args map[string]interface{}, ) (string, string, error) { @@ -32,6 +41,14 @@ func EntryArgs( return "", "", errMcp.TypeContentRequired() } + // MCP-SAN.1: Enforce input length limits. + if len(content) > cfg.MaxContentLen { + return "", "", fmt.Errorf( + desc.Text(text.DescKeyMCPErrInputTooLong), + "content", cfg.MaxContentLen, + ) + } + return entryType, content, nil } @@ -64,3 +81,21 @@ func Opts(args map[string]interface{}) handler.EntryOpts { } return opts } + +// SanitizedOpts builds EntryOpts with content sanitization applied +// to all text fields. +// +// Parameters: +// - args: MCP tool arguments with optional entry fields +// +// Returns: +// - handler.EntryOpts: sanitized options struct +func SanitizedOpts(args map[string]interface{}) handler.EntryOpts { + opts := Opts(args) + opts.Context = sanitize.Content(opts.Context) + opts.Rationale = sanitize.Content(opts.Rationale) + opts.Consequence = sanitize.Content(opts.Consequence) + opts.Lesson = sanitize.Content(opts.Lesson) + opts.Application = sanitize.Content(opts.Application) + return opts +} diff --git a/internal/mcp/server/extract/extract_test.go b/internal/mcp/server/extract/extract_test.go new file mode 100644 index 000000000..1ac47af8d --- /dev/null +++ b/internal/mcp/server/extract/extract_test.go @@ -0,0 +1,113 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package extract + +import ( + "os" + "strings" + "testing" + + "github.com/ActiveMemory/ctx/internal/assets/read/lookup" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" +) + +func TestMain(m *testing.M) { + lookup.Init() + os.Exit(m.Run()) +} + +func TestEntryArgsValid(t *testing.T) { + args := map[string]interface{}{ + "type": "decision", + "content": "Use Go", + } + typ, content, err := EntryArgs(args) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if typ != "decision" { + t.Errorf("type = %q, want decision", typ) + } + if content != "Use Go" { + t.Errorf("content = %q, want Use Go", content) + } +} + +func TestEntryArgsMissingType(t *testing.T) { + args := map[string]interface{}{"content": "ok"} + _, _, err := EntryArgs(args) + if err == nil { + t.Fatal("expected error for missing type") + } +} + +func TestEntryArgsMissingContent(t *testing.T) { + args := map[string]interface{}{"type": "decision"} + _, _, err := EntryArgs(args) + if err == nil { + t.Fatal("expected error for missing content") + } +} + +func TestEntryArgsTooLong(t *testing.T) { + args := map[string]interface{}{ + "type": "decision", + "content": strings.Repeat("x", cfg.MaxContentLen+1), + } + _, _, err := EntryArgs(args) + if err == nil { + t.Fatal("expected error for content too long") + } +} + +func TestOptsAllFields(t *testing.T) { + args := map[string]interface{}{ + "priority": "high", + "context": "ctx", + "rationale": "because", + "consequence": "result", + "lesson": "learned", + "application": "apply", + } + opts := Opts(args) + if opts.Priority != "high" { + t.Errorf("priority = %q", opts.Priority) + } + if opts.Context != "ctx" { + t.Errorf("context = %q", opts.Context) + } + if opts.Rationale != "because" { + t.Errorf("rationale = %q", opts.Rationale) + } + if opts.Consequence != "result" { + t.Errorf("consequence = %q", opts.Consequence) + } + if opts.Lesson != "learned" { + t.Errorf("lesson = %q", opts.Lesson) + } + if opts.Application != "apply" { + t.Errorf("application = %q", opts.Application) + } +} + +func TestOptsEmpty(t *testing.T) { + opts := Opts(map[string]interface{}{}) + if opts.Priority != "" { + t.Error("expected empty priority") + } +} + +func TestSanitizedOpts(t *testing.T) { + args := map[string]interface{}{ + "context": "safe text", + "rationale": "good reason", + } + opts := SanitizedOpts(args) + if opts.Context != "safe text" { + t.Errorf("context = %q", opts.Context) + } +} diff --git a/internal/mcp/server/io/write_test.go b/internal/mcp/server/io/write_test.go new file mode 100644 index 000000000..ba6157de8 --- /dev/null +++ b/internal/mcp/server/io/write_test.go @@ -0,0 +1,48 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package io + +import ( + "bytes" + "os" + "testing" +) + +func TestWriteJSONSuccess(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + err := w.WriteJSON(map[string]int{"a": 1}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := buf.String(); got != "{\"a\":1}\n" { + t.Errorf("output = %q", got) + } +} + +func TestWriteJSONMarshalError(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + err := w.WriteJSON(make(chan int)) + if err == nil { + t.Fatal("expected marshal error") + } +} + +type errWriter struct{} + +func (errWriter) Write([]byte) (int, error) { + return 0, os.ErrClosed +} + +func TestWriteJSONWriteError(t *testing.T) { + w := NewWriter(errWriter{}) + err := w.WriteJSON("hello") + if err == nil { + t.Fatal("expected write error") + } +} diff --git a/internal/mcp/server/out/out_test.go b/internal/mcp/server/out/out_test.go new file mode 100644 index 000000000..9ff16cdea --- /dev/null +++ b/internal/mcp/server/out/out_test.go @@ -0,0 +1,113 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package out + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/ActiveMemory/ctx/internal/mcp/proto" +) + +func TestOkResponse(t *testing.T) { + id, _ := json.Marshal(1) + resp := OkResponse(id, map[string]string{"k": "v"}) + if resp.JSONRPC != "2.0" { + t.Errorf("jsonrpc = %q", resp.JSONRPC) + } + if resp.Error != nil { + t.Error("unexpected error field") + } +} + +func TestErrResponse(t *testing.T) { + id, _ := json.Marshal(1) + resp := ErrResponse(id, proto.ErrCodeInternal, "boom") + if resp.Error == nil { + t.Fatal("expected error") + } + if resp.Error.Code != proto.ErrCodeInternal { + t.Errorf("code = %d", resp.Error.Code) + } + if resp.Error.Message != "boom" { + t.Errorf("msg = %q", resp.Error.Message) + } +} + +func TestToolOK(t *testing.T) { + id, _ := json.Marshal(1) + resp := ToolOK(id, "ok") + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if r.IsError { + t.Error("unexpected isError") + } + if r.Content[0].Text != "ok" { + t.Errorf("text = %q", r.Content[0].Text) + } +} + +func TestToolError(t *testing.T) { + id, _ := json.Marshal(1) + resp := ToolError(id, "fail") + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if !r.IsError { + t.Error("expected isError") + } +} + +func TestToolResultSuccess(t *testing.T) { + id, _ := json.Marshal(1) + resp := ToolResult(id, "done", nil) + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if r.IsError { + t.Error("unexpected isError") + } +} + +func TestToolResultError(t *testing.T) { + id, _ := json.Marshal(1) + resp := ToolResult(id, "", errors.New("bad")) + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if !r.IsError { + t.Error("expected isError") + } +} + +func TestCallSuccess(t *testing.T) { + id, _ := json.Marshal(1) + resp := Call(id, func() (string, error) { + return "ok", nil + }) + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if r.IsError { + t.Error("unexpected isError") + } +} + +func TestCallError(t *testing.T) { + id, _ := json.Marshal(1) + resp := Call(id, func() (string, error) { + return "", errors.New("oops") + }) + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if !r.IsError { + t.Error("expected isError") + } +} diff --git a/internal/mcp/server/parse/parse_test.go b/internal/mcp/server/parse/parse_test.go new file mode 100644 index 000000000..0532556a0 --- /dev/null +++ b/internal/mcp/server/parse/parse_test.go @@ -0,0 +1,56 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package parse + +import ( + "os" + "testing" + + "github.com/ActiveMemory/ctx/internal/assets/read/lookup" +) + +func TestMain(m *testing.M) { + lookup.Init() + os.Exit(m.Run()) +} + +func TestRequestValid(t *testing.T) { + data := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`) + req, errResp := Request(data) + switch { + case errResp != nil: + t.Fatal("unexpected error response") + case req == nil: + t.Fatal("expected non-nil request") + case req.Method != "ping": + t.Errorf("method = %q, want ping", req.Method) + } +} + +func TestRequestMalformed(t *testing.T) { + req, errResp := Request([]byte(`not-json`)) + if req != nil { + t.Fatal("expected nil request") + } + if errResp == nil || errResp.Error == nil { + t.Fatal("expected error response") + } + if errResp.Error.Code != -32700 { + t.Errorf("code = %d, want -32700", errResp.Error.Code) + } +} + +func TestRequestNotification(t *testing.T) { + data := []byte(`{"jsonrpc":"2.0","method":"notify"}`) + req, errResp := Request(data) + if req != nil { + t.Error("expected nil request for notification") + } + if errResp != nil { + t.Error("expected nil error for notification") + } +} diff --git a/internal/mcp/server/resource/dispatch.go b/internal/mcp/server/resource/dispatch.go index 469e64ba3..5ac1e4acd 100644 --- a/internal/mcp/server/resource/dispatch.go +++ b/internal/mcp/server/resource/dispatch.go @@ -12,10 +12,12 @@ import ( "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" "github.com/ActiveMemory/ctx/internal/context/load" "github.com/ActiveMemory/ctx/internal/mcp/proto" "github.com/ActiveMemory/ctx/internal/mcp/server/catalog" "github.com/ActiveMemory/ctx/internal/mcp/server/out" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // DispatchList returns the pre-built resource list. @@ -77,7 +79,7 @@ func DispatchRead( return out.ErrResponse(req.ID, proto.ErrCodeInvalidArg, fmt.Sprintf( desc.Text(text.DescKeyMCPErrUnknownResource), - params.URI, + sanitize.Reflect(params.URI, cfg.MaxURILen), )) } diff --git a/internal/mcp/server/route/prompt/dispatch.go b/internal/mcp/server/route/prompt/dispatch.go index dcce69264..b7cf89ddf 100644 --- a/internal/mcp/server/route/prompt/dispatch.go +++ b/internal/mcp/server/route/prompt/dispatch.go @@ -12,11 +12,13 @@ import ( "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" "github.com/ActiveMemory/ctx/internal/config/mcp/prompt" "github.com/ActiveMemory/ctx/internal/mcp/handler" "github.com/ActiveMemory/ctx/internal/mcp/proto" defPrompt "github.com/ActiveMemory/ctx/internal/mcp/server/def/prompt" "github.com/ActiveMemory/ctx/internal/mcp/server/out" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // DispatchList returns all available prompts. @@ -71,7 +73,7 @@ func DispatchGet( req.ID, proto.ErrCodeNotFound, fmt.Sprintf( desc.Text(text.DescKeyMCPErrUnknownPrompt), - params.Name, + sanitize.Reflect(params.Name, cfg.MaxNameLen), ), ) } diff --git a/internal/mcp/server/route/tool/dispatch.go b/internal/mcp/server/route/tool/dispatch.go index f9495d7ec..87f18a40e 100644 --- a/internal/mcp/server/route/tool/dispatch.go +++ b/internal/mcp/server/route/tool/dispatch.go @@ -12,11 +12,13 @@ import ( "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" "github.com/ActiveMemory/ctx/internal/config/mcp/tool" "github.com/ActiveMemory/ctx/internal/mcp/handler" "github.com/ActiveMemory/ctx/internal/mcp/proto" defTool "github.com/ActiveMemory/ctx/internal/mcp/server/def/tool" "github.com/ActiveMemory/ctx/internal/mcp/server/out" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // DispatchList returns all available tools. @@ -82,7 +84,7 @@ func DispatchCall( req.ID, proto.ErrCodeNotFound, fmt.Sprintf( desc.Text(text.DescKeyMCPErrUnknownTool), - params.Name, + sanitize.Reflect(params.Name, cfg.MaxNameLen), ), ) } diff --git a/internal/mcp/server/route/tool/tool.go b/internal/mcp/server/route/tool/tool.go index 7a0f02637..90280b5de 100644 --- a/internal/mcp/server/route/tool/tool.go +++ b/internal/mcp/server/route/tool/tool.go @@ -14,6 +14,7 @@ import ( "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/cli" "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/entry" "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" "github.com/ActiveMemory/ctx/internal/config/mcp/field" cfgTime "github.com/ActiveMemory/ctx/internal/config/time" @@ -21,6 +22,7 @@ import ( "github.com/ActiveMemory/ctx/internal/mcp/proto" "github.com/ActiveMemory/ctx/internal/mcp/server/extract" "github.com/ActiveMemory/ctx/internal/mcp/server/out" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // add extracts MCP args and delegates to handler.Add. @@ -40,7 +42,19 @@ func add( if extractErr != nil { return out.ToolError(id, extractErr.Error()) } - t, addErr := h.Add(entryType, content, extract.Opts(args)) + + // MCP-SAN.2: Reject unknown entry types before writing. + if _, ok := entry.CtxFile(entryType); !ok { + return out.ToolError(id, fmt.Sprintf( + desc.Text(text.DescKeyMCPErrUnknownEntryType), + sanitize.Reflect(entryType, cfg.MaxNameLen), + )) + } + + // MCP-SAN.3: Sanitize content before writing to .context/. + content = sanitize.Content(content) + + t, addErr := h.Add(entryType, content, extract.SanitizedOpts(args)) return out.ToolResult(id, t, addErr) } @@ -85,6 +99,11 @@ func journalSource( limit = int(v) } + // MCP-SAN.1: Cap source limit to a reasonable upper bound. + if limit > cfg.MaxSourceLimit { + limit = cfg.MaxSourceLimit + } + var since time.Time if sinceStr, _ := args[field.Since].(string); sinceStr != "" { var parseErr error @@ -121,8 +140,23 @@ func watchUpdate( if extractErr != nil { return out.ToolError(id, extractErr.Error()) } + + // MCP-SAN.2: Reject unknown entry types (allow "complete" as + // special case handled by handler.WatchUpdate). + if entryType != entry.Complete { + if _, ok := entry.CtxFile(entryType); !ok { + return out.ToolError(id, fmt.Sprintf( + desc.Text(text.DescKeyMCPErrUnknownEntryType), + sanitize.Reflect(entryType, cfg.MaxNameLen), + )) + } + } + + // MCP-SAN.3: Sanitize content before writing to .context/. + content = sanitize.Content(content) + t, updateErr := h.WatchUpdate( - entryType, content, extract.Opts(args), + entryType, content, extract.SanitizedOpts(args), ) return out.ToolResult(id, t, updateErr) } @@ -189,6 +223,10 @@ func sessionEvent( ) } caller, _ := args[field.Caller].(string) + + // MCP-SAN.4: Sanitize caller before reflecting in response. + caller = sanitize.Reflect(caller, cfg.MaxCallerLen) + t, eventErr := fn(eventType, caller) return out.ToolResult(id, t, eventErr) } diff --git a/internal/mcp/server/server_test.go b/internal/mcp/server/server_test.go index 8324f5e61..fac988d2c 100644 --- a/internal/mcp/server/server_test.go +++ b/internal/mcp/server/server_test.go @@ -991,6 +991,36 @@ func TestPromptAddDecision(t *testing.T) { } } +func TestPromptAddLearning(t *testing.T) { + srv, _ := newTestServer(t) + resp := request(t, srv, "prompts/get", proto.GetPromptParams{ + Name: "ctx-add-learning", + Arguments: map[string]string{ + "content": "Always validate inputs", + "context": "MCP sanitization work", + "lesson": "Never trust external input", + "application": "Add validation at boundaries", + }, + }) + if resp.Error != nil { + t.Fatalf("unexpected error: %v", resp.Error.Message) + } + raw, _ := json.Marshal(resp.Result) + var result proto.GetPromptResult + if err := json.Unmarshal(raw, &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(result.Messages) == 0 { + t.Fatal("expected message in learning prompt") + } + text := result.Messages[0].Content.Text + if !strings.Contains(text, "Always validate inputs") { + t.Errorf( + "expected learning content in text, got: %s", text, + ) + } +} + func TestPromptReflect(t *testing.T) { srv, _ := newTestServer(t) resp := request(t, srv, "prompts/get", proto.GetPromptParams{ @@ -1154,3 +1184,80 @@ func TestResourcePollerNotification(t *testing.T) { srv.poller.Stop() } + +// --- Serve edge-case tests --- + +// errWriter is an io.Writer that always returns an error. +type errWriter struct{} + +func (errWriter) Write([]byte) (int, error) { + return 0, os.ErrClosed +} + +func TestServeEmptyLines(t *testing.T) { + srv, _ := newTestServer(t) + + // Feed an empty line followed by a valid ping. + idBytes, _ := json.Marshal(1) + req := proto.Request{ + JSONRPC: "2.0", + ID: idBytes, + Method: "ping", + } + line, _ := json.Marshal(req) + + // Empty line + valid request. + input := append([]byte("\n"), line...) + input = append(input, '\n') + + var out bytes.Buffer + srv.in = bytes.NewReader(input) + srv.out = mcpIO.NewWriter(&out) + if err := srv.Serve(); err != nil { + t.Fatalf("serve: %v", err) + } + + var resp proto.Response + if err := json.Unmarshal(out.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp.Error != nil { + t.Errorf("unexpected error: %v", resp.Error.Message) + } +} + +func TestServeParseErrorWriteFailure(t *testing.T) { + srv, _ := newTestServer(t) + + // Feed invalid JSON to trigger a parse error. + srv.in = bytes.NewReader([]byte("not-json\n")) + srv.out = mcpIO.NewWriter(errWriter{}) + + err := srv.Serve() + if err == nil { + t.Fatal("expected write error, got nil") + } +} + +func TestServeDispatchWriteFailure(t *testing.T) { + srv, _ := newTestServer(t) + + // Feed a valid request but use an errWriter for output. + idBytes, _ := json.Marshal(1) + req := proto.Request{ + JSONRPC: "2.0", + ID: idBytes, + Method: "ping", + } + line, _ := json.Marshal(req) + + srv.in = bytes.NewReader(append(line, '\n')) + srv.out = mcpIO.NewWriter(errWriter{}) + + // The marshal itself succeeds but the write fails, triggering + // the fallback error path which also fails, returning the error. + err := srv.Serve() + if err == nil { + t.Fatal("expected write error, got nil") + } +} diff --git a/internal/mcp/server/stat/stat_test.go b/internal/mcp/server/stat/stat_test.go new file mode 100644 index 000000000..be0b30ce3 --- /dev/null +++ b/internal/mcp/server/stat/stat_test.go @@ -0,0 +1,22 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package stat + +import "testing" + +func TestTotalAddsEmpty(t *testing.T) { + if got := TotalAdds(nil); got != 0 { + t.Errorf("TotalAdds(nil) = %d, want 0", got) + } +} + +func TestTotalAddsMultiple(t *testing.T) { + m := map[string]int{"decision": 2, "learning": 3, "convention": 1} + if got := TotalAdds(m); got != 6 { + t.Errorf("TotalAdds = %d, want 6", got) + } +} diff --git a/internal/mcp/session/state_test.go b/internal/mcp/session/state_test.go new file mode 100644 index 000000000..daa4fdcaf --- /dev/null +++ b/internal/mcp/session/state_test.go @@ -0,0 +1,145 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "testing" + "time" +) + +func TestNewState(t *testing.T) { + s := NewState("/tmp/.context") + if s.contextDir != "/tmp/.context" { + t.Errorf( + "contextDir = %q, want %q", + s.contextDir, "/tmp/.context", + ) + } + if s.ToolCalls != 0 { + t.Errorf("ToolCalls = %d, want 0", s.ToolCalls) + } + if s.AddsPerformed == nil { + t.Fatal("AddsPerformed should be initialized") + } + if len(s.AddsPerformed) != 0 { + t.Errorf( + "AddsPerformed length = %d, want 0", + len(s.AddsPerformed), + ) + } + if s.sessionStartedAt.IsZero() { + t.Error("sessionStartedAt should be set") + } + if len(s.PendingFlush) != 0 { + t.Errorf( + "PendingFlush length = %d, want 0", + len(s.PendingFlush), + ) + } +} + +func TestRecordToolCall(t *testing.T) { + s := NewState("/tmp/.context") + s.RecordToolCall() + if s.ToolCalls != 1 { + t.Errorf("ToolCalls = %d, want 1", s.ToolCalls) + } + s.RecordToolCall() + s.RecordToolCall() + if s.ToolCalls != 3 { + t.Errorf("ToolCalls = %d, want 3", s.ToolCalls) + } +} + +func TestRecordAdd(t *testing.T) { + s := NewState("/tmp/.context") + s.RecordAdd("task") + s.RecordAdd("task") + s.RecordAdd("decision") + if s.AddsPerformed["task"] != 2 { + t.Errorf( + "task adds = %d, want 2", + s.AddsPerformed["task"], + ) + } + if s.AddsPerformed["decision"] != 1 { + t.Errorf( + "decision adds = %d, want 1", + s.AddsPerformed["decision"], + ) + } +} + +func TestQueuePendingUpdate(t *testing.T) { + s := NewState("/tmp/.context") + now := time.Now() + s.QueuePendingUpdate(PendingUpdate{ + Type: "task", + Content: "Build feature", + QueuedAt: now, + }) + if len(s.PendingFlush) != 1 { + t.Fatalf( + "PendingFlush length = %d, want 1", + len(s.PendingFlush), + ) + } + pu := s.PendingFlush[0] + if pu.Type != "task" { + t.Errorf("Type = %q, want %q", pu.Type, "task") + } + if pu.Content != "Build feature" { + t.Errorf( + "Content = %q, want %q", + pu.Content, "Build feature", + ) + } +} + +func TestQueueMultiplePendingUpdates(t *testing.T) { + s := NewState("/tmp/.context") + s.QueuePendingUpdate(PendingUpdate{Type: "task", Content: "A"}) + s.QueuePendingUpdate(PendingUpdate{Type: "decision", Content: "B"}) + s.QueuePendingUpdate(PendingUpdate{Type: "learning", Content: "C"}) + if len(s.PendingFlush) != 3 { + t.Errorf( + "PendingFlush length = %d, want 3", + len(s.PendingFlush), + ) + } +} + +func TestPendingCount(t *testing.T) { + s := NewState("/tmp/.context") + if s.PendingCount() != 0 { + t.Errorf("PendingCount = %d, want 0", s.PendingCount()) + } + s.QueuePendingUpdate(PendingUpdate{Type: "task", Content: "X"}) + if s.PendingCount() != 1 { + t.Errorf("PendingCount = %d, want 1", s.PendingCount()) + } + s.QueuePendingUpdate(PendingUpdate{Type: "task", Content: "Y"}) + if s.PendingCount() != 2 { + t.Errorf("PendingCount = %d, want 2", s.PendingCount()) + } +} + +func TestPendingUpdateAttrs(t *testing.T) { + s := NewState("/tmp/.context") + s.QueuePendingUpdate(PendingUpdate{ + Type: "task", + Content: "Implement feature", + Attrs: map[string]string{"file": "TASKS.md"}, + }) + pu := s.PendingFlush[0] + if pu.Attrs["file"] != "TASKS.md" { + t.Errorf( + "Attrs[file] = %q, want %q", + pu.Attrs["file"], "TASKS.md", + ) + } +} diff --git a/internal/sanitize/content.go b/internal/sanitize/content.go new file mode 100644 index 000000000..518463cf0 --- /dev/null +++ b/internal/sanitize/content.go @@ -0,0 +1,78 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +import ( + "regexp" + "strings" + "unicode" +) + +// entryHeaderRe matches entry headers like "## [2026-". +var entryHeaderRe = regexp.MustCompile(`(?m)^##\s+\[\d{4}-`) + +// taskCheckboxRe matches task checkboxes "- [ ]" and "- [x]". +var taskCheckboxRe = regexp.MustCompile(`(?m)^-\s+\[[x ]\]`) + +// constitutionRuleRe matches constitution rule format "- [ ] **Never". +var constitutionRuleRe = regexp.MustCompile( + `(?m)^-\s+\[[x ]\]\s+\*\*[A-Z]`, +) + +// Content neutralizes Markdown structure characters in entry content +// that could corrupt .context/ file parsing. +// +// Escapes entry headers, task checkboxes, and constitution rule +// patterns so they render as literal text instead of structural +// elements. +// +// Parameters: +// - s: raw content string from MCP client +// +// Returns: +// - string: content safe for appending to .context/ Markdown files +func Content(s string) string { + // Escape entry headers: "## [2026-" → "\\## [2026-" + s = entryHeaderRe.ReplaceAllStringFunc(s, func(m string) string { + return `\` + m + }) + + // Escape task checkboxes: "- [ ]" → "\\- [ ]" + s = taskCheckboxRe.ReplaceAllStringFunc(s, func(m string) string { + return `\` + m + }) + + // Escape constitution rules: "- [ ] **Never" → "\\- [ ] **Never" + s = constitutionRuleRe.ReplaceAllStringFunc(s, func(m string) string { + return `\` + m + }) + + // Strip null bytes. + s = strings.ReplaceAll(s, "\x00", "") + + return s +} + +// StripControl removes ASCII control characters (except tab and +// newline) from a string. +// +// Parameters: +// - s: input string potentially containing control characters +// +// Returns: +// - string: input with control characters removed +func StripControl(s string) string { + return strings.Map(func(r rune) rune { + if r == '\t' || r == '\n' || r == '\r' { + return r + } + if unicode.IsControl(r) { + return -1 + } + return r + }, s) +} diff --git a/internal/sanitize/doc.go b/internal/sanitize/doc.go index e4be5cf17..d9a471a71 100644 --- a/internal/sanitize/doc.go +++ b/internal/sanitize/doc.go @@ -8,6 +8,9 @@ // // Unlike validation (which rejects bad input), sanitization mutates // input to conform to constraints. [Filename] converts arbitrary -// strings into safe filename components. +// strings into safe filename components, [Content] neutralizes +// Markdown structure injections, [Reflect] truncates and strips +// control characters for error messages, and [SessionID] produces +// path-safe session identifiers. // Part of the internal subsystem. package sanitize diff --git a/internal/sanitize/path.go b/internal/sanitize/path.go new file mode 100644 index 000000000..df52d0545 --- /dev/null +++ b/internal/sanitize/path.go @@ -0,0 +1,51 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +import ( + "regexp" + "strings" +) + +// sessionIDUnsafe matches characters not safe for session IDs in +// file paths: anything outside [a-zA-Z0-9._-]. +var sessionIDUnsafe = regexp.MustCompile(`[^a-zA-Z0-9._-]`) + +// SessionID sanitizes a session identifier for safe use in file +// paths. +// +// Strips path separators, traversal sequences, and null bytes. +// Replaces remaining unsafe characters with hyphens and limits +// length to 128 bytes. +// +// Parameters: +// - s: raw session ID from MCP client +// +// Returns: +// - string: path-safe session ID +func SessionID(s string) string { + // Strip null bytes. + s = strings.ReplaceAll(s, "\x00", "") + + // Collapse path traversal sequences. + s = strings.ReplaceAll(s, "..", "") + s = strings.ReplaceAll(s, "/", "") + s = strings.ReplaceAll(s, "\\", "") + + // Replace remaining unsafe chars. + s = sessionIDUnsafe.ReplaceAllString(s, "-") + + // Remove leading/trailing hyphens. + s = strings.Trim(s, "-") + + // Limit length. + if len(s) > 128 { + s = s[:128] + } + + return s +} diff --git a/internal/sanitize/reflect.go b/internal/sanitize/reflect.go new file mode 100644 index 000000000..8c2771c24 --- /dev/null +++ b/internal/sanitize/reflect.go @@ -0,0 +1,44 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +// Reflect truncates a string and strips control characters for safe +// inclusion in error or log messages. +// +// Use this for any client-supplied value that gets reflected back in +// JSON-RPC error messages (tool names, prompt names, URIs, caller +// identifiers). +// +// Parameters: +// - s: untrusted input string +// - maxLen: maximum allowed length (0 = no truncation) +// +// Returns: +// - string: truncated, control-character-free string +func Reflect(s string, maxLen int) string { + s = StripControl(s) + if maxLen > 0 && len(s) > maxLen { + s = s[:maxLen] + } + return s +} + +// Truncate limits a string to maxLen bytes. If truncated, no +// ellipsis is appended — the caller controls presentation. +// +// Parameters: +// - s: input string +// - maxLen: maximum byte length +// +// Returns: +// - string: input capped to maxLen bytes +func Truncate(s string, maxLen int) string { + if maxLen > 0 && len(s) > maxLen { + return s[:maxLen] + } + return s +} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go new file mode 100644 index 000000000..32b3f0e99 --- /dev/null +++ b/internal/sanitize/sanitize_test.go @@ -0,0 +1,172 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +import ( + "strings" + "testing" +) + +func TestContentEscapesEntryHeaders(t *testing.T) { + input := "## [2026-03-15] Decision title" + got := Content(input) + want := `\## [2026-03-15] Decision title` + if got != want { + t.Errorf("Content(%q) = %q, want %q", input, got, want) + } +} + +func TestContentEscapesTaskCheckboxUnchecked(t *testing.T) { + got := Content("- [ ] New task") + want := `\- [ ] New task` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestContentEscapesTaskCheckboxChecked(t *testing.T) { + got := Content("- [x] Done task") + want := `\- [x] Done task` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestContentEscapesConstitutionRules(t *testing.T) { + input := "- [ ] **Never break the constitution" + got := Content(input) + if !strings.HasPrefix(got, `\- [ ] **Never`) { + t.Errorf("got %q, want constitution rule escaped", got) + } +} + +func TestContentStripsNullBytes(t *testing.T) { + got := Content("hello\x00world") + if got != "helloworld" { + t.Errorf("got %q, want %q", got, "helloworld") + } +} + +func TestContentPreservesNormalText(t *testing.T) { + input := "This is a normal architecture decision." + got := Content(input) + if got != input { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestContentMultilineInjection(t *testing.T) { + input := "Legit\n## [2026-01-01] Injected\n- [ ] Fake" + got := Content(input) + if strings.Contains(got, "\n## [2026") { + t.Error("entry header injection not escaped") + } + if strings.Contains(got, "\n- [ ] Fake") { + t.Error("checkbox injection not escaped") + } +} + +func TestReflectTruncates(t *testing.T) { + got := Reflect(strings.Repeat("a", 500), 256) + if len(got) != 256 { + t.Errorf("len = %d, want 256", len(got)) + } +} + +func TestReflectStripsControlChars(t *testing.T) { + got := Reflect("tool\x07name\x1b[31m", 0) + if got != "toolname[31m" { + t.Errorf("got %q, want %q", got, "toolname[31m") + } +} + +func TestReflectPreservesNormal(t *testing.T) { + got := Reflect("ctx_status", 256) + if got != "ctx_status" { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestReflectZeroMaxLen(t *testing.T) { + got := Reflect(strings.Repeat("x", 1000), 0) + if len(got) != 1000 { + t.Errorf("len = %d, want 1000 (no truncation)", len(got)) + } +} + +func TestTruncateShort(t *testing.T) { + if got := Truncate("short", 100); got != "short" { + t.Errorf("got %q", got) + } +} + +func TestTruncateLong(t *testing.T) { + if got := Truncate("long input", 4); got != "long" { + t.Errorf("got %q", got) + } +} + +func TestTruncateZero(t *testing.T) { + if got := Truncate("any", 0); got != "any" { + t.Errorf("got %q", got) + } +} + +func TestStripControlPreservesWhitespace(t *testing.T) { + input := "a\nb\tc\r" + if got := StripControl(input); got != input { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestStripControlRemovesBell(t *testing.T) { + if got := StripControl("hello\x07world"); got != "helloworld" { + t.Errorf("got %q", got) + } +} + +func TestSessionIDSafe(t *testing.T) { + input := "session-2026-03-15" + if got := SessionID(input); got != input { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestSessionIDStripsTraversal(t *testing.T) { + got := SessionID("../../etc/passwd") + if strings.Contains(got, "..") || strings.Contains(got, "/") { + t.Errorf("got %q, contains traversal", got) + } +} + +func TestSessionIDStripsBackslashTraversal(t *testing.T) { + got := SessionID(`..\..\windows\system32`) + if strings.Contains(got, "..") || strings.Contains(got, `\`) { + t.Errorf("got %q, contains traversal", got) + } +} + +func TestSessionIDStripsNullBytes(t *testing.T) { + got := SessionID("session\x00evil") + if strings.Contains(got, "\x00") { + t.Errorf("got %q, contains null byte", got) + } +} + +func TestSessionIDLimitsLength(t *testing.T) { + got := SessionID(strings.Repeat("a", 300)) + if len(got) > 128 { + t.Errorf("len = %d, want <= 128", len(got)) + } +} + +func TestSessionIDReplacesUnsafe(t *testing.T) { + got := SessionID("session with spaces!@#$") + if strings.ContainsAny(got, " !@#$") { + t.Errorf("got %q, contains unsafe chars", got) + } +}