diff --git a/agent-schema.json b/agent-schema.json index 33b549c2e..106c69717 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -182,7 +182,7 @@ }, "unload_api": { "type": "string", - "description": "Optional path (or absolute URL) to the provider's model-unload endpoint. POSTed with `{\"model\": \"\"}` when the agent wires the `unload` builtin into its `on_agent_switch` hook chain, to free GPU/RAM held by the previous model. Today only Docker Model Runner ships a provider that calls this endpoint; cloud providers don't implement [provider.Unloader] and the hook silently skips them. A relative path is resolved against the scheme+host of base_url; an absolute URL is used verbatim.", + "description": "Optional path (or absolute URL) to the provider's model-unload endpoint. POSTed with `{\"model\": \"\"}` when the agent wires the `unload` builtin into its `on_agent_switch` hook chain, to free GPU/RAM held by the previous model. Today only Docker Model Runner exposes such an endpoint; the `unload` builtin is a pure HTTP hook that silently skips non-DMR providers. A relative path is resolved against the scheme+host of base_url; an absolute URL is used verbatim.", "examples": [ "/engines/_unload", "/api/unload", @@ -881,7 +881,7 @@ }, "type": { "type": "string", - "description": "Type of hook. 'command' executes a shell command; 'builtin' invokes a named in-process Go function registered by the runtime; 'model' asks an LLM and translates its reply into the hook's native output (used for LLM-as-a-judge pre_tool_use, summarizers, etc., with no Go code). The docker-agent runtime ships these builtins: 'add_date' (turn_start: today's date), 'add_environment_info' (session_start: cwd, git, OS, arch), 'add_prompt_files' (turn_start: contents of named files looked up in the workdir hierarchy and the home directory), 'add_git_status' (turn_start: `git status --short --branch`), 'add_git_diff' (turn_start: `git diff --stat`, or full diff with args=['full']), 'add_directory_listing' (session_start: top-level entries of cwd), 'add_user_info' (session_start: current OS user and hostname), 'add_recent_commits' (session_start: `git log --oneline -n N`, default N=10, override via args=['']), 'max_iterations' (before_llm_call: hard stop after N model calls; args=[''] required), 'redact_secrets' (pre_tool_use / before_llm_call / tool_response_transform: scrubs detected secrets from tool arguments, outgoing chat content, and tool output — the same builtin handles all three legs and dispatches on the event; the matching agent-level 'redact_secrets: true' flag auto-injects the entries for all three), 'unload' (on_agent_switch: walks the previous agent's models and calls Unload() on every provider that implements provider.Unloader — e.g. asks Docker Model Runner to release the GPU/RAM held by the just-departing model so the next agent's model can claim it; opt in by adding the entry to the agent's hooks.on_agent_switch list).", + "description": "Type of hook. 'command' executes a shell command; 'builtin' invokes a named in-process Go function registered by the runtime; 'model' asks an LLM and translates its reply into the hook's native output (used for LLM-as-a-judge pre_tool_use, summarizers, etc., with no Go code). The docker-agent runtime ships these builtins: 'add_date' (turn_start: today's date), 'add_environment_info' (session_start: cwd, git, OS, arch), 'add_prompt_files' (turn_start: contents of named files looked up in the workdir hierarchy and the home directory), 'add_git_status' (turn_start: `git status --short --branch`), 'add_git_diff' (turn_start: `git diff --stat`, or full diff with args=['full']), 'add_directory_listing' (session_start: top-level entries of cwd), 'add_user_info' (session_start: current OS user and hostname), 'add_recent_commits' (session_start: `git log --oneline -n N`, default N=10, override via args=['']), 'max_iterations' (before_llm_call: hard stop after N model calls; args=[''] required), 'redact_secrets' (pre_tool_use / before_llm_call / tool_response_transform: scrubs detected secrets from tool arguments, outgoing chat content, and tool output — the same builtin handles all three legs and dispatches on the event; the matching agent-level 'redact_secrets: true' flag auto-injects the entries for all three), 'unload' (on_agent_switch: POSTs `{\"model\": \"\"}` to the previous agent's DMR model endpoints — e.g. asks Docker Model Runner to release the GPU/RAM held by the just-departing model so the next agent's model can claim it. Pure HTTP, no provider-specific runtime coupling; non-DMR providers are silently skipped. Opt in by adding the entry to the agent's hooks.on_agent_switch list).", "enum": [ "command", "builtin", diff --git a/docs/configuration/hooks/index.md b/docs/configuration/hooks/index.md index b58ad2022..83ee57216 100644 --- a/docs/configuration/hooks/index.md +++ b/docs/configuration/hooks/index.md @@ -156,7 +156,7 @@ Built-ins are typically zero-config and faster than equivalent shell hooks becau | `max_iterations` | `before_llm_call` | `[""]` (required) | Hard-stops the agent after `N` model calls. State is per-session and reset at `session_end`. | | `snapshot` | `session_start`, `turn_start`, `turn_end`, `pre_tool_use`, `post_tool_use`, `session_end` | _none_ | Records filesystem snapshots in a shadow git repo under the docker-agent data directory. No-op outside git repos; respects the source repo's ignore rules and skips newly-added files larger than 2 MiB. | | `redact_secrets` | `pre_tool_use`, `before_llm_call`, `tool_response_transform` | _none_ | Scrubs detected secrets (API keys, tokens, private keys, …) out of tool call arguments, outgoing chat content, and tool output. The same builtin handles all three events and dispatches on the event name. Auto-registered on all three events by `redact_secrets: true` on the agent — see [`examples/redact_secrets_hooks.yaml`](https://github.com/docker/docker-agent/blob/main/examples/redact_secrets_hooks.yaml) for the manual wiring. | -| `unload` | `on_agent_switch` | _none_ | Walks the previous agent's models and calls `Unload()` on every provider that implements [`provider.Unloader`](https://pkg.go.dev/github.com/docker/docker-agent/pkg/model/provider#Unloader) — typically Docker Model Runner — to free the GPU/RAM the just-departing model was holding. Cloud-only providers don't implement the interface and are silently skipped. Errors are logged and swallowed; agent switching never blocks on a slow or unreachable engine (each Unload call has a 10 s timeout). See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml). | +| `unload` | `on_agent_switch` | _none_ | POSTs `{"model": ""}` to each of the previous agent's DMR model endpoints (`/_unload` by default, overridable per-model via `unload_api`) to free the GPU/RAM the just-departing model was holding. Pure HTTP — reads the model snapshot the runtime ships on `on_agent_switch` and depends on no provider-specific runtime state. Non-DMR providers (OpenAI, Anthropic, …) are silently skipped, so cross-provider chains are safe. Errors are logged and swallowed; agent switching never blocks on a slow or unreachable engine (each call has a 10 s timeout). See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml). |
ℹ️ Per-turn vs. per-session @@ -590,7 +590,7 @@ models: model: ai/qwen3-coder ``` -At every transfer the runtime calls `Unload()` on the previous agent's model providers. For Docker Model Runner this hits the engine's `_unload` endpoint; for cloud providers (OpenAI, Anthropic, …) it is a silent no-op. Cross-provider chains are safe — only the providers that actually implement [`provider.Unloader`](https://pkg.go.dev/github.com/docker/docker-agent/pkg/model/provider#Unloader) are touched. See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml) for the full file. +At every transfer the runtime ships a snapshot of the previous agent's model endpoints on the `on_agent_switch` hook input, and the `unload` builtin POSTs `{"model": ""}` to each DMR endpoint's `/_unload` URL over plain HTTP. For cloud providers (OpenAI, Anthropic, …) the hook is a silent no-op since they don't expose an HTTP unload endpoint. Cross-provider chains are safe — only DMR endpoints are touched. See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml) for the full file. `on_session_resume` fires when the user explicitly approves the runtime to continue past its configured `max_iterations` limit. `previous_max_iterations` carries the cap that was reached and `new_max_iterations` carries the new cap after approval. Useful for alerting on extended-runtime sessions or for billing / quota pipelines that meter resumes. diff --git a/examples/unload_on_switch.yaml b/examples/unload_on_switch.yaml index e4aea29ef..c098e374e 100644 --- a/examples/unload_on_switch.yaml +++ b/examples/unload_on_switch.yaml @@ -4,11 +4,13 @@ # # Two agents share Docker Model Runner but use different models that don't # fit in GPU memory at the same time. Wiring the `unload` builtin into -# each agent's `on_agent_switch` hook chain tells the runtime to call -# Unload on the previous agent's model providers every time the active -# agent transfers control. For DMR this hits the engine's `_unload` -# endpoint; for cloud-only providers (OpenAI, Anthropic, ...) the hook is -# a silent no-op since they don't implement provider.Unloader. +# each agent's `on_agent_switch` hook chain asks the previous agent's +# DMR endpoint(s) to release GPU memory every time the active agent +# transfers control. The hook is pure: it reads the model snapshot the +# runtime ships on every on_agent_switch dispatch and POSTs to DMR's +# `_unload` endpoint over plain HTTP — no provider-specific runtime +# coupling. For cloud-only providers (OpenAI, Anthropic, ...) the hook +# is a silent no-op since they don't expose an HTTP unload endpoint. # # Switching back and forth between `coder` and `reviewer` therefore costs # one model load per switch instead of failing on out-of-memory. diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index bfac3d6a1..0cc7c712c 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -238,7 +238,7 @@ type ProviderConfig struct { // models are POSTed `{"model": ""}` here at every switch. // Cloud providers should leave this unset. // - // [unload]: https://pkg.go.dev/github.com/docker/docker-agent/pkg/runtime#BuiltinUnload + // [unload]: https://pkg.go.dev/github.com/docker/docker-agent/pkg/hooks/builtins#Unload UnloadAPI string `json:"unload_api,omitempty"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` diff --git a/pkg/hooks/builtins/builtins.go b/pkg/hooks/builtins/builtins.go index cdfa98705..80e3a821e 100644 --- a/pkg/hooks/builtins/builtins.go +++ b/pkg/hooks/builtins/builtins.go @@ -12,6 +12,8 @@ // - add_user_info (session_start) — current OS user and host // - add_recent_commits (session_start) — `git log --oneline -n N` // - max_iterations (before_llm_call) — hard stop after N model calls +// - unload (on_agent_switch) — release the previous +// agent's local-engine resources via HTTP unload (DMR today) // - snapshot (session_start, // turn_start, turn_end, // pre_tool_use, post_tool_use, @@ -120,6 +122,7 @@ func Register(r *hooks.Registry) (*State, error) { r.RegisterBuiltin(Snapshot, state.snapshot.hook), r.RegisterBuiltin(RedactSecrets, redactSecrets), r.RegisterBuiltin(HTTPPost, httpPost), + r.RegisterBuiltin(Unload, unload), ); err != nil { return nil, err } diff --git a/pkg/hooks/builtins/builtins_test.go b/pkg/hooks/builtins/builtins_test.go index 8fa1dc43c..2358a9869 100644 --- a/pkg/hooks/builtins/builtins_test.go +++ b/pkg/hooks/builtins/builtins_test.go @@ -37,6 +37,7 @@ func TestRegisterInstallsAllBuiltins(t *testing.T) { builtins.Snapshot, builtins.RedactSecrets, builtins.HTTPPost, + builtins.Unload, } { fn, ok := r.LookupBuiltin(name) assert.True(t, ok, "builtin %q must be registered", name) diff --git a/pkg/hooks/builtins/unload.go b/pkg/hooks/builtins/unload.go new file mode 100644 index 000000000..41f4df23e --- /dev/null +++ b/pkg/hooks/builtins/unload.go @@ -0,0 +1,108 @@ +package builtins + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/model/provider/dmr" +) + +// Unload is the registered name of the on_agent_switch builtin that +// asks the previous agent's local inference engines (today: Docker +// Model Runner) to release the resources they hold. +// +// Wire it into a config with: +// +// hooks: +// on_agent_switch: +// - type: builtin +// command: unload +// +// The hook is pure: it depends only on the [hooks.Input.FromAgentModels] +// snapshot the runtime ships on every on_agent_switch dispatch, plus +// net/http. It carries no runtime-side coupling and silently skips any +// model whose endpoint isn't reachable as plain HTTP (e.g. cloud +// providers that don't expose [hooks.ModelEndpoint.BaseURL]). +// +// Provider dispatch and URL resolution are owned by +// [pkg/model/provider/dmr] (see [dmr.ProviderType] and [dmr.UnloadURL]), +// so this builtin stays a dumb dispatcher and DMR keeps full control +// of its conventions. +const Unload = "unload" + +// unloadTimeout caps each per-model Unload call so a stalled engine +// cannot stall agent switching. +const unloadTimeout = 10 * time.Second + +// unload iterates the [hooks.Input.FromAgentModels] snapshot the +// runtime captured at dispatch time and POSTs `{"model": ""}` to +// the resolved unload endpoint of each DMR model. Errors are logged +// but never propagated — agent switching must never block on a slow +// or unreachable engine. +func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { + if in == nil || in.FromAgent == "" || in.FromAgent == in.ToAgent { + return nil, nil + } + for _, m := range in.FromAgentModels { + if m.Provider != dmr.ProviderType { + continue + } + if err := unloadOne(ctx, m); err != nil { + slog.WarnContext(ctx, "unload: failed", + "agent", in.FromAgent, "model", m.Model, "error", err) + } + } + return nil, nil +} + +// unloadOne resolves the unload URL for m and POSTs the model id to +// it, bounded by [unloadTimeout]. A model with no resolvable endpoint +// (no base_url and no unload_api) is a silent no-op so the hook stays +// harmless on test / in-process providers. +func unloadOne(parent context.Context, m hooks.ModelEndpoint) error { + endpoint, err := dmr.UnloadURL(m.BaseURL, m.UnloadAPI) + if err != nil || endpoint == "" { + return err + } + ctx, cancel := context.WithTimeout(parent, unloadTimeout) + defer cancel() + + body, _ := json.Marshal(map[string]string{"model": m.Model}) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("building unload request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", m.Model) + + // Unlike the http_post builtin, the unload target is the + // operator-configured DMR base URL — typically a loopback engine + // (Docker Desktop socket, 127.0.0.1:12434, …). The SSRF-safe + // dialer used by http_post would refuse those addresses by + // design, so we use the default client here. + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("calling unload endpoint %s: %w", endpoint, err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) + return fmt.Errorf("unload endpoint returned %d: %s", + resp.StatusCode, strings.TrimSpace(string(respBody))) + } + // Drain the success-path body so the underlying transport can reuse + // the connection (Go's http.Client only re-pools a connection whose + // body has been read to EOF and closed). + _, _ = io.Copy(io.Discard, resp.Body) + return nil +} diff --git a/pkg/hooks/builtins/unload_test.go b/pkg/hooks/builtins/unload_test.go new file mode 100644 index 000000000..959220fb6 --- /dev/null +++ b/pkg/hooks/builtins/unload_test.go @@ -0,0 +1,234 @@ +package builtins + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/model/provider/dmr" +) + +// dmrInput builds an on_agent_switch [hooks.Input] carrying a single +// DMR ModelEndpoint pointed at server. Callers pass the relative +// override (or "") plus optional opts to tweak the generic transition +// fields (e.g. emptying FromAgent, equating From/To). Centralising +// this removes a chunk of per-test boilerplate without hiding the +// input shape — every field the test cares about is still visible on +// the call site. +func dmrInput(server *httptest.Server, unloadAPI string, opts ...func(*hooks.Input)) *hooks.Input { + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{ + Provider: dmr.ProviderType, + Model: "ai/qwen3", + BaseURL: server.URL + "/engines/v1", + UnloadAPI: unloadAPI, + }}, + } + for _, opt := range opts { + opt(in) + } + return in +} + +// countingServer returns a recording server whose handler runs `mark` +// on every hit; tests that count calls share the same idiom. +func countingServer(t *testing.T, status int, mark func(*http.Request)) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if mark != nil { + mark(r) + } + w.WriteHeader(status) + })) + t.Cleanup(srv.Close) + return srv +} + +// TestUnloadBuiltin_Registered guarantees the public name is findable +// on a registry built by [Register], so YAML hook entries that name it +// actually resolve. +func TestUnloadBuiltin_Registered(t *testing.T) { + t.Parallel() + + r := hooks.NewRegistry() + _, err := Register(r) + require.NoError(t, err) + + fn, ok := r.LookupBuiltin(Unload) + require.True(t, ok, "%q must be registered on the hook registry", Unload) + require.NotNil(t, fn) +} + +// TestUnload_PostsToDefaultEndpoint exercises the happy path against a +// real httptest server: the builtin must derive the `_unload` URL from +// the model's BaseURL and POST `{"model": ""}`. +func TestUnload_PostsToDefaultEndpoint(t *testing.T) { + t.Parallel() + + var ( + gotPath string + gotBody map[string]string + ) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + in := dmrInput(server, "") + in.FromAgentModels[0].BaseURL = server.URL + "/engines/llama.cpp/v1" + + out, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Nil(t, out, "unload is observational; output must be nil") + assert.Equal(t, "/engines/llama.cpp/_unload", gotPath) + assert.Equal(t, map[string]string{"model": "ai/qwen3"}, gotBody) +} + +// TestUnload_HonoursOverrideUnloadAPI documents that an explicit +// `unload_api` on the model takes precedence over the default +// derivation, and is rebased onto the BaseURL's host when relative. +func TestUnload_HonoursOverrideUnloadAPI(t *testing.T) { + t.Parallel() + + var gotPath string + server := countingServer(t, http.StatusOK, func(r *http.Request) { gotPath = r.URL.Path }) + + _, err := unload(t.Context(), dmrInput(server, "/custom/unload"), nil) + require.NoError(t, err) + assert.Equal(t, "/custom/unload", gotPath) +} + +// TestUnload_FiltersPerElement pins the per-element provider filter: +// when the snapshot mixes DMR and non-DMR endpoints, only the DMR +// ones are POSTed to. The non-DMR entries (cloud providers without a +// reachable unload endpoint) must be silently skipped, not errored +// on, not POSTed to a fabricated URL. +func TestUnload_FiltersPerElement(t *testing.T) { + t.Parallel() + + var gotModels []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body struct { + Model string `json:"model"` + } + _ = json.NewDecoder(r.Body).Decode(&body) + gotModels = append(gotModels, body.Model) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{ + {Provider: "openai", Model: "gpt-4", BaseURL: "https://api.openai.com/v1"}, + {Provider: dmr.ProviderType, Model: "ai/qwen3", BaseURL: server.URL + "/engines/v1"}, + {Provider: "anthropic", Model: "claude", BaseURL: "https://api.anthropic.com"}, + {Provider: dmr.ProviderType, Model: "ai/llama3.2", BaseURL: server.URL + "/engines/llama.cpp/v1"}, + }, + } + + out, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Nil(t, out) + assert.ElementsMatch(t, []string{"ai/qwen3", "ai/llama3.2"}, gotModels, + "only DMR models must be POSTed; cloud providers must be silently skipped") +} + +// TestUnload_NoOpInputs pins the cheap-path properties the agent loop +// relies on: the hook MUST NOT fire any HTTP call when the input +// describes a transition where unloading would be wrong (back to the +// same agent, no previous agent, only cloud providers, or a model +// without a resolvable endpoint). Combining these into one table +// makes the no-op contract obvious from the test body. +func TestUnload_NoOpInputs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in func(*httptest.Server) *hooks.Input + }{ + { + name: "nil input", + in: func(*httptest.Server) *hooks.Input { return nil }, + }, + { + name: "empty FromAgent", + in: func(s *httptest.Server) *hooks.Input { + return dmrInput(s, "", func(in *hooks.Input) { in.FromAgent = "" }) + }, + }, + { + name: "FromAgent equals ToAgent", + in: func(s *httptest.Server) *hooks.Input { + return dmrInput(s, "", func(in *hooks.Input) { + in.FromAgent, in.ToAgent = "same", "same" + }) + }, + }, + { + name: "non-DMR providers only", + in: func(s *httptest.Server) *hooks.Input { + return &hooks.Input{ + FromAgent: "from", ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{ + {Provider: "openai", Model: "gpt-4", BaseURL: s.URL}, + {Provider: "anthropic", Model: "claude", BaseURL: s.URL}, + }, + } + }, + }, + { + name: "DMR model with no endpoint", + in: func(*httptest.Server) *hooks.Input { + return &hooks.Input{ + FromAgent: "from", ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{Provider: dmr.ProviderType, Model: "ai/qwen3"}}, + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var calls atomic.Int64 + server := countingServer(t, http.StatusOK, func(*http.Request) { calls.Add(1) }) + + out, err := unload(t.Context(), tt.in(server), nil) + require.NoError(t, err) + assert.Nil(t, out) + assert.Zero(t, calls.Load(), "no HTTP call must reach the server") + }) + } +} + +// TestUnload_SwallowsServerErrors verifies the best-effort contract: +// a 5xx from the engine must NOT propagate back as a hook error, +// because agent switching has to keep moving even when the unload +// endpoint is down. +func TestUnload_SwallowsServerErrors(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("boom")) + })) + defer server.Close() + + out, err := unload(t.Context(), dmrInput(server, ""), nil) + require.NoError(t, err) + assert.Nil(t, out) +} diff --git a/pkg/hooks/types.go b/pkg/hooks/types.go index 474c0cf05..60d828cba 100644 --- a/pkg/hooks/types.go +++ b/pkg/hooks/types.go @@ -242,6 +242,15 @@ type Input struct { ToAgent string `json:"to_agent,omitempty"` AgentSwitchKind string `json:"agent_switch_kind,omitempty"` + // FromAgentModels is the snapshot of the previous agent's + // configured model endpoints, captured at on_agent_switch dispatch + // time. Populated only on [EventOnAgentSwitch]; nil for every other + // event. Hooks that act on the previous agent's models (e.g. the + // stock `unload` builtin asking a local inference engine to release + // GPU memory) read this slice instead of poking at the runtime, + // keeping the hook payload self-contained. + FromAgentModels []ModelEndpoint `json:"from_agent_models,omitempty"` + // OnSessionResume specific: the iteration cap that was reached // (PreviousMaxIterations) and the new cap after the user approved // continuation (NewMaxIterations). Carrying both lets audit @@ -278,6 +287,27 @@ type Input struct { Summary string `json:"summary,omitempty"` } +// ModelEndpoint identifies one of an agent's configured models plus +// the HTTP endpoint that hosts it, when one is known. It is the wire +// format used by [Input.FromAgentModels] so hooks can reach a +// model-serving endpoint without depending on runtime-only types. +type ModelEndpoint struct { + // Provider is the provider type ("openai", "anthropic", "dmr", ...). + Provider string `json:"provider,omitempty"` + // Model is the resolved model identifier. + Model string `json:"model,omitempty"` + // BaseURL is the resolved HTTP base URL of the provider, when known + // (set by providers that talk to a configurable HTTP endpoint, e.g. + // Docker Model Runner). Empty for cloud providers that don't expose + // a stable per-instance base URL on the runtime side. + BaseURL string `json:"base_url,omitempty"` + // UnloadAPI is the per-model unload path or absolute URL copied + // verbatim from the model's `unload_api` provider option. Empty + // when the user hasn't configured an override; the unload builtin + // falls back to a provider-specific default in that case. + UnloadAPI string `json:"unload_api,omitempty"` +} + // ToJSON serializes the input. func (i *Input) ToJSON() ([]byte, error) { return json.Marshal(i) } diff --git a/pkg/model/provider/base/base.go b/pkg/model/provider/base/base.go index 5fdc1a07a..b757590b1 100644 --- a/pkg/model/provider/base/base.go +++ b/pkg/model/provider/base/base.go @@ -17,6 +17,17 @@ type Config struct { // Models stores the full models map for providers that need it (e.g., routers). // This enables proper cloning of providers that reference other models. Models map[string]latest.ModelConfig + // BaseURL is the resolved HTTP base URL the client talks to, when + // the provider is reachable over a configurable HTTP endpoint. + // Distinct from [latest.ModelConfig.BaseURL] (the user-typed input): + // providers fill BaseURL with the URL they actually use after auto + // discovery / fallback (e.g. Docker Model Runner picking between + // MODEL_RUNNER_HOST, the desktop socket, and a localhost fallback). + // Surfaced through [Config.BaseConfig] so generic, runtime-free + // consumers like hooks can address the endpoint without duplicating + // resolution logic. Empty for providers that don't expose a stable + // per-instance URL. + BaseURL string } // ID returns the provider and model ID in the format "provider/model". diff --git a/pkg/model/provider/dmr/client.go b/pkg/model/provider/dmr/client.go index 3e8dca4cb..2f9887524 100644 --- a/pkg/model/provider/dmr/client.go +++ b/pkg/model/provider/dmr/client.go @@ -53,7 +53,6 @@ type Client struct { base.Config client openai.Client - baseURL string httpClient *http.Client engine string } @@ -137,9 +136,9 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt Config: base.Config{ ModelConfig: *cfg, ModelOptions: globalOptions, + BaseURL: baseURL, }, client: openai.NewClient(clientOptions...), - baseURL: baseURL, httpClient: httpClient, engine: engine, }, nil @@ -159,7 +158,7 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat "model", c.ModelConfig.Model, "message_count", len(messages), "tool_count", len(requestTools), - "base_url", c.baseURL, + "base_url", c.BaseURL, ) if len(messages) == 0 { @@ -286,7 +285,7 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat stream := c.client.Chat.Completions.NewStreaming(ctx, params) - slog.DebugContext(ctx, "DMR chat completion stream created successfully", "model", c.ModelConfig.Model, "base_url", c.baseURL) + slog.DebugContext(ctx, "DMR chat completion stream created successfully", "model", c.ModelConfig.Model, "base_url", c.BaseURL) return newStreamAdapter(stream, trackUsage), nil } diff --git a/pkg/model/provider/dmr/client_test.go b/pkg/model/provider/dmr/client_test.go index 12beb63e6..24835658d 100644 --- a/pkg/model/provider/dmr/client_test.go +++ b/pkg/model/provider/dmr/client_test.go @@ -30,7 +30,7 @@ func TestNewClientWithExplicitBaseURL(t *testing.T) { client, err := NewClient(t.Context(), cfg) require.NoError(t, err) - assert.Equal(t, "https://custom.example.com:8080/api/v1", client.baseURL) + assert.Equal(t, "https://custom.example.com:8080/api/v1", client.BaseURL) } func TestNewClientReturnsErrNotInstalledWhenDockerModelUnsupported(t *testing.T) { diff --git a/pkg/model/provider/dmr/embed.go b/pkg/model/provider/dmr/embed.go index 205efbf4a..aa38185d0 100644 --- a/pkg/model/provider/dmr/embed.go +++ b/pkg/model/provider/dmr/embed.go @@ -43,7 +43,7 @@ func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*bas return &base.BatchEmbeddingResult{Embeddings: [][]float64{}}, nil } - slog.DebugContext(ctx, "Creating DMR embeddings", "model", c.ModelConfig.Model, "batch_size", len(texts), "base_url", c.baseURL) + slog.DebugContext(ctx, "Creating DMR embeddings", "model", c.ModelConfig.Model, "batch_size", len(texts), "base_url", c.BaseURL) response, err := c.client.Embeddings.New(ctx, openai.EmbeddingNewParams{ Input: openai.EmbeddingNewParamsInputUnion{OfArrayOfStrings: texts}, @@ -127,7 +127,7 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc documentStrings[i] = doc.Content } - baseURL, err := rerankBaseURL(c.baseURL) + baseURL, err := rerankBaseURL(c.BaseURL) if err != nil { return nil, err } diff --git a/pkg/model/provider/dmr/unload.go b/pkg/model/provider/dmr/unload.go index 5fb825f4f..07c2f6c8d 100644 --- a/pkg/model/provider/dmr/unload.go +++ b/pkg/model/provider/dmr/unload.go @@ -1,103 +1,52 @@ package dmr import ( - "bytes" - "context" - "encoding/json" "fmt" - "io" - "log/slog" - "net/http" "net/url" "strings" ) -// Unload asks Docker Model Runner to release the resources held for the -// configured model. Invoked by the runtime's `unload` on_agent_switch -// builtin hook. -// -// The unload endpoint is the provider's `unload_api` (relative path or -// absolute URL) when set, otherwise [defaultUnloadURL] derived from the -// OpenAI base URL. When neither is available the call is a no-op. -func (c *Client) Unload(ctx context.Context) error { - endpoint, err := c.resolveUnloadURL() - if err != nil || endpoint == "" { - return err - } - return postUnloadModel(ctx, c.httpClient, endpoint, c.ModelConfig.Model) -} - -func (c *Client) resolveUnloadURL() (string, error) { - if override := c.ModelConfig.UnloadAPI(); override != "" { - return rebaseURL(c.baseURL, override) - } - if c.baseURL == "" { - return "", nil - } - return defaultUnloadURL(c.baseURL), nil -} +// ProviderType is the canonical [latest.ModelConfig.Provider] value +// for Docker Model Runner. Exported so callers outside the package +// (e.g. the `unload` hook builtin) can dispatch on provider type +// without hard-coding the literal. +const ProviderType = "dmr" -// defaultUnloadURL derives the `_unload` endpoint URL from the OpenAI -// base URL by replacing the trailing `/v1` segment, mirroring how -// [buildConfigureURL] derives `_configure`: +// UnloadURL resolves the URL of the per-model unload endpoint for a +// DMR-served model, given the resolved provider base URL and the +// per-model `unload_api` override (both as they appear on +// [hooks.ModelEndpoint]). // -// http://host:port/engines/v1/ → http://host:port/engines/_unload -// http://host:port/engines/llama.cpp/v1/ → http://host:port/engines/llama.cpp/_unload -// http://_/exp/vDD4.40/engines/v1 → http://_/exp/vDD4.40/engines/_unload -func defaultUnloadURL(baseURL string) string { - u, err := url.Parse(baseURL) - if err != nil { - return strings.TrimSuffix(strings.TrimSuffix(baseURL, "/"), "/v1") + "/_unload" +// Resolution order: +// +// 1. unloadAPI is an absolute URL — used verbatim (lets users point +// at a different host than baseURL); +// 2. unloadAPI is set but relative — rebased onto baseURL's +// scheme + host (the model's path is dropped); +// 3. unloadAPI is unset — the default `_unload` URL is derived from +// baseURL by replacing its trailing `/v1` segment, mirroring the +// `/v1` → `/_configure` convention the configure path uses. +// +// Returns ("", nil) when neither baseURL nor unloadAPI is set, so the +// caller can skip without erroring (in-process / test providers). +func UnloadURL(baseURL, unloadAPI string) (string, error) { + if strings.HasPrefix(unloadAPI, "http://") || strings.HasPrefix(unloadAPI, "https://") { + return unloadAPI, nil } - u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload" - return u.String() -} - -// rebaseURL returns path verbatim if it is already an absolute URL, -// otherwise attaches it to baseURL's scheme + host (dropping any path -// baseURL may carry). This lets users point base_url at e.g. -// http://localhost:12434/engines/v1 and override unload_api with -// /engines/_unload without the version prefix bleeding through. -func rebaseURL(baseURL, path string) (string, error) { - if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { - return path, nil + if baseURL == "" && unloadAPI == "" { + return "", nil } u, err := url.Parse(baseURL) if err != nil || u.Scheme == "" || u.Host == "" { - return "", fmt.Errorf("base_url %q is not absolute; cannot resolve %q", baseURL, path) - } - if !strings.HasPrefix(path, "/") { - path = "/" + path + return "", fmt.Errorf("base_url %q is not absolute; cannot resolve unload endpoint", baseURL) } - return u.Scheme + "://" + u.Host + path, nil -} - -// postUnloadModel issues `POST ` with body `{"model": ""}`. -func postUnloadModel(ctx context.Context, client *http.Client, endpoint, modelID string) error { - body, _ := json.Marshal(map[string]string{"model": modelID}) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("building unload request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", modelID) - - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("calling unload endpoint %s: %w", endpoint, err) - } - defer resp.Body.Close() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) - return fmt.Errorf("unload endpoint returned %d: %s", - resp.StatusCode, strings.TrimSpace(string(respBody))) + switch { + case unloadAPI == "": + u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload" + case strings.HasPrefix(unloadAPI, "/"): + u.Path = unloadAPI + default: + u.Path = "/" + unloadAPI } - // Drain the success-path body so the underlying transport can reuse - // the connection (Go's http.Client only re-pools a connection whose - // body has been read to EOF and closed). - _, _ = io.Copy(io.Discard, resp.Body) - return nil + return u.String(), nil } diff --git a/pkg/model/provider/dmr/unload_test.go b/pkg/model/provider/dmr/unload_test.go index 0a2fcc463..ed6180469 100644 --- a/pkg/model/provider/dmr/unload_test.go +++ b/pkg/model/provider/dmr/unload_test.go @@ -1,57 +1,83 @@ package dmr import ( - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/docker/docker-agent/pkg/config/latest" - "github.com/docker/docker-agent/pkg/model/provider/base" ) -func TestRebaseURL(t *testing.T) { +// TestUnloadURL covers every branch of the URL-resolution algorithm +// in one place. The builtin and any other consumer of [UnloadURL] +// rely on these properties, so the test pins the convention here +// where DMR owns it (sibling to [buildConfigureURL]'s `_configure` +// math). +func TestUnloadURL(t *testing.T) { t.Parallel() tests := []struct { name string baseURL string - path string + unloadAPI string want string errContains string // empty ⇒ expect success }{ + // Default derivation (no unload_api set). + { + name: "default: standard engines path", + baseURL: "http://127.0.0.1:12434/engines/v1/", + want: "http://127.0.0.1:12434/engines/_unload", + }, + { + name: "default: no trailing slash", + baseURL: "http://127.0.0.1:12434/engines/v1", + want: "http://127.0.0.1:12434/engines/_unload", + }, + { + name: "default: Docker Desktop experimental prefix", + baseURL: "http://_/exp/vDD4.40/engines/v1", + want: "http://_/exp/vDD4.40/engines/_unload", + }, { - name: "absolute https URL is returned verbatim", - baseURL: "http://anything", - path: "https://api.example.com/unload", - want: "https://api.example.com/unload", + name: "default: backend-scoped path", + baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/", + want: "http://127.0.0.1:12434/engines/llama.cpp/_unload", }, + + // Override paths and absolute URLs. { - name: "rooted path drops base path", - baseURL: "http://localhost:12434/engines/v1", - path: "/engines/_unload", - want: "http://localhost:12434/engines/_unload", + name: "override: absolute https URL is returned verbatim", + baseURL: "http://anything", + unloadAPI: "https://api.example.com/unload", + want: "https://api.example.com/unload", }, { - name: "relative path is rooted", - baseURL: "http://localhost:12434/engines/v1", - path: "engines/_unload", - want: "http://localhost:12434/engines/_unload", + name: "override: rooted path drops base path", + baseURL: "http://localhost:12434/engines/v1", + unloadAPI: "/engines/_unload", + want: "http://localhost:12434/engines/_unload", }, { - name: "empty base URL with relative path errors", - path: "/engines/_unload", + name: "override: relative path is rooted", + baseURL: "http://localhost:12434/engines/v1", + unloadAPI: "engines/_unload", + want: "http://localhost:12434/engines/_unload", + }, + + // Skip / error cases. + { + name: "skip: no base_url and no unload_api", + want: "", + }, + { + name: "error: unload_api set but base_url empty", + unloadAPI: "/engines/_unload", errContains: "is not absolute", }, { - name: "base URL without scheme errors", + name: "error: base_url without scheme", baseURL: "localhost:12434/engines/v1", - path: "/engines/_unload", + unloadAPI: "/engines/_unload", errContains: "is not absolute", }, } @@ -59,7 +85,7 @@ func TestRebaseURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := rebaseURL(tt.baseURL, tt.path) + got, err := UnloadURL(tt.baseURL, tt.unloadAPI) if tt.errContains != "" { require.Error(t, err) assert.Contains(t, err.Error(), tt.errContains) @@ -70,150 +96,3 @@ func TestRebaseURL(t *testing.T) { }) } } - -func TestDefaultUnloadURL(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - baseURL string - expected string - }{ - { - name: "standard engines path", - baseURL: "http://127.0.0.1:12434/engines/v1/", - expected: "http://127.0.0.1:12434/engines/_unload", - }, - { - name: "no trailing slash", - baseURL: "http://127.0.0.1:12434/engines/v1", - expected: "http://127.0.0.1:12434/engines/_unload", - }, - { - name: "Docker Desktop experimental prefix", - baseURL: "http://_/exp/vDD4.40/engines/v1", - expected: "http://_/exp/vDD4.40/engines/_unload", - }, - { - name: "backend-scoped path", - baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/", - expected: "http://127.0.0.1:12434/engines/llama.cpp/_unload", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - assert.Equal(t, tt.expected, defaultUnloadURL(tt.baseURL)) - }) - } -} - -// newClient builds a [Client] just well-formed enough to drive Unload. -func newClient(baseURL string, httpClient *http.Client, cfg latest.ModelConfig) *Client { - return &Client{ - Config: base.Config{ModelConfig: cfg}, - baseURL: baseURL, - httpClient: httpClient, - } -} - -// TestClientUnload exercises the full Unload path end-to-end against an -// httptest server, covering: default URL derivation, user-configured -// unload_api override, non-2xx error surfacing, and the no-op branch -// when no endpoint can be determined. -func TestClientUnload(t *testing.T) { - t.Parallel() - - t.Run("posts model id to default unload endpoint", func(t *testing.T) { - t.Parallel() - - var ( - gotPath string - gotBody map[string]string - ) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - body, _ := io.ReadAll(r.Body) - _ = json.Unmarshal(body, &gotBody) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - c := newClient(server.URL+"/engines/llama.cpp/v1", server.Client(), latest.ModelConfig{ - Provider: "dmr", - Model: "ai/qwen3", - }) - - require.NoError(t, c.Unload(t.Context())) - assert.Equal(t, "/engines/llama.cpp/_unload", gotPath) - assert.Equal(t, map[string]string{"model": "ai/qwen3"}, gotBody) - }) - - t.Run("honours user-configured unload_api path", func(t *testing.T) { - t.Parallel() - - var gotPath string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - c := newClient(server.URL+"/engines/v1", server.Client(), latest.ModelConfig{ - Model: "ai/qwen3", - ProviderOpts: map[string]any{"unload_api": "/custom/unload"}, - }) - - require.NoError(t, c.Unload(t.Context())) - assert.Equal(t, "/custom/unload", gotPath) - }) - - t.Run("returns error on non-2xx", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte("boom")) - })) - defer server.Close() - - c := newClient(server.URL+"/engines/v1", server.Client(), latest.ModelConfig{ - Model: "ai/qwen3", - }) - - err := c.Unload(t.Context()) - require.Error(t, err) - assert.Contains(t, err.Error(), "500") - assert.Contains(t, err.Error(), "boom") - }) - - t.Run("no-op when neither base URL nor unload_api are set", func(t *testing.T) { - t.Parallel() - c := newClient("", nil, latest.ModelConfig{Model: "ai/qwen3"}) - require.NoError(t, c.Unload(t.Context())) - }) - - t.Run("drains success body so connection can be reused", func(t *testing.T) { - t.Parallel() - - // httptest's default server uses keep-alive; a non-drained body would - // force a fresh TCP connection on the second call. We assert that the - // underlying transport saw a single connection across two POSTs. - var seen sync.Map - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - seen.Store(r.RemoteAddr, struct{}{}) - // Write a non-trivial body so the drain actually has work to do. - _, _ = w.Write([]byte(`{"status":"ok"}`)) - })) - defer server.Close() - - c := newClient(server.URL+"/engines/v1", server.Client(), latest.ModelConfig{Model: "ai/qwen3"}) - require.NoError(t, c.Unload(t.Context())) - require.NoError(t, c.Unload(t.Context())) - - count := 0 - seen.Range(func(_, _ any) bool { count++; return true }) - assert.Equal(t, 1, count, "both POSTs must reuse the same TCP connection") - }) -} diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index e637fed13..80b46948b 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -71,22 +71,6 @@ type RerankingProvider interface { Rerank(ctx context.Context, query string, documents []types.Document, criteria string) ([]float64, error) } -// Unloader is an optional interface for providers that can release -// the resources held for the configured model. Local inference engines -// (today, Docker Model Runner) implement it; cloud APIs typically -// don't. -// -// The runtime's [unload] on_agent_switch builtin hook calls Unload when -// an agent hands control to another agent. Implementations are -// best-effort and should be idempotent — repeated calls on an -// already-unloaded model must succeed. -// -// [unload]: https://pkg.go.dev/github.com/docker/docker-agent/pkg/runtime#BuiltinUnload -type Unloader interface { - Provider - Unload(ctx context.Context) error -} - // New creates a new provider from a model config. // This is a convenience wrapper for NewWithModels with no models map. func New(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) { diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 63172cf30..b6e795e8d 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -258,15 +258,58 @@ const ( // changes the active agent. Observational; failures are logged. The // hook runs alongside the existing [AgentSwitching] event, so users // who already consume that event see no behaviour change. +// +// The previous agent's model-endpoint snapshot is built only when at +// least one hook is configured for this event so audit-free +// deployments don't pay the team-lookup + per-model allocation on +// every agent switch (matches the cheap-when-unused property of the +// other hook callsites). func (r *LocalRuntime) executeOnAgentSwitchHooks(ctx context.Context, a *agent.Agent, sessionID, fromAgent, toAgent, kind string) { + exec := r.hooksExec(a) + if exec == nil || !exec.Has(hooks.EventOnAgentSwitch) { + return + } r.dispatchHook(ctx, a, hooks.EventOnAgentSwitch, &hooks.Input{ SessionID: sessionID, FromAgent: fromAgent, ToAgent: toAgent, AgentSwitchKind: kind, + FromAgentModels: r.fromAgentModels(fromAgent), }, nil) } +// fromAgentModels snapshots the previous agent's configured model +// endpoints into the wire-friendly [hooks.ModelEndpoint] form. Hooks +// that act on the previous agent's models (e.g. the stock `unload` +// builtin) read this slice instead of poking at the runtime, so the +// hook payload stays self-contained. +// +// Returns nil when there is no previous agent or the lookup fails so +// the JSON wire payload omits the field via `omitempty`. +func (r *LocalRuntime) fromAgentModels(fromAgent string) []hooks.ModelEndpoint { + if fromAgent == "" { + return nil + } + from, err := r.team.Agent(fromAgent) + if err != nil { + slog.Debug("on_agent_switch: from-agent lookup failed", + "agent", fromAgent, "error", err) + return nil + } + configured := from.ConfiguredModels() + out := make([]hooks.ModelEndpoint, 0, len(configured)) + for _, p := range configured { + cfg := p.BaseConfig() + out = append(out, hooks.ModelEndpoint{ + Provider: cfg.ModelConfig.Provider, + Model: cfg.ModelConfig.Model, + BaseURL: cfg.BaseURL, + UnloadAPI: cfg.ModelConfig.UnloadAPI(), + }) + } + return out +} + // executeOnSessionResumeHooks fires on_session_resume when the user // explicitly approves continuation past the configured // max_iterations limit. Observational; failures are logged. The hook diff --git a/pkg/runtime/on_agent_switch_test.go b/pkg/runtime/on_agent_switch_test.go index 1ef60e23d..d9e8c0ca7 100644 --- a/pkg/runtime/on_agent_switch_test.go +++ b/pkg/runtime/on_agent_switch_test.go @@ -9,8 +9,12 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/tools" ) // recordingBuiltin captures the [hooks.Input] passed on every dispatch @@ -116,3 +120,83 @@ func TestExecuteOnAgentSwitchHooks_NoopWhenNoHookRegistered(t *testing.T) { // Should be a successful no-op rather than a panic or error. r.executeOnAgentSwitchHooks(t.Context(), a, "s", "root", "next", agentSwitchKindHandoff) } + +// endpointProvider is a minimal [provider.Provider] test double whose +// BaseConfig returns a non-zero base.Config, so we can drive the +// FromAgentModels-population branch of executeOnAgentSwitchHooks. +type endpointProvider struct { + cfg base.Config +} + +func (p *endpointProvider) ID() string { return p.cfg.ID() } + +func (p *endpointProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { + return &mockStream{}, nil +} + +func (p *endpointProvider) BaseConfig() base.Config { return p.cfg } + +// TestExecuteOnAgentSwitchHooks_PopulatesFromAgentModels pins the +// runtime→hook handoff for the new pure-unload contract: when the +// previous agent has configured models, the runtime must ship a +// snapshot of {Provider, Model, BaseURL, UnloadAPI} for each one on +// Input.FromAgentModels so a runtime-free hook can act on them. +func TestExecuteOnAgentSwitchHooks_PopulatesFromAgentModels(t *testing.T) { + t.Parallel() + + rb := &recordingBuiltin{} + prev := &endpointProvider{cfg: base.Config{ + ModelConfig: latest.ModelConfig{ + Provider: "dmr", + Model: "ai/qwen3", + ProviderOpts: map[string]any{"unload_api": "/engines/_unload"}, + }, + BaseURL: "http://127.0.0.1:12434/engines/v1", + }} + next := &mockProvider{id: "test/next", stream: &mockStream{}} + + from := agent.New("root", "instructions", + agent.WithModel(prev), + agent.WithHooks(&hooks.Config{ + OnAgentSwitch: []hooks.Hook{{ + Type: hooks.HookTypeBuiltin, Command: "test_record_agent_switch", + }}, + })) + to := agent.New("planner", "instructions", agent.WithModel(next)) + tm := team.New(team.WithAgents(from, to)) + + r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + require.NoError(t, r.hooksRegistry.RegisterBuiltin("test_record_agent_switch", rb.hook)) + r.buildHooksExecutors() + + r.executeOnAgentSwitchHooks(t.Context(), from, "s", "root", "planner", agentSwitchKindTransferTask) + + got := rb.snapshot() + require.Len(t, got, 1) + require.Len(t, got[0].FromAgentModels, 1, "runtime must ship one ModelEndpoint per configured model") + ep := got[0].FromAgentModels[0] + assert.Equal(t, "dmr", ep.Provider) + assert.Equal(t, "ai/qwen3", ep.Model) + assert.Equal(t, "http://127.0.0.1:12434/engines/v1", ep.BaseURL) + assert.Equal(t, "/engines/_unload", ep.UnloadAPI) +} + +// TestExecuteOnAgentSwitchHooks_FromAgentModelsNilWhenFromEmpty pins +// the cheap path on the very first switch into the team's default +// agent: there is no previous agent, so FromAgentModels must stay nil +// (no team lookup, no allocation) and the JSON wire form omits the +// field via `omitempty`. +func TestExecuteOnAgentSwitchHooks_FromAgentModelsNilWhenFromEmpty(t *testing.T) { + t.Parallel() + + r, rb := runtimeWithRecordedAgentSwitch(t, "root") + a := r.CurrentAgent() + require.NotNil(t, a) + + r.executeOnAgentSwitchHooks(t.Context(), a, "s", "", "root", agentSwitchKindHandoff) + + got := rb.snapshot() + require.Len(t, got, 1) + assert.Nil(t, got[0].FromAgentModels) +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 474d508d8..eaf16eb66 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -437,14 +437,6 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { return nil, fmt.Errorf("register %q builtin: %w", BuiltinCacheResponse, err) } - // unload is registered alongside cache_response for the same - // reason: it needs to walk Input.FromAgent up to the previous agent's - // configured models and dispatch to provider.Unloader implementations, - // which the runtime owns through the team. - if err := hooksRegistry.RegisterBuiltin(BuiltinUnload, r.unloadBuiltin); err != nil { - return nil, fmt.Errorf("register %q builtin: %w", BuiltinUnload, err) - } - // stripUnsupportedModalitiesTransform captures the runtime closure to // resolve the agent from Input.AgentName, so it lives here rather // than as a stateless builtin in pkg/hooks/builtins. It drops image diff --git a/pkg/runtime/unload.go b/pkg/runtime/unload.go deleted file mode 100644 index 6c0fe2f55..000000000 --- a/pkg/runtime/unload.go +++ /dev/null @@ -1,55 +0,0 @@ -package runtime - -import ( - "context" - "log/slog" - "time" - - "github.com/docker/docker-agent/pkg/hooks" - "github.com/docker/docker-agent/pkg/model/provider" -) - -// BuiltinUnload is the on_agent_switch builtin that asks every model on -// the previous agent to release its resources. Opt in via: -// -// hooks: -// on_agent_switch: -// - type: builtin -// command: unload -// -// Today only Docker Model Runner ships a [provider.Unloader]; other -// providers are silently skipped, so wiring the builtin on a -// cross-provider chain is harmless. -const BuiltinUnload = "unload" - -// unloadTimeout caps each Unload call so a stalled engine cannot stall -// agent switching. -const unloadTimeout = 10 * time.Second - -// unloadBuiltin calls Unload on every [provider.Unloader] of the -// previous agent. Errors are logged but never propagated — agent -// switching must never block on a slow or unreachable engine. -func (r *LocalRuntime) unloadBuiltin(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { - if in.FromAgent == "" || in.FromAgent == in.ToAgent { - return nil, nil - } - from, err := r.team.Agent(in.FromAgent) - if err != nil { - slog.DebugContext(ctx, "unload: from-agent lookup failed", - "agent", in.FromAgent, "error", err) - return nil, nil - } - for _, p := range from.ConfiguredModels() { - u, ok := p.(provider.Unloader) - if !ok { - continue - } - callCtx, cancel := context.WithTimeout(ctx, unloadTimeout) - if err := u.Unload(callCtx); err != nil { - slog.WarnContext(ctx, "unload: provider unload failed", - "agent", from.Name(), "model", p.ID(), "error", err) - } - cancel() - } - return nil, nil -} diff --git a/pkg/runtime/unload_test.go b/pkg/runtime/unload_test.go deleted file mode 100644 index 1eda21b3d..000000000 --- a/pkg/runtime/unload_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package runtime - -import ( - "context" - "errors" - "sync/atomic" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/docker/docker-agent/pkg/agent" - "github.com/docker/docker-agent/pkg/chat" - "github.com/docker/docker-agent/pkg/config/latest" - "github.com/docker/docker-agent/pkg/hooks" - "github.com/docker/docker-agent/pkg/model/provider" - "github.com/docker/docker-agent/pkg/model/provider/base" - "github.com/docker/docker-agent/pkg/team" - "github.com/docker/docker-agent/pkg/tools" -) - -// unloadingProvider is a minimal [provider.Unloader] test double that -// counts Unload calls and (optionally) returns a configured error so -// tests can assert call sites and best-effort error swallowing. -type unloadingProvider struct { - id string - calls atomic.Int64 - unloadErr error -} - -func (p *unloadingProvider) ID() string { return p.id } - -func (p *unloadingProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { - return &mockStream{}, nil -} - -func (p *unloadingProvider) BaseConfig() base.Config { - return base.Config{ModelConfig: latest.ModelConfig{Provider: "test", Model: p.id}} -} - -func (p *unloadingProvider) Unload(_ context.Context) error { - p.calls.Add(1) - return p.unloadErr -} - -var _ provider.Unloader = (*unloadingProvider)(nil) - -// newUnloadingRuntime wires up agents named after the supplied provider -// IDs and returns a runtime ready to drive [unloadBuiltin]. Each name -// gets its own [unloadingProvider] reachable through the returned map. -func newUnloadingRuntime(t *testing.T, names ...string) (*LocalRuntime, map[string]*unloadingProvider) { - t.Helper() - provs := make(map[string]*unloadingProvider, len(names)) - agents := make([]*agent.Agent, len(names)) - for i, name := range names { - p := &unloadingProvider{id: name} - provs[name] = p - agents[i] = agent.New(name, "instructions", agent.WithModel(p)) - } - rt, err := NewLocalRuntime(team.New(team.WithAgents(agents...)), - WithSessionCompaction(false), WithModelStore(mockModelStore{})) - require.NoError(t, err) - return rt, provs -} - -func TestUnloadBuiltin(t *testing.T) { - t.Parallel() - - t.Run("calls Unload on the previous agent only", func(t *testing.T) { - t.Parallel() - rt, provs := newUnloadingRuntime(t, "from", "to") - _, err := rt.unloadBuiltin(t.Context(), &hooks.Input{FromAgent: "from", ToAgent: "to"}, nil) - require.NoError(t, err) - assert.Equal(t, int64(1), provs["from"].calls.Load(), "from-agent's model must be unloaded once") - assert.Equal(t, int64(0), provs["to"].calls.Load(), "to-agent's model must NOT be unloaded") - }) - - t.Run("swallows Unload errors so agent switch never blocks", func(t *testing.T) { - t.Parallel() - rt, provs := newUnloadingRuntime(t, "from", "to") - provs["from"].unloadErr = errors.New("engine offline") - out, err := rt.unloadBuiltin(t.Context(), &hooks.Input{FromAgent: "from", ToAgent: "to"}, nil) - require.NoError(t, err) - assert.Nil(t, out) - assert.Equal(t, int64(1), provs["from"].calls.Load()) - }) - - t.Run("no-op when from==to", func(t *testing.T) { - t.Parallel() - rt, provs := newUnloadingRuntime(t, "from") - _, err := rt.unloadBuiltin(t.Context(), &hooks.Input{FromAgent: "from", ToAgent: "from"}, nil) - require.NoError(t, err) - assert.Equal(t, int64(0), provs["from"].calls.Load()) - }) -} - -// TestUnloadBuiltin_Registered guarantees the builtin name is findable -// on the runtime registry so YAML hook entries that reference it -// actually resolve. -func TestUnloadBuiltin_Registered(t *testing.T) { - t.Parallel() - rt, _ := newUnloadingRuntime(t, "loader") - fn, ok := rt.hooksRegistry.LookupBuiltin(BuiltinUnload) - require.True(t, ok, "%q must be registered on the runtime's hook registry", BuiltinUnload) - require.NotNil(t, fn) -}