Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/configuration/hooks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Built-ins are typically zero-config and faster than equivalent shell hooks becau
| `add_directory_listing` | `session_start` | _none_ | Adds an alphabetical listing of the cwd's top-level entries (skips dot-files, capped at 100 with a "... and N more"). |
| `add_user_info` | `session_start` | _none_ | Adds the current OS user (username and full name) and the hostname. |
| `add_recent_commits` | `session_start` | _none_, or `["<N>"]` | Adds `git log --oneline -n N`. `N` defaults to 10; pass a positive integer to override. |
| `max_iterations` | `before_llm_call` | `["<N>"]` (required) | Hard-stops the agent after `N` model calls. State is per-session and reset at `session_end`. |
| `max_iterations` | `before_llm_call` | `["<N>"]` (required) | Hard-stops the agent after `N` model calls. Stateless: the runtime supplies the iteration counter on every dispatch. |
| `snapshot` | `session_start`, `turn_start`, `turn_end`, `pre_tool_use`, `post_tool_use`, `session_end` | _none_ | Records filesystem snapshots in a shadow git repo under the docker-agent data directory. No-op outside git repos; respects the source repo's ignore rules and skips newly-added files larger than 2 MiB. |
| `redact_secrets` | `pre_tool_use`, `before_llm_call`, `tool_response_transform` | _none_ | Scrubs detected secrets (API keys, tokens, private keys, …) out of tool call arguments, outgoing chat content, and tool output. The same builtin handles all three events and dispatches on the event name. Auto-registered on all three events by `redact_secrets: true` on the agent — see [`examples/redact_secrets_hooks.yaml`](https://github.com/docker/docker-agent/blob/main/examples/redact_secrets_hooks.yaml) for the manual wiring. |
| `unload` | `on_agent_switch` | _none_ | Walks the previous agent's models and calls `Unload()` on every provider that implements [`provider.Unloader`](https://pkg.go.dev/github.com/docker/docker-agent/pkg/model/provider#Unloader) — typically Docker Model Runner — to free the GPU/RAM the just-departing model was holding. Cloud-only providers don't implement the interface and are silently skipped. Errors are logged and swallowed; agent switching never blocks on a slow or unreachable engine (each Unload call has a 10 s timeout). See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml). |
Expand Down Expand Up @@ -255,7 +255,7 @@ In addition to the common fields, each event ships its own payload:
| `user_prompt_submit` | `prompt` — the text the user just submitted |
| `turn_start` | _none_ (just the common fields) |
| `turn_end` | `agent_name`, `reason` — one of `normal`, `continue`, `steered`, `error`, `canceled`, `hook_blocked`, `loop_detected` |
| `before_llm_call` | _none_ |
| `before_llm_call` | `iteration` — 1-based run-loop iteration counter (the model call this hook is gating) |
| `after_llm_call` | `agent_name`, `stop_response`, `last_user_message` |
| `session_end` | `reason` — one of `clear`, `logout`, `prompt_input_exit`, `other` |
| `pre_compact` | `source` — one of `manual`, `auto`, `overflow`, `tool_overflow` |
Expand Down
76 changes: 14 additions & 62 deletions pkg/hooks/builtins/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,77 +35,29 @@
//
// turn_start builtins recompute every turn (date, git state).
// session_start builtins run once per session for context that's
// stable for its duration. max_iterations is stateful: its
// per-session counter lives on the [State] returned by [Register];
// the runtime clears it via [State.ClearSession] from session_end.
// snapshot is also stateful: it keeps per-session turn/tool snapshot
// hashes and undo checkpoints in memory while the shadow git objects live
// under the data directory. Undo checkpoints intentionally survive the
// RunStream session_end cleanup so /undo can run after the response stops.
// stable for its duration. snapshot is stateful: it keeps per-session
// turn/tool snapshot hashes and undo checkpoints in memory while the
// shadow git objects live under the data directory. Undo checkpoints
// intentionally survive the RunStream session_end cleanup so /undo
// can run after the response stops.
//
// LLM-as-a-judge hooks are NOT shipped here: write `type: model` with
// `schema: pre_tool_use_decision` instead — see
// pkg/hooks/shape_pre_tool_use_decision.go and examples/llm_judge.yaml.
package builtins

import (
"context"
"errors"

"github.com/docker/docker-agent/pkg/hooks"
)

// State holds the per-runtime state of the stateful builtins.
// It is returned by [Register] so callers can clear per-session
// entries on session_end. Stateless builtins don't appear here.
type State struct {
maxIterations *maxIterationsBuiltin
snapshot *snapshotBuiltin
}

// ClearSession drops per-RunStream state that should not survive stream teardown.
// A nil receiver is a no-op.
func (s *State) ClearSession(sessionID string) {
if s == nil || sessionID == "" {
return
}
s.maxIterations.clearSession(sessionID)
}

// UndoLastSnapshot restores files from the latest completed snapshot checkpoint.
func (s *State) UndoLastSnapshot(ctx context.Context, sessionID, cwd string) (files int, ok bool, err error) {
if s == nil || s.snapshot == nil || sessionID == "" || cwd == "" {
return 0, false, nil
}
return s.snapshot.undoLast(ctx, sessionID, cwd)
}

// ListSnapshots returns the completed snapshot checkpoints for a session in
// chronological order (oldest first). Returns nil when no snapshots exist.
func (s *State) ListSnapshots(sessionID string) []SnapshotInfo {
if s == nil || s.snapshot == nil || sessionID == "" {
return nil
}
return s.snapshot.listSnapshots(sessionID)
}

// ResetSnapshot reverts every checkpoint past index keep so the workspace
// returns to the state captured at that snapshot. keep == 0 resets to the
// original (pre-agent) state.
func (s *State) ResetSnapshot(ctx context.Context, sessionID, cwd string, keep int) (files int, ok bool, err error) {
if s == nil || s.snapshot == nil || sessionID == "" || cwd == "" {
return 0, false, nil
}
return s.snapshot.resetSnapshot(ctx, sessionID, cwd, keep)
}

// Register installs the stock builtin hooks on r and returns a [State]
// handle the caller can use for stateful builtin operations.
func Register(r *hooks.Registry) (*State, error) {
state := &State{
maxIterations: newMaxIterations(),
snapshot: newSnapshotBuiltin(),
}
// Register installs the stock builtin hooks on r and returns the
// shared [*Snapshots] tracker so the caller (typically the runtime)
// can drive /undo, /list-snapshots, and /reset against the same
// in-memory checkpoint history the snapshot hook is writing to.
func Register(r *hooks.Registry) (*Snapshots, error) {
snapshots := NewSnapshots()
if err := errors.Join(
r.RegisterBuiltin(AddDate, addDate),
r.RegisterBuiltin(AddEnvironmentInfo, addEnvironmentInfo),
Expand All @@ -115,13 +67,13 @@ func Register(r *hooks.Registry) (*State, error) {
r.RegisterBuiltin(AddDirectoryListing, addDirectoryListing),
r.RegisterBuiltin(AddUserInfo, addUserInfo),
r.RegisterBuiltin(AddRecentCommits, addRecentCommits),
r.RegisterBuiltin(MaxIterations, state.maxIterations.hook),
r.RegisterBuiltin(Snapshot, state.snapshot.hook),
r.RegisterBuiltin(MaxIterations, maxIterations),
r.RegisterBuiltin(Snapshot, snapshots.Hook),
r.RegisterBuiltin(RedactSecrets, redactSecrets),
); err != nil {
return nil, err
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*[MEDIUM] Register returns a nil Snapshots on error — callers that use the pointer without a nil-guard risk a panic

Register now returns *Snapshots (previously *State). Both old and new code correctly return nil, err on failure, and the only production caller (runtime.go) checks the error before use:

snapshots, err := builtins.Register(hooksRegistry)
if err != nil {
    return nil, fmt.Errorf("register builtin hooks: %w", err)
}

However, the public API contract has subtly changed: *State was an opaque wrapper whose nil-receiver methods were all explicitly nil-safe (e.g. ClearSession had if s == nil { return }). *Snapshots is now a promoted public type whose methods do not carry nil-receiver guards. Any future caller that stores snapshots before checking err — or any test helper that calls Register and ignores the error — will receive a nil pointer that panics on the first method call rather than degrading gracefully.

Consider adding a nil-guard at the top of Snapshots' public methods (e.g. UndoLastSnapshot, ListSnapshots, ResetSnapshot), or documenting that a nil *Snapshots must never be used, to preserve the safety property the old *State wrapper provided.

}
return state, nil
return snapshots, nil
}

// AgentDefaults captures defaults that map onto stock builtin hook entries.
Expand Down
38 changes: 8 additions & 30 deletions pkg/hooks/builtins/max_iterations.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@ import (
"fmt"
"log/slog"
"strconv"
"sync"

"github.com/docker/docker-agent/pkg/hooks"
)

// MaxIterations is the registered name of the max_iterations builtin.
const MaxIterations = "max_iterations"

// maxIterationsBuiltin counts before_llm_call invocations per session
// and signals a terminating verdict once the configured limit is
// exceeded.
// maxIterations signals a terminating verdict once the runtime has
// dispatched more before_llm_call events than the configured budget.
//
// This is a hard stop with no resume protocol — distinct from the
// agent.MaxIterations flag, which has its own special UX
Expand All @@ -25,19 +23,10 @@ const MaxIterations = "max_iterations"
//
// Args layout: `[limit]`. Missing or invalid args make the hook a
// no-op so a misconfigured YAML doesn't accidentally cap a run at
// zero. State is per-session, keyed by [hooks.Input.SessionID], and
// cleared from session_end via [State.ClearSession].
type maxIterationsBuiltin struct {
mu sync.Mutex
counts map[string]int // SessionID -> calls observed
}

func newMaxIterations() *maxIterationsBuiltin {
return &maxIterationsBuiltin{counts: map[string]int{}}
}

func (b *maxIterationsBuiltin) hook(_ context.Context, in *hooks.Input, args []string) (*hooks.Output, error) {
if in == nil || in.SessionID == "" || len(args) == 0 {
// zero. Stateless: the runtime supplies the 1-based iteration counter
// via [hooks.Input.Iteration].
func maxIterations(_ context.Context, in *hooks.Input, args []string) (*hooks.Output, error) {
if in == nil || in.Iteration <= 0 || len(args) == 0 {
return nil, nil
}
limit, err := strconv.Atoi(args[0])
Expand All @@ -46,17 +35,12 @@ func (b *maxIterationsBuiltin) hook(_ context.Context, in *hooks.Input, args []s
return nil, nil
}

b.mu.Lock()
b.counts[in.SessionID]++
count := b.counts[in.SessionID]
b.mu.Unlock()

if count <= limit {
if in.Iteration <= limit {
return nil, nil
}

slog.Warn("max_iterations tripped",
"count", count, "limit", limit, "session_id", in.SessionID)
"iteration", in.Iteration, "limit", limit, "session_id", in.SessionID)

return &hooks.Output{
Decision: hooks.DecisionBlockValue,
Expand All @@ -65,9 +49,3 @@ func (b *maxIterationsBuiltin) hook(_ context.Context, in *hooks.Input, args []s
limit),
}, nil
}

func (b *maxIterationsBuiltin) clearSession(sessionID string) {
b.mu.Lock()
delete(b.counts, sessionID)
b.mu.Unlock()
}
81 changes: 18 additions & 63 deletions pkg/hooks/builtins/max_iterations_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package builtins_test

import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -12,54 +11,29 @@ import (
)

// TestMaxIterationsTripsAfterLimit verifies the happy path: with a
// limit of 3, the first three calls are no-ops and the fourth returns
// a block decision. The reason carries the configured limit so the
// runtime's user-facing Error event explains why the run stopped.
// limit of 3, the first three iterations are no-ops and the fourth
// returns a block decision. The reason carries the configured limit
// so the runtime's user-facing Error event explains why the run
// stopped.
func TestMaxIterationsTripsAfterLimit(t *testing.T) {
t.Parallel()

fn := lookup(t, builtins.MaxIterations)
in := &hooks.Input{SessionID: "s1"}
args := []string{"3"}

for i := 1; i <= 3; i++ {
out, err := fn(t.Context(), in, args)
require.NoErrorf(t, err, "call %d must not error", i)
require.Nilf(t, out, "call %d (within limit) must not trip", i)
out, err := fn(t.Context(), &hooks.Input{Iteration: i}, args)
require.NoErrorf(t, err, "iteration %d must not error", i)
require.Nilf(t, out, "iteration %d (within limit) must not trip", i)
}

out, err := fn(t.Context(), in, args)
out, err := fn(t.Context(), &hooks.Input{Iteration: 4}, args)
require.NoError(t, err)
require.NotNil(t, out, "fourth call (over limit) must trip")
require.NotNil(t, out, "iteration 4 (over limit) must trip")
assert.Equal(t, hooks.DecisionBlockValue, out.Decision)
assert.Contains(t, out.Reason, "3", "reason must include the configured limit")
}

// TestMaxIterationsIsolatesSessions documents the per-session
// counter contract: a runtime serving multiple sessions must not let
// session A's calls count against session B's budget.
func TestMaxIterationsIsolatesSessions(t *testing.T) {
t.Parallel()

fn := lookup(t, builtins.MaxIterations)
args := []string{"2"}

// Session A: two no-op calls then one trip.
for range 2 {
out, err := fn(t.Context(), &hooks.Input{SessionID: "A"}, args)
require.NoError(t, err)
require.Nil(t, out)
}
out, err := fn(t.Context(), &hooks.Input{SessionID: "A"}, args)
require.NoError(t, err)
require.NotNil(t, out, "session A trips on its third call")

// Session B: starts fresh, only sees one call so far.
out, err = fn(t.Context(), &hooks.Input{SessionID: "B"}, args)
require.NoError(t, err)
require.Nil(t, out, "session B's counter must not include session A's calls")
}

// TestMaxIterationsNoOpWithoutValidLimit documents the lenient-args
// contract: a missing, non-integer, zero, or negative limit makes
// the builtin a no-op rather than tripping (the safer default for a
Expand All @@ -75,21 +49,21 @@ func TestMaxIterationsNoOpWithoutValidLimit(t *testing.T) {
{"-1"},
}
for _, args := range cases {
fn := lookup(t, builtins.MaxIterations) // fresh state per case
// Drive 50 calls — if the builtin were tripping erroneously,
fn := lookup(t, builtins.MaxIterations)
// Drive 50 iterations — if the builtin were tripping erroneously,
// at least one of these would return a non-nil Output.
for range 50 {
out, err := fn(t.Context(), &hooks.Input{SessionID: "s"}, args)
for i := 1; i <= 50; i++ {
out, err := fn(t.Context(), &hooks.Input{Iteration: i}, args)
require.NoError(t, err)
require.Nilf(t, out, "args=%v: must never trip", args)
}
}
}

// TestMaxIterationsIgnoresIncompleteInput pins the defensive guard:
// missing SessionID produces no state mutation and no output. This
// protects against future dispatch changes where an edge case might
// fire before_llm_call without that field populated.
// missing or non-positive Iteration produces no output. This protects
// against future dispatch changes where an edge case might fire
// before_llm_call without that field populated.
func TestMaxIterationsIgnoresIncompleteInput(t *testing.T) {
t.Parallel()

Expand All @@ -99,28 +73,9 @@ func TestMaxIterationsIgnoresIncompleteInput(t *testing.T) {
require.NoError(t, err)
assert.Nil(t, out)

// Iteration=0 (zero value) means "not populated" — the runtime
// always supplies a 1-based counter on before_llm_call.
out, err = fn(t.Context(), &hooks.Input{}, []string{"1"})
require.NoError(t, err)
assert.Nil(t, out)
}

// TestMaxIterationsConcurrentCallsAreSafe is a smoke test for the
// builtin's mutex. Many goroutines incrementing the same session's
// counter must not race (run with -race).
func TestMaxIterationsConcurrentCallsAreSafe(t *testing.T) {
t.Parallel()

fn := lookup(t, builtins.MaxIterations)
in := &hooks.Input{SessionID: "concurrent"}

const callers = 50
var wg sync.WaitGroup
wg.Add(callers)
for range callers {
go func() {
defer wg.Done()
_, _ = fn(t.Context(), in, []string{"100"})
}()
}
wg.Wait()
}
3 changes: 1 addition & 2 deletions pkg/hooks/builtins/redact_secrets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ func TestRedactSecretsIsRegistered(t *testing.T) {
t.Parallel()

reg := hooks.NewRegistry()
state, err := Register(reg)
_, err := Register(reg)
require.NoError(t, err)
t.Cleanup(func() { state.ClearSession("") })

handler, ok := reg.LookupBuiltin(RedactSecrets)
require.Truef(t, ok, "builtin %q must be registered", RedactSecrets)
Expand Down
Loading
Loading