From ed34203b4341d9dd07f4566aca023d9f6237a560 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 17:43:42 +0200 Subject: [PATCH 1/3] make max_iterations builtin stateless The runtime now passes the loop iteration via hooks.Input.Iteration, so the builtin no longer needs a per-session map, mutex, or session_end cleanup. Fixes #2698. --- docs/configuration/hooks/index.md | 4 +- pkg/hooks/builtins/builtins.go | 34 ++-- pkg/hooks/builtins/max_iterations.go | 38 +---- pkg/hooks/builtins/max_iterations_test.go | 81 +++------- pkg/hooks/builtins/redact_secrets_test.go | 3 +- pkg/hooks/builtins/snapshot_test.go | 1 - pkg/hooks/types.go | 8 + pkg/runtime/before_llm_call_test.go | 186 +++++++++++++++++++++- pkg/runtime/hooks.go | 21 ++- pkg/runtime/loop.go | 5 +- pkg/runtime/runtime.go | 9 +- 11 files changed, 250 insertions(+), 140 deletions(-) diff --git a/docs/configuration/hooks/index.md b/docs/configuration/hooks/index.md index b58ad2022..155202fb5 100644 --- a/docs/configuration/hooks/index.md +++ b/docs/configuration/hooks/index.md @@ -153,7 +153,7 @@ Built-ins are typically zero-config and faster than equivalent shell hooks becau | `add_directory_listing` | `session_start` | _none_ | Adds an alphabetical listing of the cwd's top-level entries (skips dot-files, capped at 100 with a "... and N more"). | | `add_user_info` | `session_start` | _none_ | Adds the current OS user (username and full name) and the hostname. | | `add_recent_commits` | `session_start` | _none_, or `[""]` | Adds `git log --oneline -n N`. `N` defaults to 10; pass a positive integer to override. | -| `max_iterations` | `before_llm_call` | `[""]` (required) | Hard-stops the agent after `N` model calls. State is per-session and reset at `session_end`. | +| `max_iterations` | `before_llm_call` | `[""]` (required) | Hard-stops the agent after `N` model calls. Stateless: the runtime supplies the iteration counter on every dispatch. | | `snapshot` | `session_start`, `turn_start`, `turn_end`, `pre_tool_use`, `post_tool_use`, `session_end` | _none_ | Records filesystem snapshots in a shadow git repo under the docker-agent data directory. No-op outside git repos; respects the source repo's ignore rules and skips newly-added files larger than 2 MiB. | | `redact_secrets` | `pre_tool_use`, `before_llm_call`, `tool_response_transform` | _none_ | Scrubs detected secrets (API keys, tokens, private keys, …) out of tool call arguments, outgoing chat content, and tool output. The same builtin handles all three events and dispatches on the event name. Auto-registered on all three events by `redact_secrets: true` on the agent — see [`examples/redact_secrets_hooks.yaml`](https://github.com/docker/docker-agent/blob/main/examples/redact_secrets_hooks.yaml) for the manual wiring. | | `unload` | `on_agent_switch` | _none_ | Walks the previous agent's models and calls `Unload()` on every provider that implements [`provider.Unloader`](https://pkg.go.dev/github.com/docker/docker-agent/pkg/model/provider#Unloader) — typically Docker Model Runner — to free the GPU/RAM the just-departing model was holding. Cloud-only providers don't implement the interface and are silently skipped. Errors are logged and swallowed; agent switching never blocks on a slow or unreachable engine (each Unload call has a 10 s timeout). See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml). | @@ -255,7 +255,7 @@ In addition to the common fields, each event ships its own payload: | `user_prompt_submit` | `prompt` — the text the user just submitted | | `turn_start` | _none_ (just the common fields) | | `turn_end` | `agent_name`, `reason` — one of `normal`, `continue`, `steered`, `error`, `canceled`, `hook_blocked`, `loop_detected` | -| `before_llm_call` | _none_ | +| `before_llm_call` | `iteration` — 1-based run-loop iteration counter (the model call this hook is gating) | | `after_llm_call` | `agent_name`, `stop_response`, `last_user_message` | | `session_end` | `reason` — one of `clear`, `logout`, `prompt_input_exit`, `other` | | `pre_compact` | `source` — one of `manual`, `auto`, `overflow`, `tool_overflow` | diff --git a/pkg/hooks/builtins/builtins.go b/pkg/hooks/builtins/builtins.go index cdfa98705..eac16c1bc 100644 --- a/pkg/hooks/builtins/builtins.go +++ b/pkg/hooks/builtins/builtins.go @@ -36,13 +36,11 @@ // // turn_start builtins recompute every turn (date, git state). // session_start builtins run once per session for context that's -// stable for its duration. max_iterations is stateful: its -// per-session counter lives on the [State] returned by [Register]; -// the runtime clears it via [State.ClearSession] from session_end. -// snapshot is also stateful: it keeps per-session turn/tool snapshot -// hashes and undo checkpoints in memory while the shadow git objects live -// under the data directory. Undo checkpoints intentionally survive the -// RunStream session_end cleanup so /undo can run after the response stops. +// stable for its duration. snapshot is stateful: it keeps per-session +// turn/tool snapshot hashes and undo checkpoints in memory while the +// shadow git objects live under the data directory. Undo checkpoints +// intentionally survive the RunStream session_end cleanup so /undo +// can run after the response stops. // // LLM-as-a-judge hooks are NOT shipped here: write `type: model` with // `schema: pre_tool_use_decision` instead — see @@ -57,20 +55,11 @@ import ( ) // State holds the per-runtime state of the stateful builtins. -// It is returned by [Register] so callers can clear per-session -// entries on session_end. Stateless builtins don't appear here. +// It is returned by [Register] so callers can reach into +// snapshot operations (undo / list / reset) without poking at +// builtin internals. Stateless builtins don't appear here. type State struct { - maxIterations *maxIterationsBuiltin - snapshot *snapshotBuiltin -} - -// ClearSession drops per-RunStream state that should not survive stream teardown. -// A nil receiver is a no-op. -func (s *State) ClearSession(sessionID string) { - if s == nil || sessionID == "" { - return - } - s.maxIterations.clearSession(sessionID) + snapshot *snapshotBuiltin } // UndoLastSnapshot restores files from the latest completed snapshot checkpoint. @@ -104,8 +93,7 @@ func (s *State) ResetSnapshot(ctx context.Context, sessionID, cwd string, keep i // handle the caller can use for stateful builtin operations. func Register(r *hooks.Registry) (*State, error) { state := &State{ - maxIterations: newMaxIterations(), - snapshot: newSnapshotBuiltin(), + snapshot: newSnapshotBuiltin(), } if err := errors.Join( r.RegisterBuiltin(AddDate, addDate), @@ -116,7 +104,7 @@ func Register(r *hooks.Registry) (*State, error) { r.RegisterBuiltin(AddDirectoryListing, addDirectoryListing), r.RegisterBuiltin(AddUserInfo, addUserInfo), r.RegisterBuiltin(AddRecentCommits, addRecentCommits), - r.RegisterBuiltin(MaxIterations, state.maxIterations.hook), + r.RegisterBuiltin(MaxIterations, maxIterations), r.RegisterBuiltin(Snapshot, state.snapshot.hook), r.RegisterBuiltin(RedactSecrets, redactSecrets), r.RegisterBuiltin(HTTPPost, httpPost), diff --git a/pkg/hooks/builtins/max_iterations.go b/pkg/hooks/builtins/max_iterations.go index c01997cb0..8e80ccbf9 100644 --- a/pkg/hooks/builtins/max_iterations.go +++ b/pkg/hooks/builtins/max_iterations.go @@ -5,7 +5,6 @@ import ( "fmt" "log/slog" "strconv" - "sync" "github.com/docker/docker-agent/pkg/hooks" ) @@ -13,9 +12,8 @@ import ( // MaxIterations is the registered name of the max_iterations builtin. const MaxIterations = "max_iterations" -// maxIterationsBuiltin counts before_llm_call invocations per session -// and signals a terminating verdict once the configured limit is -// exceeded. +// maxIterations signals a terminating verdict once the runtime has +// dispatched more before_llm_call events than the configured budget. // // This is a hard stop with no resume protocol — distinct from the // agent.MaxIterations flag, which has its own special UX @@ -25,19 +23,10 @@ const MaxIterations = "max_iterations" // // Args layout: `[limit]`. Missing or invalid args make the hook a // no-op so a misconfigured YAML doesn't accidentally cap a run at -// zero. State is per-session, keyed by [hooks.Input.SessionID], and -// cleared from session_end via [State.ClearSession]. -type maxIterationsBuiltin struct { - mu sync.Mutex - counts map[string]int // SessionID -> calls observed -} - -func newMaxIterations() *maxIterationsBuiltin { - return &maxIterationsBuiltin{counts: map[string]int{}} -} - -func (b *maxIterationsBuiltin) hook(_ context.Context, in *hooks.Input, args []string) (*hooks.Output, error) { - if in == nil || in.SessionID == "" || len(args) == 0 { +// zero. Stateless: the runtime supplies the 1-based iteration counter +// via [hooks.Input.Iteration]. +func maxIterations(_ context.Context, in *hooks.Input, args []string) (*hooks.Output, error) { + if in == nil || in.Iteration <= 0 || len(args) == 0 { return nil, nil } limit, err := strconv.Atoi(args[0]) @@ -46,17 +35,12 @@ func (b *maxIterationsBuiltin) hook(_ context.Context, in *hooks.Input, args []s return nil, nil } - b.mu.Lock() - b.counts[in.SessionID]++ - count := b.counts[in.SessionID] - b.mu.Unlock() - - if count <= limit { + if in.Iteration <= limit { return nil, nil } slog.Warn("max_iterations tripped", - "count", count, "limit", limit, "session_id", in.SessionID) + "iteration", in.Iteration, "limit", limit, "session_id", in.SessionID) return &hooks.Output{ Decision: hooks.DecisionBlockValue, @@ -65,9 +49,3 @@ func (b *maxIterationsBuiltin) hook(_ context.Context, in *hooks.Input, args []s limit), }, nil } - -func (b *maxIterationsBuiltin) clearSession(sessionID string) { - b.mu.Lock() - delete(b.counts, sessionID) - b.mu.Unlock() -} diff --git a/pkg/hooks/builtins/max_iterations_test.go b/pkg/hooks/builtins/max_iterations_test.go index ca9d82a1f..fac064e95 100644 --- a/pkg/hooks/builtins/max_iterations_test.go +++ b/pkg/hooks/builtins/max_iterations_test.go @@ -1,7 +1,6 @@ package builtins_test import ( - "sync" "testing" "github.com/stretchr/testify/assert" @@ -12,54 +11,29 @@ import ( ) // TestMaxIterationsTripsAfterLimit verifies the happy path: with a -// limit of 3, the first three calls are no-ops and the fourth returns -// a block decision. The reason carries the configured limit so the -// runtime's user-facing Error event explains why the run stopped. +// limit of 3, the first three iterations are no-ops and the fourth +// returns a block decision. The reason carries the configured limit +// so the runtime's user-facing Error event explains why the run +// stopped. func TestMaxIterationsTripsAfterLimit(t *testing.T) { t.Parallel() fn := lookup(t, builtins.MaxIterations) - in := &hooks.Input{SessionID: "s1"} args := []string{"3"} for i := 1; i <= 3; i++ { - out, err := fn(t.Context(), in, args) - require.NoErrorf(t, err, "call %d must not error", i) - require.Nilf(t, out, "call %d (within limit) must not trip", i) + out, err := fn(t.Context(), &hooks.Input{Iteration: i}, args) + require.NoErrorf(t, err, "iteration %d must not error", i) + require.Nilf(t, out, "iteration %d (within limit) must not trip", i) } - out, err := fn(t.Context(), in, args) + out, err := fn(t.Context(), &hooks.Input{Iteration: 4}, args) require.NoError(t, err) - require.NotNil(t, out, "fourth call (over limit) must trip") + require.NotNil(t, out, "iteration 4 (over limit) must trip") assert.Equal(t, hooks.DecisionBlockValue, out.Decision) assert.Contains(t, out.Reason, "3", "reason must include the configured limit") } -// TestMaxIterationsIsolatesSessions documents the per-session -// counter contract: a runtime serving multiple sessions must not let -// session A's calls count against session B's budget. -func TestMaxIterationsIsolatesSessions(t *testing.T) { - t.Parallel() - - fn := lookup(t, builtins.MaxIterations) - args := []string{"2"} - - // Session A: two no-op calls then one trip. - for range 2 { - out, err := fn(t.Context(), &hooks.Input{SessionID: "A"}, args) - require.NoError(t, err) - require.Nil(t, out) - } - out, err := fn(t.Context(), &hooks.Input{SessionID: "A"}, args) - require.NoError(t, err) - require.NotNil(t, out, "session A trips on its third call") - - // Session B: starts fresh, only sees one call so far. - out, err = fn(t.Context(), &hooks.Input{SessionID: "B"}, args) - require.NoError(t, err) - require.Nil(t, out, "session B's counter must not include session A's calls") -} - // TestMaxIterationsNoOpWithoutValidLimit documents the lenient-args // contract: a missing, non-integer, zero, or negative limit makes // the builtin a no-op rather than tripping (the safer default for a @@ -75,11 +49,11 @@ func TestMaxIterationsNoOpWithoutValidLimit(t *testing.T) { {"-1"}, } for _, args := range cases { - fn := lookup(t, builtins.MaxIterations) // fresh state per case - // Drive 50 calls — if the builtin were tripping erroneously, + fn := lookup(t, builtins.MaxIterations) + // Drive 50 iterations — if the builtin were tripping erroneously, // at least one of these would return a non-nil Output. - for range 50 { - out, err := fn(t.Context(), &hooks.Input{SessionID: "s"}, args) + for i := 1; i <= 50; i++ { + out, err := fn(t.Context(), &hooks.Input{Iteration: i}, args) require.NoError(t, err) require.Nilf(t, out, "args=%v: must never trip", args) } @@ -87,9 +61,9 @@ func TestMaxIterationsNoOpWithoutValidLimit(t *testing.T) { } // TestMaxIterationsIgnoresIncompleteInput pins the defensive guard: -// missing SessionID produces no state mutation and no output. This -// protects against future dispatch changes where an edge case might -// fire before_llm_call without that field populated. +// missing or non-positive Iteration produces no output. This protects +// against future dispatch changes where an edge case might fire +// before_llm_call without that field populated. func TestMaxIterationsIgnoresIncompleteInput(t *testing.T) { t.Parallel() @@ -99,28 +73,9 @@ func TestMaxIterationsIgnoresIncompleteInput(t *testing.T) { require.NoError(t, err) assert.Nil(t, out) + // Iteration=0 (zero value) means "not populated" — the runtime + // always supplies a 1-based counter on before_llm_call. out, err = fn(t.Context(), &hooks.Input{}, []string{"1"}) require.NoError(t, err) assert.Nil(t, out) } - -// TestMaxIterationsConcurrentCallsAreSafe is a smoke test for the -// builtin's mutex. Many goroutines incrementing the same session's -// counter must not race (run with -race). -func TestMaxIterationsConcurrentCallsAreSafe(t *testing.T) { - t.Parallel() - - fn := lookup(t, builtins.MaxIterations) - in := &hooks.Input{SessionID: "concurrent"} - - const callers = 50 - var wg sync.WaitGroup - wg.Add(callers) - for range callers { - go func() { - defer wg.Done() - _, _ = fn(t.Context(), in, []string{"100"}) - }() - } - wg.Wait() -} diff --git a/pkg/hooks/builtins/redact_secrets_test.go b/pkg/hooks/builtins/redact_secrets_test.go index 94a9b07b2..f7888929d 100644 --- a/pkg/hooks/builtins/redact_secrets_test.go +++ b/pkg/hooks/builtins/redact_secrets_test.go @@ -128,9 +128,8 @@ func TestRedactSecretsIsRegistered(t *testing.T) { t.Parallel() reg := hooks.NewRegistry() - state, err := Register(reg) + _, err := Register(reg) require.NoError(t, err) - t.Cleanup(func() { state.ClearSession("") }) handler, ok := reg.LookupBuiltin(RedactSecrets) require.Truef(t, ok, "builtin %q must be registered", RedactSecrets) diff --git a/pkg/hooks/builtins/snapshot_test.go b/pkg/hooks/builtins/snapshot_test.go index d748e1172..b8217bee2 100644 --- a/pkg/hooks/builtins/snapshot_test.go +++ b/pkg/hooks/builtins/snapshot_test.go @@ -68,7 +68,6 @@ func TestSnapshotBuiltinUndoSurvivesStreamEnd(t *testing.T) { Reason: "stream_ended", }, nil) require.NoError(t, err) - state.ClearSession("s") entries, err := os.ReadDir(filepath.Join(paths.GetDataDir(), "snapshot")) require.NoError(t, err) diff --git a/pkg/hooks/types.go b/pkg/hooks/types.go index 474c0cf05..97ceb42ac 100644 --- a/pkg/hooks/types.go +++ b/pkg/hooks/types.go @@ -185,6 +185,14 @@ type Input struct { // model-call-scoped. ModelID string `json:"model_id,omitempty"` + // Iteration is the 1-based run-loop iteration counter for the + // model call this dispatch is gating. Populated for + // [EventBeforeLLMCall] (1 for the first call of the RunStream, 2 + // for the second, ...); zero for events not tied to a loop + // iteration. The max_iterations builtin compares it to a configured + // budget without per-session state. + Iteration int `json:"iteration,omitempty"` + // LastUserMessage is the text content of the latest user message in // the session at dispatch time. Populated for events that respond to // a user turn (stop, after_llm_call). Empty for events that aren't diff --git a/pkg/runtime/before_llm_call_test.go b/pkg/runtime/before_llm_call_test.go index 958950d1a..16220f321 100644 --- a/pkg/runtime/before_llm_call_test.go +++ b/pkg/runtime/before_llm_call_test.go @@ -2,6 +2,7 @@ package runtime import ( "context" + "strings" "sync/atomic" "testing" @@ -9,10 +10,13 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/tools" ) // TestBeforeLLMCallHookFiresOncePerLoopIteration is a regression test @@ -21,11 +25,17 @@ import ( // bug would silently break stateful before_llm_call hooks (the // max_iterations builtin would have tripped at half its configured // limit). A single-turn session must observe exactly one fire. +// +// It also pins the [hooks.Input.Iteration] contract: the runtime +// surfaces a 1-based loop counter on every dispatch so stateless +// guards (like the max_iterations builtin) can compare it to a +// configured budget without keeping their own per-session counter. func TestBeforeLLMCallHookFiresOncePerLoopIteration(t *testing.T) { t.Parallel() const counterName = "test-before-llm-counter" var calls atomic.Int32 + var lastIteration atomic.Int32 stream := newStreamBuilder(). AddContent("Hello"). @@ -53,8 +63,9 @@ func TestBeforeLLMCallHookFiresOncePerLoopIteration(t *testing.T) { // so registering after NewLocalRuntime is sufficient. require.NoError(t, rt.hooksRegistry.RegisterBuiltin( counterName, - func(_ context.Context, _ *hooks.Input, _ []string) (*hooks.Output, error) { + func(_ context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { calls.Add(1) + lastIteration.Store(int32(in.Iteration)) return nil, nil }, )) @@ -68,4 +79,177 @@ func TestBeforeLLMCallHookFiresOncePerLoopIteration(t *testing.T) { assert.Equal(t, int32(1), calls.Load(), "before_llm_call must fire exactly once per loop iteration; "+ "a duplicate dispatch would silently break stateful hooks like max_iterations") + assert.Equal(t, int32(1), lastIteration.Load(), + "first model call of a RunStream must carry Iteration=1 "+ + "so the max_iterations builtin can compare it to its configured budget") +} + +// TestMaxIterationsBuiltin_TripsAfterConfiguredLimit is the real e2e +// test for the max_iterations builtin: it stands up an agent whose +// model issues a tool call on every iteration (so the loop never +// terminates on its own), wires the builtin in via +// before_llm_call, and asserts that the loop is hard-stopped at +// exactly the configured budget — not under it (nothing was lost +// by going stateless), not over it (the gate held). +// +// This is the regression test that pins issue #2698: the runtime +// surfaces [hooks.Input.Iteration] on every dispatch so the builtin +// can compare it to its limit without keeping a per-session counter +// or relying on session_end cleanup. +func TestMaxIterationsBuiltin_TripsAfterConfiguredLimit(t *testing.T) { + t.Parallel() + + const limit = 3 + + // A tool that always succeeds. Every iteration the model issues a + // call to it, which means the loop would run forever if not for the + // max_iterations gate. + loopTool := tools.Tool{ + Name: "do_thing", + Parameters: map[string]any{}, + Handler: func(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) { + return tools.ResultSuccess("ok"), nil + }, + } + + // Build a fresh tool-call stream per iteration. We queue more + // streams than `limit` so a regression that lets the loop run an + // extra time still has a stream to consume — the test then catches + // the over-run via prov.callIdx. + newToolCallStream := func(callID string) *mockStream { + return newStreamBuilder(). + AddToolCallName(callID, loopTool.Name). + AddToolCallArguments(callID, `{}`). + AddToolCallStopWithUsage(2, 2). + Build() + } + prov := &queueProvider{ + id: "test/mock-model", + streams: []chat.MessageStream{ + newToolCallStream("call_1"), + newToolCallStream("call_2"), + newToolCallStream("call_3"), + // Extra streams so an off-by-one regression can be detected + // rather than masked by an empty queue defaulting to a stop. + newToolCallStream("call_4"), + newToolCallStream("call_5"), + }, + } + + root := agent.New("root", "test agent", + agent.WithModel(prov), + agent.WithToolSets(newStubToolSet(nil, []tools.Tool{loopTool}, nil)), + agent.WithHooks(&latest.HooksConfig{ + BeforeLLMCall: []latest.HookDefinition{ + {Type: "builtin", Command: builtins.MaxIterations, Args: []string{"3"}}, + }, + }), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + ) + require.NoError(t, err) + rt.registerDefaultTools() + + sess := session.New( + session.WithUserMessage("loop forever"), + session.WithToolsApproved(true), + ) + sess.Title = "max_iterations e2e" + + var events []Event + for ev := range rt.RunStream(t.Context(), sess) { + events = append(events, ev) + } + + // The model MUST have been called exactly `limit` times. The 4th + // dispatch is where the builtin trips, before any model call. + prov.mu.Lock() + // queueProvider doesn't track callIdx, but it pops from streams as + // it runs — the residual length tells us how many calls happened. + callsMade := 5 - len(prov.streams) + prov.mu.Unlock() + assert.Equal(t, limit, callsMade, + "max_iterations(%d) must allow exactly %d model calls, got %d", + limit, limit, callsMade) + + // The runtime must surface the builtin's block reason as an + // ErrorEvent so the user sees why the run stopped — not just + // silently exit. + var errEvt *ErrorEvent + for _, ev := range events { + if e, ok := ev.(*ErrorEvent); ok { + errEvt = e + break + } + } + require.NotNil(t, errEvt, "max_iterations trip must surface as an ErrorEvent") + assert.Contains(t, strings.ToLower(errEvt.Error), "max_iterations", + "error must mention the tripping builtin so users can identify the cause") + assert.Contains(t, errEvt.Error, "3", + "error must include the configured limit so users can adjust their YAML") +} + +// TestMaxIterationsBuiltin_NoOpOnInvalidLimit asserts the lenient-args +// contract of the builtin from end-to-end: when the YAML configures an +// invalid limit (zero, negative, non-numeric, missing), the runtime +// must NOT trip prematurely. A misconfigured limit becomes a no-op +// rather than an instant kill switch — a safer default for users. +func TestMaxIterationsBuiltin_NoOpOnInvalidLimit(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + }{ + {"missing", nil}, + {"empty", []string{}}, + {"zero", []string{"0"}}, + {"negative", []string{"-5"}}, + {"non_numeric", []string{"three"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + stream := newStreamBuilder(). + AddContent("hello"). + AddStopWithUsage(2, 2). + Build() + prov := &mockProvider{id: "test/mock-model", stream: stream} + + root := agent.New("root", "test agent", + agent.WithModel(prov), + agent.WithHooks(&latest.HooksConfig{ + BeforeLLMCall: []latest.HookDefinition{ + {Type: "builtin", Command: builtins.MaxIterations, Args: tc.args}, + }, + }), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + ) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("hi")) + sess.Title = "Unit Test" + + var events []Event + for ev := range rt.RunStream(t.Context(), sess) { + events = append(events, ev) + } + + for _, ev := range events { + if e, ok := ev.(*ErrorEvent); ok { + t.Fatalf("invalid limit %v must be a no-op, but the run was terminated: %s", tc.args, e.Error) + } + } + }) + } } diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 63172cf30..39b9405b1 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -189,15 +189,12 @@ func contextMessages(result *hooks.Result) []chat.Message { }} } -// executeSessionEndHooks fires session_end when the run loop exits -// and clears any per-session state held by stateful builtins so a -// long-running runtime stays bounded. +// executeSessionEndHooks fires session_end when the run loop exits. func (r *LocalRuntime) executeSessionEndHooks(ctx context.Context, sess *session.Session, a *agent.Agent) { r.dispatchHook(ctx, a, hooks.EventSessionEnd, &hooks.Input{ SessionID: sess.ID, Reason: "stream_ended", }, nil) - r.builtinsState.ClearSession(sess.ID) } // executeStopHooks fires stop hooks when the model finishes responding, @@ -329,10 +326,11 @@ func (r *LocalRuntime) executeOnToolApprovalDecisionHooks( // the contract. Hooks that just want to contribute system messages // should target turn_start instead. // -// modelID is the canonical model identifier the loop has just -// resolved (after per-tool overrides and alloy-mode selection); it's -// surfaced to hooks via [hooks.Input.ModelID] so handlers don't need -// to recompute it from the agent. +// modelID and iteration are surfaced verbatim via +// [hooks.Input.ModelID] / [hooks.Input.Iteration] so handlers (notably +// the max_iterations builtin) don't have to recompute them — the +// loop's resolved values reflect per-tool overrides and alloy-mode +// selection that an Agent.Model() lookup would miss. // // messages is the conversation snapshot the runtime is about to send // to the model. Hooks may return a rewrite via @@ -340,15 +338,13 @@ func (r *LocalRuntime) executeOnToolApprovalDecisionHooks( // builtin scrubbing outbound chat content); the rewrite is returned // in the third tuple value when present, nil otherwise. Callers must // swap the rewrite in BEFORE the model call so the LLM never sees the -// original content. messages is passed through to hooks only when at -// least one before_llm_call hook is configured (see [dispatchHook]), -// so observational hook configurations don't pay the JSON-encoding -// cost on every model call. +// original content. func (r *LocalRuntime) executeBeforeLLMCallHooks( ctx context.Context, sess *session.Session, a *agent.Agent, modelID string, + iteration int, messages []chat.Message, ) (stop bool, message string, rewritten []chat.Message) { exec := r.hooksExec(a) @@ -362,6 +358,7 @@ func (r *LocalRuntime) executeBeforeLLMCallHooks( SessionID: sess.ID, AgentName: a.Name(), ModelID: modelID, + Iteration: iteration, Messages: messages, }, nil) if result == nil { diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 4c387161d..4952f3f91 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -386,7 +386,7 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, // AFTER the closure body has assigned both, so callers see the same // reason the runtime took. ctrl drives the outer for-loop's // continue-or-exit decision. - ctrl := r.runTurn(ctx, sess, a, m, model, modelID, contextLimit, sessionSpan, + ctrl := r.runTurn(ctx, sess, a, m, model, modelID, iteration, contextLimit, sessionSpan, slices.Concat(sessionStartMsgs, userPromptMsgs), agentTools, loopDetector, &overflowCompactions, &toolModelOverride, events) switch ctrl { @@ -436,6 +436,7 @@ func (r *LocalRuntime) runTurn( m *modelsdev.Model, model provider.Provider, modelID string, + iteration int, contextLimit int64, sessionSpan trace.Span, priorExtras []chat.Message, @@ -504,7 +505,7 @@ func (r *LocalRuntime) runTurn( // runtime's Go-only message transforms so a hook that drops a // message (e.g. a custom "strip system reminders") doesn't get // silently overridden by a transform later in the chain. - stop, msg, rewritten := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID, messages) + stop, msg, rewritten := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID, iteration, messages) if stop { slog.WarnContext(ctx, "before_llm_call hook signalled run termination", "agent", a.Name(), "session_id", sess.ID, "reason", msg) diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 474d508d8..5802f5e4a 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -165,10 +165,11 @@ type LocalRuntime struct { // touching any process-wide state. hooksRegistry *hooks.Registry - // builtinsState holds per-session state for the stateful builtins - // (loop_detector, max_iterations). The runtime calls - // builtinsState.ClearSession from session_end so a long-running - // runtime serving many sessions stays bounded. + // builtinsState exposes per-runtime state of the stateful builtins + // (currently just the snapshot store). It's the handle the runtime + // uses for snapshot operations like /undo, and was originally where + // the max_iterations counter lived before that builtin became + // stateless via [hooks.Input.Iteration]. builtinsState *builtins.State // hooksExecByAgent holds the per-agent [hooks.Executor], keyed by From e9d6bc07c05f4e6ac6d2f2f43d2130118f1e7c87 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 17:45:27 +0200 Subject: [PATCH 2/3] drop builtins.State wrapper, expose Snapshots directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With max_iterations gone stateless, State was a single-field struct with three pure pass-through methods around the snapshot tracker. Promote *Snapshots to the package's public type and have Register return it directly — one fewer layer between the runtime and the checkpoint history. --- pkg/hooks/builtins/builtins.go | 52 +++++------------------------ pkg/hooks/builtins/snapshot.go | 51 +++++++++++++++++----------- pkg/hooks/builtins/snapshot_test.go | 30 ++++++++--------- pkg/runtime/runtime.go | 16 ++++----- pkg/runtime/snapshot.go | 6 ++-- 5 files changed, 65 insertions(+), 90 deletions(-) diff --git a/pkg/hooks/builtins/builtins.go b/pkg/hooks/builtins/builtins.go index eac16c1bc..b5af32656 100644 --- a/pkg/hooks/builtins/builtins.go +++ b/pkg/hooks/builtins/builtins.go @@ -48,53 +48,17 @@ package builtins import ( - "context" "errors" "github.com/docker/docker-agent/pkg/hooks" ) -// State holds the per-runtime state of the stateful builtins. -// It is returned by [Register] so callers can reach into -// snapshot operations (undo / list / reset) without poking at -// builtin internals. Stateless builtins don't appear here. -type State struct { - snapshot *snapshotBuiltin -} - -// UndoLastSnapshot restores files from the latest completed snapshot checkpoint. -func (s *State) UndoLastSnapshot(ctx context.Context, sessionID, cwd string) (files int, ok bool, err error) { - if s == nil || s.snapshot == nil || sessionID == "" || cwd == "" { - return 0, false, nil - } - return s.snapshot.undoLast(ctx, sessionID, cwd) -} - -// ListSnapshots returns the completed snapshot checkpoints for a session in -// chronological order (oldest first). Returns nil when no snapshots exist. -func (s *State) ListSnapshots(sessionID string) []SnapshotInfo { - if s == nil || s.snapshot == nil || sessionID == "" { - return nil - } - return s.snapshot.listSnapshots(sessionID) -} - -// ResetSnapshot reverts every checkpoint past index keep so the workspace -// returns to the state captured at that snapshot. keep == 0 resets to the -// original (pre-agent) state. -func (s *State) ResetSnapshot(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) { - if s == nil || s.snapshot == nil || sessionID == "" || cwd == "" { - return 0, false, nil - } - return s.snapshot.resetSnapshot(ctx, sessionID, cwd, keep) -} - -// Register installs the stock builtin hooks on r and returns a [State] -// handle the caller can use for stateful builtin operations. -func Register(r *hooks.Registry) (*State, error) { - state := &State{ - snapshot: newSnapshotBuiltin(), - } +// Register installs the stock builtin hooks on r and returns the +// shared [*Snapshots] tracker so the caller (typically the runtime) +// can drive /undo, /list-snapshots, and /reset against the same +// in-memory checkpoint history the snapshot hook is writing to. +func Register(r *hooks.Registry) (*Snapshots, error) { + snapshots := NewSnapshots() if err := errors.Join( r.RegisterBuiltin(AddDate, addDate), r.RegisterBuiltin(AddEnvironmentInfo, addEnvironmentInfo), @@ -105,13 +69,13 @@ func Register(r *hooks.Registry) (*State, error) { r.RegisterBuiltin(AddUserInfo, addUserInfo), r.RegisterBuiltin(AddRecentCommits, addRecentCommits), r.RegisterBuiltin(MaxIterations, maxIterations), - r.RegisterBuiltin(Snapshot, state.snapshot.hook), + r.RegisterBuiltin(Snapshot, snapshots.Hook), r.RegisterBuiltin(RedactSecrets, redactSecrets), r.RegisterBuiltin(HTTPPost, httpPost), ); err != nil { return nil, err } - return state, nil + return snapshots, nil } // AgentDefaults captures defaults that map onto stock builtin hook entries. diff --git a/pkg/hooks/builtins/snapshot.go b/pkg/hooks/builtins/snapshot.go index fb437cbab..317cc4986 100644 --- a/pkg/hooks/builtins/snapshot.go +++ b/pkg/hooks/builtins/snapshot.go @@ -19,7 +19,7 @@ type SnapshotInfo struct { Files int } -type snapshotBuiltin struct { +type Snapshots struct { manager *snapshot.Manager mu sync.Mutex session map[string]*snapshotSession @@ -36,14 +36,21 @@ type snapshotCheckpoint struct { files []string } -func newSnapshotBuiltin() *snapshotBuiltin { - return &snapshotBuiltin{ +// NewSnapshots returns a fresh snapshot tracker. Held by the runtime +// for /undo, /list-snapshots and /reset; the same instance backs the +// snapshot hook registered under [Snapshot]. +func NewSnapshots() *Snapshots { + return &Snapshots{ manager: snapshot.NewManager(""), session: map[string]*snapshotSession{}, } } -func (b *snapshotBuiltin) hook(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { +// Hook is the [hooks.BuiltinFunc] dispatched on every snapshot event. +// It tracks per-session turn/tool hashes, captures patches at +// turn_end / post_tool_use, and runs the shadow-repo cleanup at +// session_end. +func (b *Snapshots) Hook(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { if in == nil || in.Cwd == "" || in.SessionID == "" { return nil, nil } @@ -116,7 +123,7 @@ func (b *snapshotBuiltin) hook(ctx context.Context, in *hooks.Input, _ []string) return nil, nil } -func (b *snapshotBuiltin) getSession(sessionID string) *snapshotSession { +func (b *Snapshots) getSession(sessionID string) *snapshotSession { s := b.session[sessionID] if s == nil { s = &snapshotSession{tools: map[string]string{}} @@ -125,13 +132,13 @@ func (b *snapshotBuiltin) getSession(sessionID string) *snapshotSession { return s } -func (b *snapshotBuiltin) setTurn(sessionID, hash string) { +func (b *Snapshots) setTurn(sessionID, hash string) { b.mu.Lock() defer b.mu.Unlock() b.getSession(sessionID).turn = hash } -func (b *snapshotBuiltin) popTurn(sessionID string) string { +func (b *Snapshots) popTurn(sessionID string) string { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] @@ -143,7 +150,7 @@ func (b *snapshotBuiltin) popTurn(sessionID string) string { return hash } -func (b *snapshotBuiltin) setTool(sessionID, toolUseID, hash string) { +func (b *Snapshots) setTool(sessionID, toolUseID, hash string) { if toolUseID == "" { return } @@ -152,7 +159,7 @@ func (b *snapshotBuiltin) setTool(sessionID, toolUseID, hash string) { b.getSession(sessionID).tools[toolUseID] = hash } -func (b *snapshotBuiltin) popTool(sessionID, toolUseID string) string { +func (b *Snapshots) popTool(sessionID, toolUseID string) string { if toolUseID == "" { return "" } @@ -167,7 +174,7 @@ func (b *snapshotBuiltin) popTool(sessionID, toolUseID string) string { return hash } -func (b *snapshotBuiltin) pushCheckpoint(sessionID string, checkpoint snapshotCheckpoint) { +func (b *Snapshots) pushCheckpoint(sessionID string, checkpoint snapshotCheckpoint) { if len(checkpoint.files) == 0 { return } @@ -177,7 +184,7 @@ func (b *snapshotBuiltin) pushCheckpoint(sessionID string, checkpoint snapshotCh s.history = append(s.history, checkpoint) } -func (b *snapshotBuiltin) popCheckpoint(sessionID string) (snapshotCheckpoint, bool) { +func (b *Snapshots) popCheckpoint(sessionID string) (snapshotCheckpoint, bool) { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] @@ -191,7 +198,10 @@ func (b *snapshotBuiltin) popCheckpoint(sessionID string) (snapshotCheckpoint, b return checkpoint, true } -func (b *snapshotBuiltin) undoLast(ctx context.Context, sessionID, cwd string) (files int, ok bool, err error) { +// UndoLast restores the files captured by the most recent checkpoint. +// Returns (filesRestored, true, nil) on success, (0, false, nil) when +// there is nothing to undo. +func (b *Snapshots) UndoLast(ctx context.Context, sessionID, cwd string) (files int, ok bool, err error) { checkpoint, ok := b.popCheckpoint(sessionID) if !ok { return 0, false, nil @@ -210,9 +220,9 @@ func (b *snapshotBuiltin) undoLast(ctx context.Context, sessionID, cwd string) ( return len(checkpoint.files), true, nil } -// listSnapshots returns the completed checkpoints for a session in chronological +// List returns the completed checkpoints for a session in chronological // order (oldest first). The returned slice may be empty. -func (b *snapshotBuiltin) listSnapshots(sessionID string) []SnapshotInfo { +func (b *Snapshots) List(sessionID string) []SnapshotInfo { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] @@ -226,11 +236,12 @@ func (b *snapshotBuiltin) listSnapshots(sessionID string) []SnapshotInfo { return out } -// resetSnapshot reverts every checkpoint with index >= keep so the workspace -// returns to the state captured at snapshot keep. keep == 0 means "reset to -// the original state". A keep value greater than or equal to the snapshot -// count is a no-op. Reverted checkpoints are dropped from the session history. -func (b *snapshotBuiltin) resetSnapshot(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) { +// Reset reverts every checkpoint with index >= keep so the workspace +// returns to the state captured at snapshot keep. keep == 0 means +// "reset to the original state". A keep value greater than or equal +// to the snapshot count is a no-op. Reverted checkpoints are dropped +// from the session history. +func (b *Snapshots) Reset(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) { tail := b.popHistoryTail(sessionID, keep) if len(tail) == 0 { return 0, false, nil @@ -257,7 +268,7 @@ func (b *snapshotBuiltin) resetSnapshot(ctx context.Context, sessionID, cwd stri // the surviving prefix in the session history. keep is clamped to [0, len]. // The popped slots in the backing array are zeroed so the dropped file lists // can be garbage-collected before the slice grows past them again. -func (b *snapshotBuiltin) popHistoryTail(sessionID string, keep int) []snapshotCheckpoint { +func (b *Snapshots) popHistoryTail(sessionID string, keep int) []snapshotCheckpoint { b.mu.Lock() defer b.mu.Unlock() s := b.session[sessionID] diff --git a/pkg/hooks/builtins/snapshot_test.go b/pkg/hooks/builtins/snapshot_test.go index b8217bee2..1d07ec5a8 100644 --- a/pkg/hooks/builtins/snapshot_test.go +++ b/pkg/hooks/builtins/snapshot_test.go @@ -22,7 +22,7 @@ func TestSnapshotBuiltinUndoSurvivesStreamEnd(t *testing.T) { t.Cleanup(func() { paths.SetDataDir("") }) r := hooks.NewRegistry() - state, err := builtins.Register(r) + snapshots, err := builtins.Register(r) require.NoError(t, err) fn, ok := r.LookupBuiltin(builtins.Snapshot) require.True(t, ok) @@ -74,7 +74,7 @@ func TestSnapshotBuiltinUndoSurvivesStreamEnd(t *testing.T) { require.Len(t, entries, 1) require.DirExists(t, filepath.Join(paths.GetDataDir(), "snapshot", entries[0].Name())) - files, restored, err := state.UndoLastSnapshot(t.Context(), "s", dir) + files, restored, err := snapshots.UndoLast(t.Context(), "s", dir) require.NoError(t, err) assert.True(t, restored) assert.Equal(t, 1, files) @@ -89,7 +89,7 @@ func TestSnapshotBuiltinListAndReset(t *testing.T) { t.Cleanup(func() { paths.SetDataDir("") }) r := hooks.NewRegistry() - state, err := builtins.Register(r) + snapshots, err := builtins.Register(r) require.NoError(t, err) fn, ok := r.LookupBuiltin(builtins.Snapshot) require.True(t, ok) @@ -97,7 +97,7 @@ func TestSnapshotBuiltinListAndReset(t *testing.T) { dir := snapshotBuiltinRepo(t) // Initially: no checkpoints. - assert.Empty(t, state.ListSnapshots("s")) + assert.Empty(t, snapshots.List("s")) // Capture three snapshots: each turn modifies one file. recordTurn := func(t *testing.T, name, contents string) { @@ -122,33 +122,33 @@ func TestSnapshotBuiltinListAndReset(t *testing.T) { recordTurn(t, "b.txt", "b") recordTurn(t, "c.txt", "c") - snaps := state.ListSnapshots("s") + snaps := snapshots.List("s") require.Len(t, snaps, 3) assert.Equal(t, 1, snaps[0].Files) assert.Equal(t, 1, snaps[1].Files) assert.Equal(t, 1, snaps[2].Files) // Reset to snapshot 2: revert turn 3 only, leaving a.txt and b.txt intact. - files, restored, err := state.ResetSnapshot(t.Context(), "s", dir, 2) + files, restored, err := snapshots.Reset(t.Context(), "s", dir, 2) require.NoError(t, err) assert.True(t, restored) assert.Equal(t, 1, files) assert.FileExists(t, filepath.Join(dir, "a.txt")) assert.FileExists(t, filepath.Join(dir, "b.txt")) assert.NoFileExists(t, filepath.Join(dir, "c.txt")) - require.Len(t, state.ListSnapshots("s"), 2) + require.Len(t, snapshots.List("s"), 2) // Reset to original: revert remaining checkpoints, deleting all three files. - files, restored, err = state.ResetSnapshot(t.Context(), "s", dir, 0) + files, restored, err = snapshots.Reset(t.Context(), "s", dir, 0) require.NoError(t, err) assert.True(t, restored) assert.Equal(t, 2, files) assert.NoFileExists(t, filepath.Join(dir, "a.txt")) assert.NoFileExists(t, filepath.Join(dir, "b.txt")) - assert.Empty(t, state.ListSnapshots("s")) + assert.Empty(t, snapshots.List("s")) // Subsequent reset is a no-op (nothing to revert). - _, restored, err = state.ResetSnapshot(t.Context(), "s", dir, 0) + _, restored, err = snapshots.Reset(t.Context(), "s", dir, 0) require.NoError(t, err) assert.False(t, restored) } @@ -161,7 +161,7 @@ func TestSnapshotBuiltinResetKeepBeyondHistoryIsNoop(t *testing.T) { t.Cleanup(func() { paths.SetDataDir("") }) r := hooks.NewRegistry() - state, err := builtins.Register(r) + snapshots, err := builtins.Register(r) require.NoError(t, err) fn, ok := r.LookupBuiltin(builtins.Snapshot) require.True(t, ok) @@ -183,18 +183,18 @@ func TestSnapshotBuiltinResetKeepBeyondHistoryIsNoop(t *testing.T) { require.NoError(t, err) // keep == len(history) means "keep everything" — no checkpoints reverted. - files, restored, err := state.ResetSnapshot(t.Context(), "s", dir, 1) + files, restored, err := snapshots.Reset(t.Context(), "s", dir, 1) require.NoError(t, err) assert.False(t, restored) assert.Equal(t, 0, files) assert.FileExists(t, filepath.Join(dir, "a.txt")) - require.Len(t, state.ListSnapshots("s"), 1) + require.Len(t, snapshots.List("s"), 1) // keep way past the end is also a no-op. - _, restored, err = state.ResetSnapshot(t.Context(), "s", dir, 99) + _, restored, err = snapshots.Reset(t.Context(), "s", dir, 99) require.NoError(t, err) assert.False(t, restored) - require.Len(t, state.ListSnapshots("s"), 1) + require.Len(t, snapshots.List("s"), 1) } func snapshotBuiltinRepo(t *testing.T) string { diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 5802f5e4a..af3c315f9 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -165,12 +165,12 @@ type LocalRuntime struct { // touching any process-wide state. hooksRegistry *hooks.Registry - // builtinsState exposes per-runtime state of the stateful builtins - // (currently just the snapshot store). It's the handle the runtime - // uses for snapshot operations like /undo, and was originally where - // the max_iterations counter lived before that builtin became - // stateless via [hooks.Input.Iteration]. - builtinsState *builtins.State + // snapshots is the shared shadow-git snapshot tracker. The hook + // side ([Snapshots.Hook], wired in via [builtins.Register]) writes + // checkpoints during the run; the runtime side + // ([UndoLastSnapshot] / [ListSnapshots] / [ResetSnapshot]) + // reads them for /undo and /reset commands. + snapshots *builtins.Snapshots // hooksExecByAgent holds the per-agent [hooks.Executor], keyed by // agent name. Built once in [NewLocalRuntime.buildHooksExecutors] @@ -404,7 +404,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { } hooksRegistry := hooks.NewRegistry() - builtinsState, err := builtins.Register(hooksRegistry) + snapshots, err := builtins.Register(hooksRegistry) if err != nil { return nil, fmt.Errorf("register builtin hooks: %w", err) } @@ -422,7 +422,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { managedOAuth: true, sessionStore: session.NewInMemorySessionStore(), hooksRegistry: hooksRegistry, - builtinsState: builtinsState, + snapshots: snapshots, fallback: newFallbackExecutor(), now: time.Now, telemetry: defaultTelemetry{}, diff --git a/pkg/runtime/snapshot.go b/pkg/runtime/snapshot.go index b80beafce..d527e80e1 100644 --- a/pkg/runtime/snapshot.go +++ b/pkg/runtime/snapshot.go @@ -28,7 +28,7 @@ func (r *LocalRuntime) UndoLastSnapshot(ctx context.Context, sess *session.Sessi if cwd == "" { return 0, false, nil } - return r.builtinsState.UndoLastSnapshot(ctx, sess.ID, cwd) + return r.snapshots.UndoLast(ctx, sess.ID, cwd) } // ListSnapshots returns the completed snapshot checkpoints recorded for the @@ -37,7 +37,7 @@ func (r *LocalRuntime) ListSnapshots(sess *session.Session) []builtins.SnapshotI if r == nil || sess == nil { return nil } - return r.builtinsState.ListSnapshots(sess.ID) + return r.snapshots.List(sess.ID) } // ResetSnapshot reverts every checkpoint past index keep so the workspace @@ -48,7 +48,7 @@ func (r *LocalRuntime) ResetSnapshot(ctx context.Context, sess *session.Session, if cwd == "" { return 0, false, nil } - return r.builtinsState.ResetSnapshot(ctx, sess.ID, cwd, keep) + return r.snapshots.Reset(ctx, sess.ID, cwd, keep) } // snapshotCwd resolves the working directory used to open the shadow From b3672045d9be3fdb30527aeba613047f38ccca56 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 17:59:04 +0200 Subject: [PATCH 3/3] docs: add godoc on Snapshots type --- pkg/hooks/builtins/snapshot.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/hooks/builtins/snapshot.go b/pkg/hooks/builtins/snapshot.go index 317cc4986..5a2d8656c 100644 --- a/pkg/hooks/builtins/snapshot.go +++ b/pkg/hooks/builtins/snapshot.go @@ -19,6 +19,11 @@ type SnapshotInfo struct { Files int } +// Snapshots tracks per-session shadow-git checkpoints. The same +// instance is dispatched as the snapshot builtin (registered under +// [Snapshot] via [Hook]) and exposed to the runtime for /undo, +// /list-snapshots and /reset (via [UndoLast] / [List] / [Reset]). +// Construct with [NewSnapshots]; the zero value is not usable. type Snapshots struct { manager *snapshot.Manager mu sync.Mutex