diff --git a/pkg/tools/repair.go b/pkg/tools/repair.go new file mode 100644 index 000000000..f929c8634 --- /dev/null +++ b/pkg/tools/repair.go @@ -0,0 +1,269 @@ +package tools + +import ( + "encoding/json" + "reflect" + "strings" +) + +// repairKind identifies which shape repair was applied to a single field. +// Repairs are kept narrow and named so per-(model, tool) telemetry can be +// aggregated and so unintended repairs are obvious in logs. +type repairKind string + +const ( + // repairDropNull removes a field whose value is JSON null when the field + // type is one Go's json package would otherwise leave as a zero value. + // In Go this is rarely needed (json.Unmarshal accepts null for slices, + // pointers, maps, and interfaces and treats it as a no-op for primitive + // scalars) but it stays here for symmetry with the framing in + // https://x.com — primarily as a safety net for fields whose + // custom UnmarshalJSON may otherwise reject null. + repairDropNull repairKind = "drop_null" + + // repairUnwrapStringArray turns a JSON-encoded array delivered as a + // string into a real array. Models routinely send + // "paths": "[\"a\",\"b\"]" + // instead of + // "paths": ["a","b"] + // The repair tries `json.Unmarshal` on the string value; if it parses to + // an array we substitute the array. This must run BEFORE + // repairWrapStringInArray, otherwise '["a","b"]' (a literal stringified + // array) gets wrapped as ['["a","b"]']. + repairUnwrapStringArray repairKind = "unwrap_string_array" + + // repairWrapObjectInArray turns a single-key object placeholder into a + // one-element array. Models sometimes emit + // "paths": {"path": "foo.txt"} + // when the schema expects ["foo.txt"]. We only fire this when the object + // has exactly one entry whose value matches the slice's element kind, to + // keep the repair narrow. + repairWrapObjectInArray repairKind = "wrap_object_in_array" + + // repairWrapInArray wraps a bare scalar in a one-element array when the + // schema expects an array of that scalar's kind. Catches the common + // "paths": "foo.txt" → ["foo.txt"] + // failure mode. + repairWrapInArray repairKind = "wrap_in_array" +) + +// tryRepairToolArgs attempts the four shape repairs at the top level of a +// tool argument payload. It walks the destination struct's reflect.Type and +// looks for shape mismatches between each typed field and the corresponding +// raw JSON value, applying a small set of targeted fixes. +// +// The design follows validate-then-repair: callers MUST first try a strict +// json.Unmarshal and only invoke this function when that strict parse +// failed. The schema (the destination type) is the prior, and we only spend +// repair budget at the exact field paths the schema disagreed at. +// +// Returns the repaired JSON bytes, the list of repairs applied (for +// telemetry), and a boolean indicating whether any repair was applied. When +// the second value is empty / third is false the caller should surface the +// original validation error unchanged. +func tryRepairToolArgs(data []byte, paramsType reflect.Type) ([]byte, []repairKind, bool) { + if paramsType == nil { + return nil, nil, false + } + for paramsType.Kind() == reflect.Ptr { + paramsType = paramsType.Elem() + } + if paramsType.Kind() != reflect.Struct { + return nil, nil, false + } + + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + // The payload isn't even a JSON object at the top level. Shape + // repairs operate on object fields, so we have nothing to do. + return nil, nil, false + } + + // reflect.VisibleFields walks promoted fields from embedded structs, + // matching how encoding/json marshals struct values into JSON objects: + // fields lifted by embedding share the same top-level object, so they + // must be inspected in the same `raw` map. Iterating + // `paramsType.NumField()` directly would silently skip them — relevant + // today for ReferencesArgs/RenameArgs/CallHierarchyArgs/TypeHierarchyArgs + // which all embed PositionArgs. + repairs := []repairKind{} + for _, field := range reflect.VisibleFields(paramsType) { + if field.Anonymous { + // The embedding itself is a marker; its promoted children are + // emitted as separate visible fields by VisibleFields. + continue + } + name, ok := jsonFieldName(field) + if !ok { + continue + } + val, present := raw[name] + if !present { + continue + } + + newVal, kind, repaired := repairFieldValue(val, field.Type) + if !repaired { + continue + } + if newVal == nil && kind == repairDropNull { + delete(raw, name) + } else { + raw[name] = newVal + } + repairs = append(repairs, kind) + } + + if len(repairs) == 0 { + return nil, nil, false + } + + out, err := json.Marshal(raw) + if err != nil { + return nil, nil, false + } + return out, repairs, true +} + +// jsonFieldName returns the JSON object key that a struct field marshals to, +// or false if the field is unexported / explicitly skipped via `json:"-"`. +func jsonFieldName(field reflect.StructField) (string, bool) { + if !field.IsExported() { + return "", false + } + tag := field.Tag.Get("json") + if tag == "-" { + return "", false + } + if tag == "" { + return field.Name, true + } + name := strings.SplitN(tag, ",", 2)[0] + if name == "" { + return field.Name, true + } + return name, true +} + +// repairFieldValue applies the four repairs to a single field. It returns +// the new value, the repair kind, and whether a repair fired. The function +// is intentionally conservative — when in doubt, return false and let the +// original validation error surface. +func repairFieldValue(val any, fieldType reflect.Type) (any, repairKind, bool) { + for fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + + // Repair 1: null-valued primitives. Slices/maps/pointers/interfaces + // already accept null in Go's json package, so we only nudge for scalar + // kinds where a custom UnmarshalJSON might object. + if val == nil { + switch fieldType.Kind() { + case reflect.String, reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return nil, repairDropNull, true + default: + return nil, "", false + } + } + + // Remaining repairs all target slice fields. Anything else is left + // alone — we explicitly do not generalise to maps or nested structs in + // this layer. Recursion would expand the blast radius without evidence + // it is needed. + if fieldType.Kind() != reflect.Slice { + return nil, "", false + } + + // If the value already arrived as an array there is nothing to fix at + // this level — even if individual elements are wrong, the schema can + // surface that error itself. + if _, ok := val.([]any); ok { + return nil, "", false + } + + elemType := fieldType.Elem() + for elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + elemKind := elemType.Kind() + + // Repair 2: stringified JSON array. Try this BEFORE the bare-string + // wrap, otherwise a stringified array would be wrapped as a single + // element and we would silently corrupt the input. + if s, ok := val.(string); ok { + trimmed := strings.TrimSpace(s) + if strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]") { + var arr []any + if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { + return arr, repairUnwrapStringArray, true + } + } + // Repair 4: bare scalar where the schema expects an array of that + // scalar's kind. Only fire for primitive-element slices to avoid + // guessing how to construct a struct from a string. + if isScalarKind(elemKind) { + return []any{s}, repairWrapInArray, true + } + return nil, "", false + } + + // Repair 3: object placeholder. Models sometimes emit + // {"paths": {"path": "foo"}} → ["foo"] + // We accept exactly the narrow case of a single-entry object whose + // value is a scalar matching the slice's element kind. + if obj, ok := val.(map[string]any); ok { + if len(obj) != 1 { + return nil, "", false + } + for _, v := range obj { + if isScalarKind(elemKind) && matchesScalarKind(v, elemKind) { + return []any{v}, repairWrapObjectInArray, true + } + } + return nil, "", false + } + + // Bare scalar of a different type (number, bool) where an array is + // expected. Wrap when element kinds line up. + if matchesScalarKind(val, elemKind) { + return []any{val}, repairWrapInArray, true + } + + return nil, "", false +} + +// isScalarKind reports whether k is one of the primitive JSON-compatible +// reflect kinds the repair layer is willing to wrap into a one-element array. +func isScalarKind(k reflect.Kind) bool { + switch k { + case reflect.String, reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return true + } + return false +} + +// matchesScalarKind reports whether v (a value parsed from JSON via +// map[string]any) is compatible with the given scalar reflect.Kind. JSON +// numbers always parse to float64, so any numeric kind matches a float64. +func matchesScalarKind(v any, k reflect.Kind) bool { + switch v.(type) { + case string: + return k == reflect.String + case bool: + return k == reflect.Bool + case float64: + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return true + } + } + return false +} diff --git a/pkg/tools/repair_test.go b/pkg/tools/repair_test.go new file mode 100644 index 000000000..cdec770f3 --- /dev/null +++ b/pkg/tools/repair_test.go @@ -0,0 +1,201 @@ +package tools + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// argsWithStrings exercises the slice-of-string repair path that's by far +// the most commonly broken in real LLM tool calls (paths, urls, patterns). +type argsWithStrings struct { + Paths []string `json:"paths"` + JSON bool `json:"json,omitempty"` +} + +type argsWithInt struct { + N int `json:"n"` + Tags []string `json:"tags,omitempty"` +} + +func TestRepair_UnwrapsStringifiedArray(t *testing.T) { + // Common DeepSeek/Qwen mistake: send an array as a JSON string. + in := []byte(`{"paths": "[\"a.txt\",\"b.txt\"]"}`) + out, kinds, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithStrings]()) + require.True(t, ok) + assert.Equal(t, []repairKind{repairUnwrapStringArray}, kinds) + assert.JSONEq(t, `{"paths":["a.txt","b.txt"]}`, string(out)) +} + +func TestRepair_WrapsBareString(t *testing.T) { + // Single-string-instead-of-array, the most common shape mistake. + in := []byte(`{"paths": "only.txt"}`) + out, kinds, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithStrings]()) + require.True(t, ok) + assert.Equal(t, []repairKind{repairWrapInArray}, kinds) + assert.JSONEq(t, `{"paths":["only.txt"]}`, string(out)) +} + +func TestRepair_WrapsSingleObjectPlaceholder(t *testing.T) { + // Some models wrap a single argument in an object. + in := []byte(`{"paths": {"path": "only.txt"}}`) + out, kinds, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithStrings]()) + require.True(t, ok) + assert.Equal(t, []repairKind{repairWrapObjectInArray}, kinds) + assert.JSONEq(t, `{"paths":["only.txt"]}`, string(out)) +} + +func TestRepair_OrderingPreventsDoubleWrap(t *testing.T) { + // If the bare-string-wrap fired before the unwrap-stringified-array + // repair, this input would become [`["a","b"]`] instead of ["a","b"]. + // The fact that we get a clean array out is the load-bearing assertion + // of this test. + in := []byte(`{"paths": "[\"a\",\"b\"]"}`) + out, kinds, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithStrings]()) + require.True(t, ok) + assert.Equal(t, []repairKind{repairUnwrapStringArray}, kinds) + assert.JSONEq(t, `{"paths":["a","b"]}`, string(out)) +} + +func TestRepair_DropsNullForPrimitive(t *testing.T) { + // Some custom UnmarshalJSON impls trip on null where a primitive is + // expected. Dropping the field lets the type's zero value win. + in := []byte(`{"n": null}`) + out, kinds, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithInt]()) + require.True(t, ok) + assert.Equal(t, []repairKind{repairDropNull}, kinds) + assert.JSONEq(t, `{}`, string(out)) +} + +func TestRepair_LeavesValidArrayUntouched(t *testing.T) { + // The repair entry point should only be reached after a strict parse + // already failed, but defensively ensure that a well-formed array is + // not "repaired" if we ever do get called with one. + in := []byte(`{"paths": ["a","b"]}`) + _, _, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithStrings]()) + assert.False(t, ok) +} + +func TestRepair_LeavesUnknownFieldsAlone(t *testing.T) { + // Field not declared on the struct: out of repair scope. + in := []byte(`{"unknown": "foo"}`) + _, _, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithStrings]()) + assert.False(t, ok) +} + +func TestRepair_ReturnsFalseOnNonObjectInput(t *testing.T) { + // Top-level non-object payloads are unparseable as field-shape errors. + in := []byte(`"just a string"`) + _, _, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithStrings]()) + assert.False(t, ok) +} + +func TestRepair_RefusesMultiKeyObjectAsArray(t *testing.T) { + // Two keys in the placeholder object — too ambiguous to safely wrap. + in := []byte(`{"paths": {"path": "a.txt", "extra": "ignore"}}`) + _, _, ok := tryRepairToolArgs(in, reflect.TypeFor[argsWithStrings]()) + assert.False(t, ok) +} + +// TestRepair_VisitsPromotedFieldsFromEmbeddedStruct mirrors the shape of +// LSP arg structs in pkg/tools/builtin/lsp.go (ReferencesArgs and friends +// embed PositionArgs). reflect.VisibleFields is what makes promoted fields +// visit-able; iterating NumField/Field directly would silently skip them. +func TestRepair_VisitsPromotedFieldsFromEmbeddedStruct(t *testing.T) { + type Base struct { + Files []string `json:"files"` + } + type WithEmbedding struct { + Base + + Extra string `json:"extra,omitempty"` + } + in := []byte(`{"files":"only.txt"}`) + out, kinds, ok := tryRepairToolArgs(in, reflect.TypeFor[WithEmbedding]()) + require.True(t, ok) + assert.Equal(t, []repairKind{repairWrapInArray}, kinds) + assert.JSONEq(t, `{"files":["only.txt"]}`, string(out)) +} + +func TestRepair_RepairsMultipleFieldsInOneCall(t *testing.T) { + type combo struct { + Paths []string `json:"paths"` + Tags []string `json:"tags"` + } + in := []byte(`{"paths":"only.txt","tags":"[\"go\",\"ai\"]"}`) + out, kinds, ok := tryRepairToolArgs(in, reflect.TypeFor[combo]()) + require.True(t, ok) + assert.Len(t, kinds, 2) + assert.JSONEq(t, `{"paths":["only.txt"],"tags":["go","ai"]}`, string(out)) +} + +// End-to-end test that exercises the NewHandler integration: invalid input +// reshaped by repair, handler called with the typed value. +func TestNewHandler_RepairsBareStringArray(t *testing.T) { + type fileArgs struct { + Paths []string `json:"paths"` + } + var got fileArgs + handler := NewHandler(func(_ context.Context, args fileArgs) (*ToolCallResult, error) { + got = args + return ResultSuccess("ok"), nil + }) + + result, err := handler(t.Context(), ToolCall{ + ID: "call_1", + Type: "function", + Function: FunctionCall{ + Name: "read_multiple_files", + Arguments: `{"paths":"only.txt"}`, + }, + }) + require.NoError(t, err) + assert.Equal(t, "ok", result.Output) + assert.Equal(t, []string{"only.txt"}, got.Paths) +} + +func TestNewHandler_RepairsStringifiedArray(t *testing.T) { + type fileArgs struct { + Paths []string `json:"paths"` + } + var got fileArgs + handler := NewHandler(func(_ context.Context, args fileArgs) (*ToolCallResult, error) { + got = args + return ResultSuccess("ok"), nil + }) + + result, err := handler(t.Context(), ToolCall{ + ID: "call_1", + Type: "function", + Function: FunctionCall{ + Name: "read_multiple_files", + Arguments: `{"paths":"[\"a.txt\",\"b.txt\"]"}`, + }, + }) + require.NoError(t, err) + assert.Equal(t, "ok", result.Output) + assert.Equal(t, []string{"a.txt", "b.txt"}, got.Paths) +} + +func TestNewHandler_UnrepairableInputReturnsOriginalError(t *testing.T) { + type fileArgs struct { + Paths []string `json:"paths"` + } + handler := NewHandler(func(_ context.Context, _ fileArgs) (*ToolCallResult, error) { + t.Fatal("handler should not be called for unrepairable input") + return nil, nil + }) + + _, err := handler(t.Context(), ToolCall{ + ID: "call_1", + Type: "function", + Function: FunctionCall{ + Name: "read_multiple_files", + Arguments: `{not even json`, + }, + }) + require.Error(t, err) +} diff --git a/pkg/tools/tools.go b/pkg/tools/tools.go index cfb6fb98a..4bcefac6d 100644 --- a/pkg/tools/tools.go +++ b/pkg/tools/tools.go @@ -3,6 +3,8 @@ package tools import ( "context" "encoding/json" + "log/slog" + "reflect" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -12,8 +14,14 @@ type ToolSet interface { Tools(ctx context.Context) ([]Tool, error) } -// NewHandler creates a type-safe tool handler from a function that accepts typed parameters. -// It handles JSON unmarshaling of the tool call arguments into the specified type T. +// NewHandler creates a type-safe tool handler from a function that accepts +// typed parameters. It first runs a strict json.Unmarshal into T; on success +// the typed function is called with zero overhead. On failure the handler +// invokes the input-shape repair layer (see repair.go) which targets the +// four common LLM mistakes: null-for-required, JSON-stringified array, single +// object placeholder where an array is expected, and bare scalar where an +// array is expected. Repaired calls emit a tool_input_repaired log entry so +// per-(model, tool) repair rates can be tracked. func NewHandler[T any](fn func(context.Context, T) (*ToolCallResult, error)) ToolHandler { return func(ctx context.Context, toolCall ToolCall) (*ToolCallResult, error) { var params T @@ -21,10 +29,31 @@ func NewHandler[T any](fn func(context.Context, T) (*ToolCallResult, error)) Too if args == "" { args = "{}" } - if err := json.Unmarshal([]byte(args), ¶ms); err != nil { + + err := json.Unmarshal([]byte(args), ¶ms) + if err == nil { + return fn(ctx, params) + } + + // Strict parse failed. Try the four shape repairs at the field + // paths the schema disagreed at, then re-parse. Valid inputs are + // never reached by this code path so well-formed calls pay nothing. + repaired, kinds, ok := tryRepairToolArgs([]byte(args), reflect.TypeFor[T]()) + if !ok { + return nil, err + } + var retry T + if rerr := json.Unmarshal(repaired, &retry); rerr != nil { + // Repair did not produce a parseable payload. Surface the + // original error so the model sees the schema's complaint, not + // the repair-layer's complaint about a synthesised payload. return nil, err } - return fn(ctx, params) + slog.Info("tool_input_repaired", + "tool", toolCall.Function.Name, + "repairs", kinds, + ) + return fn(ctx, retry) } }