Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/assets/commands/text/mcp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ mcp.err-unknown-prompt:
short: 'unknown prompt: %s'
mcp.err-uri-required:
short: uri is required
mcp.err-input-too-long:
short: '%s exceeds maximum length (%d bytes)'
mcp.err-unknown-entry-type:
short: 'unknown entry type: %s'
mcp.format-watch-completed:
short: 'Completed: %s'
mcp.format-wrote:
Expand Down
2 changes: 2 additions & 0 deletions internal/config/embed/text/mcp_err.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ const (
DescKeyMCPErrQueryRequired = "mcp.err-query-required"
DescKeyMCPErrUnknownPrompt = "mcp.err-unknown-prompt"
DescKeyMCPErrURIRequired = "mcp.err-uri-required"
DescKeyMCPErrInputTooLong = "mcp.err-input-too-long"
DescKeyMCPErrUnknownEntryType = "mcp.err-unknown-entry-type"
)
15 changes: 15 additions & 0 deletions internal/config/mcp/cfg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,23 @@ const (

// DefaultSourceLimit is the max sessions returned by ctx_journal_source.
DefaultSourceLimit = 5
// MaxSourceLimit caps the source limit to prevent unbounded queries.
MaxSourceLimit = 100
// MinWordLen is the shortest word considered for overlap matching.
MinWordLen = 4
// MinWordOverlap is the minimum word matches to signal task completion.
MinWordOverlap = 2

// --- Input length limits (MCP-SAN.1) ---

// MaxContentLen is the maximum byte length for entry content fields.
MaxContentLen = 32_000
// MaxNameLen is the maximum byte length for tool/prompt/resource names.
MaxNameLen = 256
// MaxQueryLen is the maximum byte length for search queries.
MaxQueryLen = 1_000
// MaxCallerLen is the maximum byte length for caller identifiers.
MaxCallerLen = 128
// MaxURILen is the maximum byte length for resource URIs.
MaxURILen = 512
)
345 changes: 345 additions & 0 deletions internal/mcp/proto/schema_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
// / ctx: https://ctx.ist
// ,'`./ do you remember?
// `.,'\
// \ Copyright 2026-present Context contributors.
// SPDX-License-Identifier: Apache-2.0

package proto

import (
"encoding/json"
"testing"
)

func roundTrip(t *testing.T, v interface{}, dst interface{}) {
t.Helper()
data, err := json.Marshal(v)
if err != nil {
t.Fatalf("marshal: %v", err)
}
if err := json.Unmarshal(data, dst); err != nil {
t.Fatalf("unmarshal: %v", err)
}
}

func TestRequestRoundTrip(t *testing.T) {
orig := Request{
JSONRPC: "2.0",
ID: json.RawMessage(`1`),
Method: "tools/call",
Params: json.RawMessage(`{"name":"ctx_status"}`),
}
var got Request
roundTrip(t, orig, &got)
if got.JSONRPC != orig.JSONRPC {
t.Errorf("JSONRPC = %q, want %q", got.JSONRPC, orig.JSONRPC)
}
if got.Method != orig.Method {
t.Errorf("Method = %q, want %q", got.Method, orig.Method)
}
if string(got.ID) != string(orig.ID) {
t.Errorf("ID = %s, want %s", got.ID, orig.ID)
}
}

func TestResponseSuccessRoundTrip(t *testing.T) {
orig := Response{
JSONRPC: "2.0",
ID: json.RawMessage(`1`),
Result: map[string]string{"key": "value"},
}
var got Response
roundTrip(t, orig, &got)
if got.JSONRPC != "2.0" {
t.Errorf("JSONRPC = %q, want %q", got.JSONRPC, "2.0")
}
if got.Error != nil {
t.Errorf("unexpected error: %v", got.Error)
}
}

func TestResponseErrorRoundTrip(t *testing.T) {
orig := Response{
JSONRPC: "2.0",
ID: json.RawMessage(`1`),
Error: &RPCError{
Code: ErrCodeNotFound,
Message: "method not found",
},
}
var got Response
roundTrip(t, orig, &got)
if got.Error == nil {
t.Fatal("expected error in response")
}
if got.Error.Code != ErrCodeNotFound {
t.Errorf("Code = %d, want %d", got.Error.Code, ErrCodeNotFound)
}
if got.Error.Message != "method not found" {
t.Errorf("Message = %q, want %q", got.Error.Message, "method not found")
}
}

func TestNotificationRoundTrip(t *testing.T) {
orig := Notification{
JSONRPC: "2.0",
Method: "notifications/initialized",
}
var got Notification
roundTrip(t, orig, &got)
if got.Method != "notifications/initialized" {
t.Errorf("Method = %q, want %q", got.Method, "notifications/initialized")
}
}

func TestRPCErrorWithData(t *testing.T) {
orig := RPCError{
Code: ErrCodeInvalidArg,
Message: "invalid",
Data: map[string]string{"field": "name"},
}
var got RPCError
roundTrip(t, orig, &got)
if got.Code != ErrCodeInvalidArg {
t.Errorf("Code = %d, want %d", got.Code, ErrCodeInvalidArg)
}
}

func TestInitializeParamsRoundTrip(t *testing.T) {
orig := InitializeParams{
ProtocolVersion: ProtocolVersion,
ClientInfo: AppInfo{Name: "test-client", Version: "1.0.0"},
}
var got InitializeParams
roundTrip(t, orig, &got)
if got.ProtocolVersion != ProtocolVersion {
t.Errorf("ProtocolVersion = %q, want %q", got.ProtocolVersion, ProtocolVersion)
}
if got.ClientInfo.Name != "test-client" {
t.Errorf("ClientInfo.Name = %q, want %q", got.ClientInfo.Name, "test-client")
}
}

func TestInitializeResultRoundTrip(t *testing.T) {
orig := InitializeResult{
ProtocolVersion: ProtocolVersion,
Capabilities: ServerCaps{
Resources: &ResourcesCap{Subscribe: true, ListChanged: true},
Tools: &ToolsCap{ListChanged: true},
Prompts: &PromptsCap{ListChanged: false},
},
ServerInfo: AppInfo{Name: "ctx", Version: "0.3.0"},
}
var got InitializeResult
roundTrip(t, orig, &got)
if got.Capabilities.Resources == nil {
t.Fatal("expected Resources capability")
}
if !got.Capabilities.Resources.Subscribe {
t.Error("expected Subscribe = true")
}
}

func TestResourceRoundTrip(t *testing.T) {
orig := Resource{
URI: "ctx://context/tasks",
Name: "tasks",
MimeType: "text/markdown",
}
var got Resource
roundTrip(t, orig, &got)
if got.URI != orig.URI {
t.Errorf("URI = %q, want %q", got.URI, orig.URI)
}
}

func TestToolRoundTrip(t *testing.T) {
orig := Tool{
Name: "ctx_status",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]Property{
"verbose": {Type: "boolean", Description: "Verbose"},
},
Required: []string{"verbose"},
},
Annotations: &ToolAnnotations{ReadOnlyHint: true},
}
var got Tool
roundTrip(t, orig, &got)
if got.Name != "ctx_status" {
t.Errorf("Name = %q, want %q", got.Name, "ctx_status")
}
if got.Annotations == nil || !got.Annotations.ReadOnlyHint {
t.Error("expected ReadOnlyHint = true")
}
}

func TestCallToolParamsRoundTrip(t *testing.T) {
orig := CallToolParams{
Name: "ctx_add",
Arguments: map[string]interface{}{"type": "task", "content": "Test"},
}
var got CallToolParams
roundTrip(t, orig, &got)
if got.Name != "ctx_add" {
t.Errorf("Name = %q, want %q", got.Name, "ctx_add")
}
}

func TestCallToolResultRoundTrip(t *testing.T) {
orig := CallToolResult{
Content: []ToolContent{{Type: "text", Text: "Done"}},
}
var got CallToolResult
roundTrip(t, orig, &got)
if len(got.Content) != 1 {
t.Fatalf("Content count = %d, want 1", len(got.Content))
}
if got.Content[0].Text != "Done" {
t.Errorf("Text = %q, want %q", got.Content[0].Text, "Done")
}
if got.IsError {
t.Error("expected IsError = false")
}
}

func TestCallToolResultErrorRoundTrip(t *testing.T) {
orig := CallToolResult{
Content: []ToolContent{{Type: "text", Text: "failed"}},
IsError: true,
}
var got CallToolResult
roundTrip(t, orig, &got)
if !got.IsError {
t.Error("expected IsError = true")
}
}

func TestPromptRoundTrip(t *testing.T) {
orig := Prompt{
Name: "ctx-session-start",
Arguments: []PromptArgument{
{Name: "content", Required: true},
},
}
var got Prompt
roundTrip(t, orig, &got)
if got.Name != "ctx-session-start" {
t.Errorf("Name = %q, want %q", got.Name, "ctx-session-start")
}
if len(got.Arguments) != 1 || !got.Arguments[0].Required {
t.Error("expected 1 required argument")
}
}

func TestGetPromptResultRoundTrip(t *testing.T) {
orig := GetPromptResult{
Description: "Test",
Messages: []PromptMessage{
{Role: "user", Content: ToolContent{Type: "text", Text: "Hi"}},
},
}
var got GetPromptResult
roundTrip(t, orig, &got)
if len(got.Messages) != 1 {
t.Fatalf("Messages count = %d, want 1", len(got.Messages))
}
if got.Messages[0].Role != "user" {
t.Errorf("Role = %q, want %q", got.Messages[0].Role, "user")
}
}

func TestSubscribeParamsRoundTrip(t *testing.T) {
orig := SubscribeParams{URI: "ctx://context/tasks"}
var got SubscribeParams
roundTrip(t, orig, &got)
if got.URI != orig.URI {
t.Errorf("URI = %q, want %q", got.URI, orig.URI)
}
}

func TestUnsubscribeParamsRoundTrip(t *testing.T) {
orig := UnsubscribeParams{URI: "ctx://context/decisions"}
var got UnsubscribeParams
roundTrip(t, orig, &got)
if got.URI != orig.URI {
t.Errorf("URI = %q, want %q", got.URI, orig.URI)
}
}

func TestResourceUpdatedParamsRoundTrip(t *testing.T) {
orig := ResourceUpdatedParams{URI: "ctx://context/tasks"}
var got ResourceUpdatedParams
roundTrip(t, orig, &got)
if got.URI != orig.URI {
t.Errorf("URI = %q, want %q", got.URI, orig.URI)
}
}

func TestErrorCodeConstants(t *testing.T) {
if ErrCodeParse != -32700 {
t.Errorf("ErrCodeParse = %d, want -32700", ErrCodeParse)
}
if ErrCodeNotFound != -32601 {
t.Errorf("ErrCodeNotFound = %d, want -32601", ErrCodeNotFound)
}
if ErrCodeInvalidArg != -32602 {
t.Errorf("ErrCodeInvalidArg = %d, want -32602", ErrCodeInvalidArg)
}
if ErrCodeInternal != -32603 {
t.Errorf("ErrCodeInternal = %d, want -32603", ErrCodeInternal)
}
}

func TestProtocolVersionValue(t *testing.T) {
if ProtocolVersion != "2024-11-05" {
t.Errorf("ProtocolVersion = %q, want %q", ProtocolVersion, "2024-11-05")
}
}

func TestRequestNilParams(t *testing.T) {
orig := Request{
JSONRPC: "2.0",
ID: json.RawMessage(`"abc"`),
Method: "ping",
}
data, err := json.Marshal(orig)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var got Request
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if got.Params != nil {
t.Errorf("expected nil Params, got %s", got.Params)
}
}

func TestResponseNilID(t *testing.T) {
orig := Response{
JSONRPC: "2.0",
Error: &RPCError{Code: ErrCodeParse, Message: "parse error"},
}
var got Response
roundTrip(t, orig, &got)
if got.ID != nil {
t.Errorf("expected nil ID, got %s", got.ID)
}
}

func TestPropertyEnumRoundTrip(t *testing.T) {
orig := Property{
Type: "string",
Enum: []string{"task", "decision", "learning"},
}
var got Property
roundTrip(t, orig, &got)
if len(got.Enum) != 3 {
t.Fatalf("Enum count = %d, want 3", len(got.Enum))
}
if got.Enum[0] != "task" {
t.Errorf("Enum[0] = %q, want %q", got.Enum[0], "task")
}
}
Loading
Loading