From e818273679f94ef38f8f765a06b0430028e23280 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 17:07:23 +0200 Subject: [PATCH 1/3] refactor(hooks): make the unload builtin pure, decouple from runtime Move the on_agent_switch `unload` builtin out of pkg/runtime and pkg/model/provider/dmr into pkg/hooks/builtins. The builtin no longer depends on the runtime team or the provider.Unloader interface; it reads a snapshot of the previous agent's model endpoints from hooks.Input.FromAgentModels and POSTs to the resolved `_unload` URL over plain HTTP. The runtime ships generic data on every on_agent_switch dispatch (provider, model id, resolved base URL, optional unload_api override) and no longer knows the word 'unload'. Cross-provider chains stay safe: non-DMR endpoints are silently skipped. --- agent-schema.json | 4 +- docs/configuration/hooks/index.md | 4 +- examples/unload_on_switch.yaml | 12 +- pkg/config/latest/types.go | 2 +- pkg/hooks/builtins/builtins.go | 3 + pkg/hooks/builtins/builtins_test.go | 1 + pkg/hooks/builtins/unload.go | 157 +++++++++++++ pkg/hooks/builtins/unload_test.go | 312 ++++++++++++++++++++++++++ pkg/hooks/types.go | 30 +++ pkg/model/provider/base/base.go | 11 + pkg/model/provider/dmr/client.go | 7 +- pkg/model/provider/dmr/client_test.go | 2 +- pkg/model/provider/dmr/embed.go | 4 +- pkg/model/provider/dmr/unload.go | 103 --------- pkg/model/provider/dmr/unload_test.go | 219 ------------------ pkg/model/provider/provider.go | 16 -- pkg/runtime/hooks.go | 37 +++ pkg/runtime/on_agent_switch_test.go | 84 +++++++ pkg/runtime/runtime.go | 8 - pkg/runtime/unload.go | 55 ----- pkg/runtime/unload_test.go | 106 --------- 21 files changed, 653 insertions(+), 524 deletions(-) create mode 100644 pkg/hooks/builtins/unload.go create mode 100644 pkg/hooks/builtins/unload_test.go delete mode 100644 pkg/model/provider/dmr/unload.go delete mode 100644 pkg/model/provider/dmr/unload_test.go delete mode 100644 pkg/runtime/unload.go delete mode 100644 pkg/runtime/unload_test.go diff --git a/agent-schema.json b/agent-schema.json index 33b549c2e..106c69717 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -182,7 +182,7 @@ }, "unload_api": { "type": "string", - "description": "Optional path (or absolute URL) to the provider's model-unload endpoint. POSTed with `{\"model\": \"\"}` when the agent wires the `unload` builtin into its `on_agent_switch` hook chain, to free GPU/RAM held by the previous model. Today only Docker Model Runner ships a provider that calls this endpoint; cloud providers don't implement [provider.Unloader] and the hook silently skips them. A relative path is resolved against the scheme+host of base_url; an absolute URL is used verbatim.", + "description": "Optional path (or absolute URL) to the provider's model-unload endpoint. POSTed with `{\"model\": \"\"}` when the agent wires the `unload` builtin into its `on_agent_switch` hook chain, to free GPU/RAM held by the previous model. Today only Docker Model Runner exposes such an endpoint; the `unload` builtin is a pure HTTP hook that silently skips non-DMR providers. A relative path is resolved against the scheme+host of base_url; an absolute URL is used verbatim.", "examples": [ "/engines/_unload", "/api/unload", @@ -881,7 +881,7 @@ }, "type": { "type": "string", - "description": "Type of hook. 'command' executes a shell command; 'builtin' invokes a named in-process Go function registered by the runtime; 'model' asks an LLM and translates its reply into the hook's native output (used for LLM-as-a-judge pre_tool_use, summarizers, etc., with no Go code). The docker-agent runtime ships these builtins: 'add_date' (turn_start: today's date), 'add_environment_info' (session_start: cwd, git, OS, arch), 'add_prompt_files' (turn_start: contents of named files looked up in the workdir hierarchy and the home directory), 'add_git_status' (turn_start: `git status --short --branch`), 'add_git_diff' (turn_start: `git diff --stat`, or full diff with args=['full']), 'add_directory_listing' (session_start: top-level entries of cwd), 'add_user_info' (session_start: current OS user and hostname), 'add_recent_commits' (session_start: `git log --oneline -n N`, default N=10, override via args=['']), 'max_iterations' (before_llm_call: hard stop after N model calls; args=[''] required), 'redact_secrets' (pre_tool_use / before_llm_call / tool_response_transform: scrubs detected secrets from tool arguments, outgoing chat content, and tool output — the same builtin handles all three legs and dispatches on the event; the matching agent-level 'redact_secrets: true' flag auto-injects the entries for all three), 'unload' (on_agent_switch: walks the previous agent's models and calls Unload() on every provider that implements provider.Unloader — e.g. asks Docker Model Runner to release the GPU/RAM held by the just-departing model so the next agent's model can claim it; opt in by adding the entry to the agent's hooks.on_agent_switch list).", + "description": "Type of hook. 'command' executes a shell command; 'builtin' invokes a named in-process Go function registered by the runtime; 'model' asks an LLM and translates its reply into the hook's native output (used for LLM-as-a-judge pre_tool_use, summarizers, etc., with no Go code). The docker-agent runtime ships these builtins: 'add_date' (turn_start: today's date), 'add_environment_info' (session_start: cwd, git, OS, arch), 'add_prompt_files' (turn_start: contents of named files looked up in the workdir hierarchy and the home directory), 'add_git_status' (turn_start: `git status --short --branch`), 'add_git_diff' (turn_start: `git diff --stat`, or full diff with args=['full']), 'add_directory_listing' (session_start: top-level entries of cwd), 'add_user_info' (session_start: current OS user and hostname), 'add_recent_commits' (session_start: `git log --oneline -n N`, default N=10, override via args=['']), 'max_iterations' (before_llm_call: hard stop after N model calls; args=[''] required), 'redact_secrets' (pre_tool_use / before_llm_call / tool_response_transform: scrubs detected secrets from tool arguments, outgoing chat content, and tool output — the same builtin handles all three legs and dispatches on the event; the matching agent-level 'redact_secrets: true' flag auto-injects the entries for all three), 'unload' (on_agent_switch: POSTs `{\"model\": \"\"}` to the previous agent's DMR model endpoints — e.g. asks Docker Model Runner to release the GPU/RAM held by the just-departing model so the next agent's model can claim it. Pure HTTP, no provider-specific runtime coupling; non-DMR providers are silently skipped. Opt in by adding the entry to the agent's hooks.on_agent_switch list).", "enum": [ "command", "builtin", diff --git a/docs/configuration/hooks/index.md b/docs/configuration/hooks/index.md index b58ad2022..83ee57216 100644 --- a/docs/configuration/hooks/index.md +++ b/docs/configuration/hooks/index.md @@ -156,7 +156,7 @@ Built-ins are typically zero-config and faster than equivalent shell hooks becau | `max_iterations` | `before_llm_call` | `[""]` (required) | Hard-stops the agent after `N` model calls. State is per-session and reset at `session_end`. | | `snapshot` | `session_start`, `turn_start`, `turn_end`, `pre_tool_use`, `post_tool_use`, `session_end` | _none_ | Records filesystem snapshots in a shadow git repo under the docker-agent data directory. No-op outside git repos; respects the source repo's ignore rules and skips newly-added files larger than 2 MiB. | | `redact_secrets` | `pre_tool_use`, `before_llm_call`, `tool_response_transform` | _none_ | Scrubs detected secrets (API keys, tokens, private keys, …) out of tool call arguments, outgoing chat content, and tool output. The same builtin handles all three events and dispatches on the event name. Auto-registered on all three events by `redact_secrets: true` on the agent — see [`examples/redact_secrets_hooks.yaml`](https://github.com/docker/docker-agent/blob/main/examples/redact_secrets_hooks.yaml) for the manual wiring. | -| `unload` | `on_agent_switch` | _none_ | Walks the previous agent's models and calls `Unload()` on every provider that implements [`provider.Unloader`](https://pkg.go.dev/github.com/docker/docker-agent/pkg/model/provider#Unloader) — typically Docker Model Runner — to free the GPU/RAM the just-departing model was holding. Cloud-only providers don't implement the interface and are silently skipped. Errors are logged and swallowed; agent switching never blocks on a slow or unreachable engine (each Unload call has a 10 s timeout). See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml). | +| `unload` | `on_agent_switch` | _none_ | POSTs `{"model": ""}` to each of the previous agent's DMR model endpoints (`/_unload` by default, overridable per-model via `unload_api`) to free the GPU/RAM the just-departing model was holding. Pure HTTP — reads the model snapshot the runtime ships on `on_agent_switch` and depends on no provider-specific runtime state. Non-DMR providers (OpenAI, Anthropic, …) are silently skipped, so cross-provider chains are safe. Errors are logged and swallowed; agent switching never blocks on a slow or unreachable engine (each call has a 10 s timeout). See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml). |
ℹ️ Per-turn vs. per-session @@ -590,7 +590,7 @@ models: model: ai/qwen3-coder ``` -At every transfer the runtime calls `Unload()` on the previous agent's model providers. For Docker Model Runner this hits the engine's `_unload` endpoint; for cloud providers (OpenAI, Anthropic, …) it is a silent no-op. Cross-provider chains are safe — only the providers that actually implement [`provider.Unloader`](https://pkg.go.dev/github.com/docker/docker-agent/pkg/model/provider#Unloader) are touched. See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml) for the full file. +At every transfer the runtime ships a snapshot of the previous agent's model endpoints on the `on_agent_switch` hook input, and the `unload` builtin POSTs `{"model": ""}` to each DMR endpoint's `/_unload` URL over plain HTTP. For cloud providers (OpenAI, Anthropic, …) the hook is a silent no-op since they don't expose an HTTP unload endpoint. Cross-provider chains are safe — only DMR endpoints are touched. See [`examples/unload_on_switch.yaml`](https://github.com/docker/docker-agent/blob/main/examples/unload_on_switch.yaml) for the full file. `on_session_resume` fires when the user explicitly approves the runtime to continue past its configured `max_iterations` limit. `previous_max_iterations` carries the cap that was reached and `new_max_iterations` carries the new cap after approval. Useful for alerting on extended-runtime sessions or for billing / quota pipelines that meter resumes. diff --git a/examples/unload_on_switch.yaml b/examples/unload_on_switch.yaml index e4aea29ef..c098e374e 100644 --- a/examples/unload_on_switch.yaml +++ b/examples/unload_on_switch.yaml @@ -4,11 +4,13 @@ # # Two agents share Docker Model Runner but use different models that don't # fit in GPU memory at the same time. Wiring the `unload` builtin into -# each agent's `on_agent_switch` hook chain tells the runtime to call -# Unload on the previous agent's model providers every time the active -# agent transfers control. For DMR this hits the engine's `_unload` -# endpoint; for cloud-only providers (OpenAI, Anthropic, ...) the hook is -# a silent no-op since they don't implement provider.Unloader. +# each agent's `on_agent_switch` hook chain asks the previous agent's +# DMR endpoint(s) to release GPU memory every time the active agent +# transfers control. The hook is pure: it reads the model snapshot the +# runtime ships on every on_agent_switch dispatch and POSTs to DMR's +# `_unload` endpoint over plain HTTP — no provider-specific runtime +# coupling. For cloud-only providers (OpenAI, Anthropic, ...) the hook +# is a silent no-op since they don't expose an HTTP unload endpoint. # # Switching back and forth between `coder` and `reviewer` therefore costs # one model load per switch instead of failing on out-of-memory. diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index bfac3d6a1..0cc7c712c 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -238,7 +238,7 @@ type ProviderConfig struct { // models are POSTed `{"model": ""}` here at every switch. // Cloud providers should leave this unset. // - // [unload]: https://pkg.go.dev/github.com/docker/docker-agent/pkg/runtime#BuiltinUnload + // [unload]: https://pkg.go.dev/github.com/docker/docker-agent/pkg/hooks/builtins#Unload UnloadAPI string `json:"unload_api,omitempty"` // TokenKey is the environment variable name containing the API token TokenKey string `json:"token_key,omitempty"` diff --git a/pkg/hooks/builtins/builtins.go b/pkg/hooks/builtins/builtins.go index cdfa98705..80e3a821e 100644 --- a/pkg/hooks/builtins/builtins.go +++ b/pkg/hooks/builtins/builtins.go @@ -12,6 +12,8 @@ // - add_user_info (session_start) — current OS user and host // - add_recent_commits (session_start) — `git log --oneline -n N` // - max_iterations (before_llm_call) — hard stop after N model calls +// - unload (on_agent_switch) — release the previous +// agent's local-engine resources via HTTP unload (DMR today) // - snapshot (session_start, // turn_start, turn_end, // pre_tool_use, post_tool_use, @@ -120,6 +122,7 @@ func Register(r *hooks.Registry) (*State, error) { r.RegisterBuiltin(Snapshot, state.snapshot.hook), r.RegisterBuiltin(RedactSecrets, redactSecrets), r.RegisterBuiltin(HTTPPost, httpPost), + r.RegisterBuiltin(Unload, unload), ); err != nil { return nil, err } diff --git a/pkg/hooks/builtins/builtins_test.go b/pkg/hooks/builtins/builtins_test.go index 8fa1dc43c..2358a9869 100644 --- a/pkg/hooks/builtins/builtins_test.go +++ b/pkg/hooks/builtins/builtins_test.go @@ -37,6 +37,7 @@ func TestRegisterInstallsAllBuiltins(t *testing.T) { builtins.Snapshot, builtins.RedactSecrets, builtins.HTTPPost, + builtins.Unload, } { fn, ok := r.LookupBuiltin(name) assert.True(t, ok, "builtin %q must be registered", name) diff --git a/pkg/hooks/builtins/unload.go b/pkg/hooks/builtins/unload.go new file mode 100644 index 000000000..0c61216ad --- /dev/null +++ b/pkg/hooks/builtins/unload.go @@ -0,0 +1,157 @@ +package builtins + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "github.com/docker/docker-agent/pkg/hooks" +) + +// Unload is the registered name of the on_agent_switch builtin that +// asks the previous agent's local inference engines (today: Docker +// Model Runner) to release the resources they hold. +// +// Wire it into a config with: +// +// hooks: +// on_agent_switch: +// - type: builtin +// command: unload +// +// The hook is pure: it depends only on the [hooks.Input.FromAgentModels] +// snapshot the runtime ships on every on_agent_switch dispatch, plus +// net/http. It carries no runtime-side coupling and silently skips any +// model whose endpoint isn't reachable as plain HTTP (e.g. cloud +// providers that don't expose [hooks.ModelEndpoint.BaseURL]). +const Unload = "unload" + +// unloadProviderDMR is the only provider type the builtin currently +// knows how to unload against. Other providers ship through a no-op so +// wiring this hook on a heterogeneous chain is harmless. +const unloadProviderDMR = "dmr" + +// unloadTimeout caps each Unload call so a stalled engine cannot stall +// agent switching. Each model gets its own deadline so a slow first +// model can't starve the rest. +const unloadTimeout = 10 * time.Second + +// unload iterates the [hooks.Input.FromAgentModels] snapshot the +// runtime captured at dispatch time and POSTs `{"model": ""}` to +// the resolved unload endpoint of each one we know how to unload. +// Errors are logged but never propagated — agent switching must never +// block on a slow or unreachable engine. +func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { + if in == nil || in.FromAgent == "" || in.FromAgent == in.ToAgent { + return nil, nil + } + for _, m := range in.FromAgentModels { + if m.Provider != unloadProviderDMR { + continue + } + endpoint, err := resolveUnloadURL(m.BaseURL, m.UnloadAPI) + if err != nil { + slog.WarnContext(ctx, "unload: resolving endpoint failed", + "agent", in.FromAgent, "model", m.Model, "error", err) + continue + } + if endpoint == "" { + continue + } + callCtx, cancel := context.WithTimeout(ctx, unloadTimeout) + if err := postUnloadModel(callCtx, http.DefaultClient, endpoint, m.Model); err != nil { + slog.WarnContext(ctx, "unload: provider unload failed", + "agent", in.FromAgent, "model", m.Model, "error", err) + } + cancel() + } + return nil, nil +} + +// resolveUnloadURL picks the unload endpoint for one model: +// the configured `unload_api` override (rebased against baseURL when +// it isn't already absolute) wins, otherwise [defaultUnloadURL] +// derives one from baseURL by replacing the trailing /v1 segment. +// Returns ("", nil) when no endpoint can be determined so the caller +// can skip without erroring. +func resolveUnloadURL(baseURL, override string) (string, error) { + if override != "" { + return rebaseURL(baseURL, override) + } + if baseURL == "" { + return "", nil + } + return defaultUnloadURL(baseURL), nil +} + +// defaultUnloadURL derives the `_unload` endpoint URL from the OpenAI +// base URL by replacing the trailing `/v1` segment, mirroring how the +// DMR client derives `_configure`: +// +// http://host:port/engines/v1/ → http://host:port/engines/_unload +// http://host:port/engines/llama.cpp/v1/ → http://host:port/engines/llama.cpp/_unload +// http://_/exp/vDD4.40/engines/v1 → http://_/exp/vDD4.40/engines/_unload +func defaultUnloadURL(baseURL string) string { + u, err := url.Parse(baseURL) + if err != nil { + return strings.TrimSuffix(strings.TrimSuffix(baseURL, "/"), "/v1") + "/_unload" + } + u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload" + return u.String() +} + +// rebaseURL returns path verbatim if it is already an absolute URL, +// otherwise attaches it to baseURL's scheme + host (dropping any path +// baseURL may carry). This lets users point base_url at e.g. +// http://localhost:12434/engines/v1 and override unload_api with +// /engines/_unload without the version prefix bleeding through. +func rebaseURL(baseURL, path string) (string, error) { + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return path, nil + } + u, err := url.Parse(baseURL) + if err != nil || u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("base_url %q is not absolute; cannot resolve %q", baseURL, path) + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return u.Scheme + "://" + u.Host + path, nil +} + +// postUnloadModel issues `POST ` with body `{"model": ""}`. +func postUnloadModel(ctx context.Context, client *http.Client, endpoint, modelID string) error { + body, _ := json.Marshal(map[string]string{"model": modelID}) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("building unload request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", modelID) + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("calling unload endpoint %s: %w", endpoint, err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) + return fmt.Errorf("unload endpoint returned %d: %s", + resp.StatusCode, strings.TrimSpace(string(respBody))) + } + // Drain the success-path body so the underlying transport can reuse + // the connection (Go's http.Client only re-pools a connection whose + // body has been read to EOF and closed). + _, _ = io.Copy(io.Discard, resp.Body) + return nil +} diff --git a/pkg/hooks/builtins/unload_test.go b/pkg/hooks/builtins/unload_test.go new file mode 100644 index 000000000..3cd9d8ad8 --- /dev/null +++ b/pkg/hooks/builtins/unload_test.go @@ -0,0 +1,312 @@ +package builtins + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/hooks" +) + +// TestUnloadBuiltin_Registered guarantees the public name is findable +// on a registry built by [Register], so YAML hook entries that name it +// actually resolve. +func TestUnloadBuiltin_Registered(t *testing.T) { + t.Parallel() + + r := hooks.NewRegistry() + _, err := Register(r) + require.NoError(t, err) + + fn, ok := r.LookupBuiltin(Unload) + require.True(t, ok, "%q must be registered on the hook registry", Unload) + require.NotNil(t, fn) +} + +// TestUnload_PostsToDefaultEndpoint exercises the happy path against a +// real httptest server: the builtin must derive the `_unload` URL from +// the model's BaseURL and POST `{"model": ""}`. +func TestUnload_PostsToDefaultEndpoint(t *testing.T) { + t.Parallel() + + var ( + gotPath string + gotBody map[string]string + ) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &gotBody) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{ + Provider: "dmr", + Model: "ai/qwen3", + BaseURL: server.URL + "/engines/llama.cpp/v1", + }}, + } + out, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Nil(t, out, "unload is observational; output must be nil") + assert.Equal(t, "/engines/llama.cpp/_unload", gotPath) + assert.Equal(t, map[string]string{"model": "ai/qwen3"}, gotBody) +} + +// TestUnload_HonoursOverrideUnloadAPI documents that an explicit +// `unload_api` on the model takes precedence over the default +// derivation, and is rebased onto the BaseURL's host when relative. +func TestUnload_HonoursOverrideUnloadAPI(t *testing.T) { + t.Parallel() + + var gotPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{ + Provider: "dmr", + Model: "ai/qwen3", + BaseURL: server.URL + "/engines/v1", + UnloadAPI: "/custom/unload", + }}, + } + _, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Equal(t, "/custom/unload", gotPath) +} + +// TestUnload_SkipsNonDMRProviders pins the cross-provider safety +// property: wiring `unload` on a heterogeneous chain is harmless +// because non-DMR endpoints are silently skipped. +func TestUnload_SkipsNonDMRProviders(t *testing.T) { + t.Parallel() + + var calls atomic.Int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{ + {Provider: "openai", Model: "gpt-4", BaseURL: server.URL}, + {Provider: "anthropic", Model: "claude", BaseURL: server.URL}, + }, + } + _, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Equal(t, int64(0), calls.Load(), "no HTTP call must reach a non-DMR endpoint") +} + +// TestUnload_NoopWhenFromEqualsTo documents the no-self-unload guard: +// transferring back into the same agent must not unload the model the +// next turn is about to use. +func TestUnload_NoopWhenFromEqualsTo(t *testing.T) { + t.Parallel() + + var calls atomic.Int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + in := &hooks.Input{ + FromAgent: "same", + ToAgent: "same", + FromAgentModels: []hooks.ModelEndpoint{{ + Provider: "dmr", + Model: "ai/qwen3", + BaseURL: server.URL + "/engines/v1", + }}, + } + _, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Equal(t, int64(0), calls.Load()) +} + +// TestUnload_NoopWhenFromAgentEmpty documents the cheap path: the very +// first switch into the team's default agent has no previous agent and +// must not fire any HTTP call. +func TestUnload_NoopWhenFromAgentEmpty(t *testing.T) { + t.Parallel() + + var calls atomic.Int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + in := &hooks.Input{ + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{ + Provider: "dmr", + Model: "ai/qwen3", + BaseURL: server.URL + "/engines/v1", + }}, + } + _, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Equal(t, int64(0), calls.Load()) +} + +// TestUnload_SwallowsServerErrors verifies the best-effort contract: +// a 5xx from the engine must NOT propagate back as a hook error, +// because agent switching has to keep moving even when the unload +// endpoint is down. +func TestUnload_SwallowsServerErrors(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("boom")) + })) + defer server.Close() + + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{ + Provider: "dmr", + Model: "ai/qwen3", + BaseURL: server.URL + "/engines/v1", + }}, + } + out, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Nil(t, out) +} + +// TestUnload_NoopWhenNoEndpoint documents that a DMR model without a +// BaseURL or unload_api is silently skipped (rather than erroring) so +// the hook stays harmless when wired against an in-process / test +// provider that hasn't resolved an HTTP endpoint. +func TestUnload_NoopWhenNoEndpoint(t *testing.T) { + t.Parallel() + + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{ + Provider: "dmr", + Model: "ai/qwen3", + }}, + } + out, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Nil(t, out) +} + +func TestRebaseURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baseURL string + path string + want string + errContains string // empty ⇒ expect success + }{ + { + name: "absolute https URL is returned verbatim", + baseURL: "http://anything", + path: "https://api.example.com/unload", + want: "https://api.example.com/unload", + }, + { + name: "rooted path drops base path", + baseURL: "http://localhost:12434/engines/v1", + path: "/engines/_unload", + want: "http://localhost:12434/engines/_unload", + }, + { + name: "relative path is rooted", + baseURL: "http://localhost:12434/engines/v1", + path: "engines/_unload", + want: "http://localhost:12434/engines/_unload", + }, + { + name: "empty base URL with relative path errors", + path: "/engines/_unload", + errContains: "is not absolute", + }, + { + name: "base URL without scheme errors", + baseURL: "localhost:12434/engines/v1", + path: "/engines/_unload", + errContains: "is not absolute", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := rebaseURL(tt.baseURL, tt.path) + if tt.errContains != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestDefaultUnloadURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baseURL string + expected string + }{ + { + name: "standard engines path", + baseURL: "http://127.0.0.1:12434/engines/v1/", + expected: "http://127.0.0.1:12434/engines/_unload", + }, + { + name: "no trailing slash", + baseURL: "http://127.0.0.1:12434/engines/v1", + expected: "http://127.0.0.1:12434/engines/_unload", + }, + { + name: "Docker Desktop experimental prefix", + baseURL: "http://_/exp/vDD4.40/engines/v1", + expected: "http://_/exp/vDD4.40/engines/_unload", + }, + { + name: "backend-scoped path", + baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/", + expected: "http://127.0.0.1:12434/engines/llama.cpp/_unload", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.expected, defaultUnloadURL(tt.baseURL)) + }) + } +} diff --git a/pkg/hooks/types.go b/pkg/hooks/types.go index 474c0cf05..60d828cba 100644 --- a/pkg/hooks/types.go +++ b/pkg/hooks/types.go @@ -242,6 +242,15 @@ type Input struct { ToAgent string `json:"to_agent,omitempty"` AgentSwitchKind string `json:"agent_switch_kind,omitempty"` + // FromAgentModels is the snapshot of the previous agent's + // configured model endpoints, captured at on_agent_switch dispatch + // time. Populated only on [EventOnAgentSwitch]; nil for every other + // event. Hooks that act on the previous agent's models (e.g. the + // stock `unload` builtin asking a local inference engine to release + // GPU memory) read this slice instead of poking at the runtime, + // keeping the hook payload self-contained. + FromAgentModels []ModelEndpoint `json:"from_agent_models,omitempty"` + // OnSessionResume specific: the iteration cap that was reached // (PreviousMaxIterations) and the new cap after the user approved // continuation (NewMaxIterations). Carrying both lets audit @@ -278,6 +287,27 @@ type Input struct { Summary string `json:"summary,omitempty"` } +// ModelEndpoint identifies one of an agent's configured models plus +// the HTTP endpoint that hosts it, when one is known. It is the wire +// format used by [Input.FromAgentModels] so hooks can reach a +// model-serving endpoint without depending on runtime-only types. +type ModelEndpoint struct { + // Provider is the provider type ("openai", "anthropic", "dmr", ...). + Provider string `json:"provider,omitempty"` + // Model is the resolved model identifier. + Model string `json:"model,omitempty"` + // BaseURL is the resolved HTTP base URL of the provider, when known + // (set by providers that talk to a configurable HTTP endpoint, e.g. + // Docker Model Runner). Empty for cloud providers that don't expose + // a stable per-instance base URL on the runtime side. + BaseURL string `json:"base_url,omitempty"` + // UnloadAPI is the per-model unload path or absolute URL copied + // verbatim from the model's `unload_api` provider option. Empty + // when the user hasn't configured an override; the unload builtin + // falls back to a provider-specific default in that case. + UnloadAPI string `json:"unload_api,omitempty"` +} + // ToJSON serializes the input. func (i *Input) ToJSON() ([]byte, error) { return json.Marshal(i) } diff --git a/pkg/model/provider/base/base.go b/pkg/model/provider/base/base.go index 5fdc1a07a..b757590b1 100644 --- a/pkg/model/provider/base/base.go +++ b/pkg/model/provider/base/base.go @@ -17,6 +17,17 @@ type Config struct { // Models stores the full models map for providers that need it (e.g., routers). // This enables proper cloning of providers that reference other models. Models map[string]latest.ModelConfig + // BaseURL is the resolved HTTP base URL the client talks to, when + // the provider is reachable over a configurable HTTP endpoint. + // Distinct from [latest.ModelConfig.BaseURL] (the user-typed input): + // providers fill BaseURL with the URL they actually use after auto + // discovery / fallback (e.g. Docker Model Runner picking between + // MODEL_RUNNER_HOST, the desktop socket, and a localhost fallback). + // Surfaced through [Config.BaseConfig] so generic, runtime-free + // consumers like hooks can address the endpoint without duplicating + // resolution logic. Empty for providers that don't expose a stable + // per-instance URL. + BaseURL string } // ID returns the provider and model ID in the format "provider/model". diff --git a/pkg/model/provider/dmr/client.go b/pkg/model/provider/dmr/client.go index 3e8dca4cb..2f9887524 100644 --- a/pkg/model/provider/dmr/client.go +++ b/pkg/model/provider/dmr/client.go @@ -53,7 +53,6 @@ type Client struct { base.Config client openai.Client - baseURL string httpClient *http.Client engine string } @@ -137,9 +136,9 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt Config: base.Config{ ModelConfig: *cfg, ModelOptions: globalOptions, + BaseURL: baseURL, }, client: openai.NewClient(clientOptions...), - baseURL: baseURL, httpClient: httpClient, engine: engine, }, nil @@ -159,7 +158,7 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat "model", c.ModelConfig.Model, "message_count", len(messages), "tool_count", len(requestTools), - "base_url", c.baseURL, + "base_url", c.BaseURL, ) if len(messages) == 0 { @@ -286,7 +285,7 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat stream := c.client.Chat.Completions.NewStreaming(ctx, params) - slog.DebugContext(ctx, "DMR chat completion stream created successfully", "model", c.ModelConfig.Model, "base_url", c.baseURL) + slog.DebugContext(ctx, "DMR chat completion stream created successfully", "model", c.ModelConfig.Model, "base_url", c.BaseURL) return newStreamAdapter(stream, trackUsage), nil } diff --git a/pkg/model/provider/dmr/client_test.go b/pkg/model/provider/dmr/client_test.go index 12beb63e6..24835658d 100644 --- a/pkg/model/provider/dmr/client_test.go +++ b/pkg/model/provider/dmr/client_test.go @@ -30,7 +30,7 @@ func TestNewClientWithExplicitBaseURL(t *testing.T) { client, err := NewClient(t.Context(), cfg) require.NoError(t, err) - assert.Equal(t, "https://custom.example.com:8080/api/v1", client.baseURL) + assert.Equal(t, "https://custom.example.com:8080/api/v1", client.BaseURL) } func TestNewClientReturnsErrNotInstalledWhenDockerModelUnsupported(t *testing.T) { diff --git a/pkg/model/provider/dmr/embed.go b/pkg/model/provider/dmr/embed.go index 205efbf4a..aa38185d0 100644 --- a/pkg/model/provider/dmr/embed.go +++ b/pkg/model/provider/dmr/embed.go @@ -43,7 +43,7 @@ func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*bas return &base.BatchEmbeddingResult{Embeddings: [][]float64{}}, nil } - slog.DebugContext(ctx, "Creating DMR embeddings", "model", c.ModelConfig.Model, "batch_size", len(texts), "base_url", c.baseURL) + slog.DebugContext(ctx, "Creating DMR embeddings", "model", c.ModelConfig.Model, "batch_size", len(texts), "base_url", c.BaseURL) response, err := c.client.Embeddings.New(ctx, openai.EmbeddingNewParams{ Input: openai.EmbeddingNewParamsInputUnion{OfArrayOfStrings: texts}, @@ -127,7 +127,7 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc documentStrings[i] = doc.Content } - baseURL, err := rerankBaseURL(c.baseURL) + baseURL, err := rerankBaseURL(c.BaseURL) if err != nil { return nil, err } diff --git a/pkg/model/provider/dmr/unload.go b/pkg/model/provider/dmr/unload.go deleted file mode 100644 index 5fb825f4f..000000000 --- a/pkg/model/provider/dmr/unload.go +++ /dev/null @@ -1,103 +0,0 @@ -package dmr - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "log/slog" - "net/http" - "net/url" - "strings" -) - -// Unload asks Docker Model Runner to release the resources held for the -// configured model. Invoked by the runtime's `unload` on_agent_switch -// builtin hook. -// -// The unload endpoint is the provider's `unload_api` (relative path or -// absolute URL) when set, otherwise [defaultUnloadURL] derived from the -// OpenAI base URL. When neither is available the call is a no-op. -func (c *Client) Unload(ctx context.Context) error { - endpoint, err := c.resolveUnloadURL() - if err != nil || endpoint == "" { - return err - } - return postUnloadModel(ctx, c.httpClient, endpoint, c.ModelConfig.Model) -} - -func (c *Client) resolveUnloadURL() (string, error) { - if override := c.ModelConfig.UnloadAPI(); override != "" { - return rebaseURL(c.baseURL, override) - } - if c.baseURL == "" { - return "", nil - } - return defaultUnloadURL(c.baseURL), nil -} - -// defaultUnloadURL derives the `_unload` endpoint URL from the OpenAI -// base URL by replacing the trailing `/v1` segment, mirroring how -// [buildConfigureURL] derives `_configure`: -// -// http://host:port/engines/v1/ → http://host:port/engines/_unload -// http://host:port/engines/llama.cpp/v1/ → http://host:port/engines/llama.cpp/_unload -// http://_/exp/vDD4.40/engines/v1 → http://_/exp/vDD4.40/engines/_unload -func defaultUnloadURL(baseURL string) string { - u, err := url.Parse(baseURL) - if err != nil { - return strings.TrimSuffix(strings.TrimSuffix(baseURL, "/"), "/v1") + "/_unload" - } - u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload" - return u.String() -} - -// rebaseURL returns path verbatim if it is already an absolute URL, -// otherwise attaches it to baseURL's scheme + host (dropping any path -// baseURL may carry). This lets users point base_url at e.g. -// http://localhost:12434/engines/v1 and override unload_api with -// /engines/_unload without the version prefix bleeding through. -func rebaseURL(baseURL, path string) (string, error) { - if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { - return path, nil - } - u, err := url.Parse(baseURL) - if err != nil || u.Scheme == "" || u.Host == "" { - return "", fmt.Errorf("base_url %q is not absolute; cannot resolve %q", baseURL, path) - } - if !strings.HasPrefix(path, "/") { - path = "/" + path - } - return u.Scheme + "://" + u.Host + path, nil -} - -// postUnloadModel issues `POST ` with body `{"model": ""}`. -func postUnloadModel(ctx context.Context, client *http.Client, endpoint, modelID string) error { - body, _ := json.Marshal(map[string]string{"model": modelID}) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("building unload request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", modelID) - - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("calling unload endpoint %s: %w", endpoint, err) - } - defer resp.Body.Close() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) - return fmt.Errorf("unload endpoint returned %d: %s", - resp.StatusCode, strings.TrimSpace(string(respBody))) - } - // Drain the success-path body so the underlying transport can reuse - // the connection (Go's http.Client only re-pools a connection whose - // body has been read to EOF and closed). - _, _ = io.Copy(io.Discard, resp.Body) - return nil -} diff --git a/pkg/model/provider/dmr/unload_test.go b/pkg/model/provider/dmr/unload_test.go deleted file mode 100644 index 0a2fcc463..000000000 --- a/pkg/model/provider/dmr/unload_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package dmr - -import ( - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/docker/docker-agent/pkg/config/latest" - "github.com/docker/docker-agent/pkg/model/provider/base" -) - -func TestRebaseURL(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - baseURL string - path string - want string - errContains string // empty ⇒ expect success - }{ - { - name: "absolute https URL is returned verbatim", - baseURL: "http://anything", - path: "https://api.example.com/unload", - want: "https://api.example.com/unload", - }, - { - name: "rooted path drops base path", - baseURL: "http://localhost:12434/engines/v1", - path: "/engines/_unload", - want: "http://localhost:12434/engines/_unload", - }, - { - name: "relative path is rooted", - baseURL: "http://localhost:12434/engines/v1", - path: "engines/_unload", - want: "http://localhost:12434/engines/_unload", - }, - { - name: "empty base URL with relative path errors", - path: "/engines/_unload", - errContains: "is not absolute", - }, - { - name: "base URL without scheme errors", - baseURL: "localhost:12434/engines/v1", - path: "/engines/_unload", - errContains: "is not absolute", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got, err := rebaseURL(tt.baseURL, tt.path) - if tt.errContains != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errContains) - return - } - require.NoError(t, err) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestDefaultUnloadURL(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - baseURL string - expected string - }{ - { - name: "standard engines path", - baseURL: "http://127.0.0.1:12434/engines/v1/", - expected: "http://127.0.0.1:12434/engines/_unload", - }, - { - name: "no trailing slash", - baseURL: "http://127.0.0.1:12434/engines/v1", - expected: "http://127.0.0.1:12434/engines/_unload", - }, - { - name: "Docker Desktop experimental prefix", - baseURL: "http://_/exp/vDD4.40/engines/v1", - expected: "http://_/exp/vDD4.40/engines/_unload", - }, - { - name: "backend-scoped path", - baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/", - expected: "http://127.0.0.1:12434/engines/llama.cpp/_unload", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - assert.Equal(t, tt.expected, defaultUnloadURL(tt.baseURL)) - }) - } -} - -// newClient builds a [Client] just well-formed enough to drive Unload. -func newClient(baseURL string, httpClient *http.Client, cfg latest.ModelConfig) *Client { - return &Client{ - Config: base.Config{ModelConfig: cfg}, - baseURL: baseURL, - httpClient: httpClient, - } -} - -// TestClientUnload exercises the full Unload path end-to-end against an -// httptest server, covering: default URL derivation, user-configured -// unload_api override, non-2xx error surfacing, and the no-op branch -// when no endpoint can be determined. -func TestClientUnload(t *testing.T) { - t.Parallel() - - t.Run("posts model id to default unload endpoint", func(t *testing.T) { - t.Parallel() - - var ( - gotPath string - gotBody map[string]string - ) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - body, _ := io.ReadAll(r.Body) - _ = json.Unmarshal(body, &gotBody) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - c := newClient(server.URL+"/engines/llama.cpp/v1", server.Client(), latest.ModelConfig{ - Provider: "dmr", - Model: "ai/qwen3", - }) - - require.NoError(t, c.Unload(t.Context())) - assert.Equal(t, "/engines/llama.cpp/_unload", gotPath) - assert.Equal(t, map[string]string{"model": "ai/qwen3"}, gotBody) - }) - - t.Run("honours user-configured unload_api path", func(t *testing.T) { - t.Parallel() - - var gotPath string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - c := newClient(server.URL+"/engines/v1", server.Client(), latest.ModelConfig{ - Model: "ai/qwen3", - ProviderOpts: map[string]any{"unload_api": "/custom/unload"}, - }) - - require.NoError(t, c.Unload(t.Context())) - assert.Equal(t, "/custom/unload", gotPath) - }) - - t.Run("returns error on non-2xx", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte("boom")) - })) - defer server.Close() - - c := newClient(server.URL+"/engines/v1", server.Client(), latest.ModelConfig{ - Model: "ai/qwen3", - }) - - err := c.Unload(t.Context()) - require.Error(t, err) - assert.Contains(t, err.Error(), "500") - assert.Contains(t, err.Error(), "boom") - }) - - t.Run("no-op when neither base URL nor unload_api are set", func(t *testing.T) { - t.Parallel() - c := newClient("", nil, latest.ModelConfig{Model: "ai/qwen3"}) - require.NoError(t, c.Unload(t.Context())) - }) - - t.Run("drains success body so connection can be reused", func(t *testing.T) { - t.Parallel() - - // httptest's default server uses keep-alive; a non-drained body would - // force a fresh TCP connection on the second call. We assert that the - // underlying transport saw a single connection across two POSTs. - var seen sync.Map - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - seen.Store(r.RemoteAddr, struct{}{}) - // Write a non-trivial body so the drain actually has work to do. - _, _ = w.Write([]byte(`{"status":"ok"}`)) - })) - defer server.Close() - - c := newClient(server.URL+"/engines/v1", server.Client(), latest.ModelConfig{Model: "ai/qwen3"}) - require.NoError(t, c.Unload(t.Context())) - require.NoError(t, c.Unload(t.Context())) - - count := 0 - seen.Range(func(_, _ any) bool { count++; return true }) - assert.Equal(t, 1, count, "both POSTs must reuse the same TCP connection") - }) -} diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index e637fed13..80b46948b 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -71,22 +71,6 @@ type RerankingProvider interface { Rerank(ctx context.Context, query string, documents []types.Document, criteria string) ([]float64, error) } -// Unloader is an optional interface for providers that can release -// the resources held for the configured model. Local inference engines -// (today, Docker Model Runner) implement it; cloud APIs typically -// don't. -// -// The runtime's [unload] on_agent_switch builtin hook calls Unload when -// an agent hands control to another agent. Implementations are -// best-effort and should be idempotent — repeated calls on an -// already-unloaded model must succeed. -// -// [unload]: https://pkg.go.dev/github.com/docker/docker-agent/pkg/runtime#BuiltinUnload -type Unloader interface { - Provider - Unload(ctx context.Context) error -} - // New creates a new provider from a model config. // This is a convenience wrapper for NewWithModels with no models map. func New(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) { diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 63172cf30..b8aae23b5 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -264,9 +264,46 @@ func (r *LocalRuntime) executeOnAgentSwitchHooks(ctx context.Context, a *agent.A FromAgent: fromAgent, ToAgent: toAgent, AgentSwitchKind: kind, + FromAgentModels: r.fromAgentModels(fromAgent), }, nil) } +// fromAgentModels snapshots the previous agent's configured model +// endpoints into the wire-friendly [hooks.ModelEndpoint] form. Hooks +// that act on the previous agent's models (e.g. the stock `unload` +// builtin) read this slice instead of poking at the runtime, so the +// hook payload stays self-contained. +// +// Returns nil — not an empty slice — when there is nothing to ship +// (no fromAgent, the agent isn't on the team, or it has no models) +// so the JSON wire payload omits the field via `omitempty`. +func (r *LocalRuntime) fromAgentModels(fromAgent string) []hooks.ModelEndpoint { + if fromAgent == "" { + return nil + } + from, err := r.team.Agent(fromAgent) + if err != nil { + slog.Debug("on_agent_switch: from-agent lookup failed", + "agent", fromAgent, "error", err) + return nil + } + configured := from.ConfiguredModels() + if len(configured) == 0 { + return nil + } + out := make([]hooks.ModelEndpoint, 0, len(configured)) + for _, p := range configured { + cfg := p.BaseConfig() + out = append(out, hooks.ModelEndpoint{ + Provider: cfg.ModelConfig.Provider, + Model: cfg.ModelConfig.Model, + BaseURL: cfg.BaseURL, + UnloadAPI: cfg.ModelConfig.UnloadAPI(), + }) + } + return out +} + // executeOnSessionResumeHooks fires on_session_resume when the user // explicitly approves continuation past the configured // max_iterations limit. Observational; failures are logged. The hook diff --git a/pkg/runtime/on_agent_switch_test.go b/pkg/runtime/on_agent_switch_test.go index 1ef60e23d..d9e8c0ca7 100644 --- a/pkg/runtime/on_agent_switch_test.go +++ b/pkg/runtime/on_agent_switch_test.go @@ -9,8 +9,12 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/tools" ) // recordingBuiltin captures the [hooks.Input] passed on every dispatch @@ -116,3 +120,83 @@ func TestExecuteOnAgentSwitchHooks_NoopWhenNoHookRegistered(t *testing.T) { // Should be a successful no-op rather than a panic or error. r.executeOnAgentSwitchHooks(t.Context(), a, "s", "root", "next", agentSwitchKindHandoff) } + +// endpointProvider is a minimal [provider.Provider] test double whose +// BaseConfig returns a non-zero base.Config, so we can drive the +// FromAgentModels-population branch of executeOnAgentSwitchHooks. +type endpointProvider struct { + cfg base.Config +} + +func (p *endpointProvider) ID() string { return p.cfg.ID() } + +func (p *endpointProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { + return &mockStream{}, nil +} + +func (p *endpointProvider) BaseConfig() base.Config { return p.cfg } + +// TestExecuteOnAgentSwitchHooks_PopulatesFromAgentModels pins the +// runtime→hook handoff for the new pure-unload contract: when the +// previous agent has configured models, the runtime must ship a +// snapshot of {Provider, Model, BaseURL, UnloadAPI} for each one on +// Input.FromAgentModels so a runtime-free hook can act on them. +func TestExecuteOnAgentSwitchHooks_PopulatesFromAgentModels(t *testing.T) { + t.Parallel() + + rb := &recordingBuiltin{} + prev := &endpointProvider{cfg: base.Config{ + ModelConfig: latest.ModelConfig{ + Provider: "dmr", + Model: "ai/qwen3", + ProviderOpts: map[string]any{"unload_api": "/engines/_unload"}, + }, + BaseURL: "http://127.0.0.1:12434/engines/v1", + }} + next := &mockProvider{id: "test/next", stream: &mockStream{}} + + from := agent.New("root", "instructions", + agent.WithModel(prev), + agent.WithHooks(&hooks.Config{ + OnAgentSwitch: []hooks.Hook{{ + Type: hooks.HookTypeBuiltin, Command: "test_record_agent_switch", + }}, + })) + to := agent.New("planner", "instructions", agent.WithModel(next)) + tm := team.New(team.WithAgents(from, to)) + + r, err := NewLocalRuntime(tm, WithModelStore(mockModelStore{})) + require.NoError(t, err) + require.NoError(t, r.hooksRegistry.RegisterBuiltin("test_record_agent_switch", rb.hook)) + r.buildHooksExecutors() + + r.executeOnAgentSwitchHooks(t.Context(), from, "s", "root", "planner", agentSwitchKindTransferTask) + + got := rb.snapshot() + require.Len(t, got, 1) + require.Len(t, got[0].FromAgentModels, 1, "runtime must ship one ModelEndpoint per configured model") + ep := got[0].FromAgentModels[0] + assert.Equal(t, "dmr", ep.Provider) + assert.Equal(t, "ai/qwen3", ep.Model) + assert.Equal(t, "http://127.0.0.1:12434/engines/v1", ep.BaseURL) + assert.Equal(t, "/engines/_unload", ep.UnloadAPI) +} + +// TestExecuteOnAgentSwitchHooks_FromAgentModelsNilWhenFromEmpty pins +// the cheap path on the very first switch into the team's default +// agent: there is no previous agent, so FromAgentModels must stay nil +// (no team lookup, no allocation) and the JSON wire form omits the +// field via `omitempty`. +func TestExecuteOnAgentSwitchHooks_FromAgentModelsNilWhenFromEmpty(t *testing.T) { + t.Parallel() + + r, rb := runtimeWithRecordedAgentSwitch(t, "root") + a := r.CurrentAgent() + require.NotNil(t, a) + + r.executeOnAgentSwitchHooks(t.Context(), a, "s", "", "root", agentSwitchKindHandoff) + + got := rb.snapshot() + require.Len(t, got, 1) + assert.Nil(t, got[0].FromAgentModels) +} diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 474d508d8..eaf16eb66 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -437,14 +437,6 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { return nil, fmt.Errorf("register %q builtin: %w", BuiltinCacheResponse, err) } - // unload is registered alongside cache_response for the same - // reason: it needs to walk Input.FromAgent up to the previous agent's - // configured models and dispatch to provider.Unloader implementations, - // which the runtime owns through the team. - if err := hooksRegistry.RegisterBuiltin(BuiltinUnload, r.unloadBuiltin); err != nil { - return nil, fmt.Errorf("register %q builtin: %w", BuiltinUnload, err) - } - // stripUnsupportedModalitiesTransform captures the runtime closure to // resolve the agent from Input.AgentName, so it lives here rather // than as a stateless builtin in pkg/hooks/builtins. It drops image diff --git a/pkg/runtime/unload.go b/pkg/runtime/unload.go deleted file mode 100644 index 6c0fe2f55..000000000 --- a/pkg/runtime/unload.go +++ /dev/null @@ -1,55 +0,0 @@ -package runtime - -import ( - "context" - "log/slog" - "time" - - "github.com/docker/docker-agent/pkg/hooks" - "github.com/docker/docker-agent/pkg/model/provider" -) - -// BuiltinUnload is the on_agent_switch builtin that asks every model on -// the previous agent to release its resources. Opt in via: -// -// hooks: -// on_agent_switch: -// - type: builtin -// command: unload -// -// Today only Docker Model Runner ships a [provider.Unloader]; other -// providers are silently skipped, so wiring the builtin on a -// cross-provider chain is harmless. -const BuiltinUnload = "unload" - -// unloadTimeout caps each Unload call so a stalled engine cannot stall -// agent switching. -const unloadTimeout = 10 * time.Second - -// unloadBuiltin calls Unload on every [provider.Unloader] of the -// previous agent. Errors are logged but never propagated — agent -// switching must never block on a slow or unreachable engine. -func (r *LocalRuntime) unloadBuiltin(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { - if in.FromAgent == "" || in.FromAgent == in.ToAgent { - return nil, nil - } - from, err := r.team.Agent(in.FromAgent) - if err != nil { - slog.DebugContext(ctx, "unload: from-agent lookup failed", - "agent", in.FromAgent, "error", err) - return nil, nil - } - for _, p := range from.ConfiguredModels() { - u, ok := p.(provider.Unloader) - if !ok { - continue - } - callCtx, cancel := context.WithTimeout(ctx, unloadTimeout) - if err := u.Unload(callCtx); err != nil { - slog.WarnContext(ctx, "unload: provider unload failed", - "agent", from.Name(), "model", p.ID(), "error", err) - } - cancel() - } - return nil, nil -} diff --git a/pkg/runtime/unload_test.go b/pkg/runtime/unload_test.go deleted file mode 100644 index 1eda21b3d..000000000 --- a/pkg/runtime/unload_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package runtime - -import ( - "context" - "errors" - "sync/atomic" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/docker/docker-agent/pkg/agent" - "github.com/docker/docker-agent/pkg/chat" - "github.com/docker/docker-agent/pkg/config/latest" - "github.com/docker/docker-agent/pkg/hooks" - "github.com/docker/docker-agent/pkg/model/provider" - "github.com/docker/docker-agent/pkg/model/provider/base" - "github.com/docker/docker-agent/pkg/team" - "github.com/docker/docker-agent/pkg/tools" -) - -// unloadingProvider is a minimal [provider.Unloader] test double that -// counts Unload calls and (optionally) returns a configured error so -// tests can assert call sites and best-effort error swallowing. -type unloadingProvider struct { - id string - calls atomic.Int64 - unloadErr error -} - -func (p *unloadingProvider) ID() string { return p.id } - -func (p *unloadingProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { - return &mockStream{}, nil -} - -func (p *unloadingProvider) BaseConfig() base.Config { - return base.Config{ModelConfig: latest.ModelConfig{Provider: "test", Model: p.id}} -} - -func (p *unloadingProvider) Unload(_ context.Context) error { - p.calls.Add(1) - return p.unloadErr -} - -var _ provider.Unloader = (*unloadingProvider)(nil) - -// newUnloadingRuntime wires up agents named after the supplied provider -// IDs and returns a runtime ready to drive [unloadBuiltin]. Each name -// gets its own [unloadingProvider] reachable through the returned map. -func newUnloadingRuntime(t *testing.T, names ...string) (*LocalRuntime, map[string]*unloadingProvider) { - t.Helper() - provs := make(map[string]*unloadingProvider, len(names)) - agents := make([]*agent.Agent, len(names)) - for i, name := range names { - p := &unloadingProvider{id: name} - provs[name] = p - agents[i] = agent.New(name, "instructions", agent.WithModel(p)) - } - rt, err := NewLocalRuntime(team.New(team.WithAgents(agents...)), - WithSessionCompaction(false), WithModelStore(mockModelStore{})) - require.NoError(t, err) - return rt, provs -} - -func TestUnloadBuiltin(t *testing.T) { - t.Parallel() - - t.Run("calls Unload on the previous agent only", func(t *testing.T) { - t.Parallel() - rt, provs := newUnloadingRuntime(t, "from", "to") - _, err := rt.unloadBuiltin(t.Context(), &hooks.Input{FromAgent: "from", ToAgent: "to"}, nil) - require.NoError(t, err) - assert.Equal(t, int64(1), provs["from"].calls.Load(), "from-agent's model must be unloaded once") - assert.Equal(t, int64(0), provs["to"].calls.Load(), "to-agent's model must NOT be unloaded") - }) - - t.Run("swallows Unload errors so agent switch never blocks", func(t *testing.T) { - t.Parallel() - rt, provs := newUnloadingRuntime(t, "from", "to") - provs["from"].unloadErr = errors.New("engine offline") - out, err := rt.unloadBuiltin(t.Context(), &hooks.Input{FromAgent: "from", ToAgent: "to"}, nil) - require.NoError(t, err) - assert.Nil(t, out) - assert.Equal(t, int64(1), provs["from"].calls.Load()) - }) - - t.Run("no-op when from==to", func(t *testing.T) { - t.Parallel() - rt, provs := newUnloadingRuntime(t, "from") - _, err := rt.unloadBuiltin(t.Context(), &hooks.Input{FromAgent: "from", ToAgent: "from"}, nil) - require.NoError(t, err) - assert.Equal(t, int64(0), provs["from"].calls.Load()) - }) -} - -// TestUnloadBuiltin_Registered guarantees the builtin name is findable -// on the runtime registry so YAML hook entries that reference it -// actually resolve. -func TestUnloadBuiltin_Registered(t *testing.T) { - t.Parallel() - rt, _ := newUnloadingRuntime(t, "loader") - fn, ok := rt.hooksRegistry.LookupBuiltin(BuiltinUnload) - require.True(t, ok, "%q must be registered on the runtime's hook registry", BuiltinUnload) - require.NotNil(t, fn) -} From c090e18ef320d51b82045e1db4695e3315a4d11a Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 17:14:08 +0200 Subject: [PATCH 2/3] refactor(hooks/builtins): tighten the unload builtin Same behaviour, smaller surface: - Collapse the three URL helpers (resolveUnloadURL, rebaseURL, defaultUnloadURL) into one unloadURL(ModelEndpoint) function. The three branches (absolute override, relative override, default derivation) are now visible side-by-side in a single switch. - Extract per-model work into unloadOne so the per-call timeout uses defer cancel() and the call site collapses to a single warn-log on failure. - Drop the unused unloadProviderDMR constant (one literal, one site). - Drop the unused *http.Client parameter from the POST helper; it was always http.DefaultClient, which is reachable from httptest servers anyway. - Drop the redundant len(configured)==0 early return in the runtime's fromAgentModels: the loop handles the empty case, and an empty vs nil slice are wire-equivalent under `omitempty`. - Tests: replace the two URL-helper tables with one TestUnloadURL table covering every branch; collapse the three no-op tests (empty FromAgent, equal From/To, non-DMR providers, no endpoint) into one table-driven TestUnload_NoOpInputs; add a small dmrInput builder so each test highlights only the field it cares about. Net 55 lines removed; `task lint` clean; `task test` green. --- pkg/hooks/builtins/unload.go | 134 +++++------- pkg/hooks/builtins/unload_test.go | 339 ++++++++++++++---------------- pkg/runtime/hooks.go | 8 +- 3 files changed, 213 insertions(+), 268 deletions(-) diff --git a/pkg/hooks/builtins/unload.go b/pkg/hooks/builtins/unload.go index 0c61216ad..973f3a40a 100644 --- a/pkg/hooks/builtins/unload.go +++ b/pkg/hooks/builtins/unload.go @@ -33,112 +33,53 @@ import ( // providers that don't expose [hooks.ModelEndpoint.BaseURL]). const Unload = "unload" -// unloadProviderDMR is the only provider type the builtin currently -// knows how to unload against. Other providers ship through a no-op so -// wiring this hook on a heterogeneous chain is harmless. -const unloadProviderDMR = "dmr" - -// unloadTimeout caps each Unload call so a stalled engine cannot stall -// agent switching. Each model gets its own deadline so a slow first -// model can't starve the rest. +// unloadTimeout caps each per-model Unload call so a stalled engine +// cannot stall agent switching. const unloadTimeout = 10 * time.Second // unload iterates the [hooks.Input.FromAgentModels] snapshot the // runtime captured at dispatch time and POSTs `{"model": ""}` to -// the resolved unload endpoint of each one we know how to unload. -// Errors are logged but never propagated — agent switching must never -// block on a slow or unreachable engine. +// the resolved unload endpoint of each DMR model. Errors are logged +// but never propagated — agent switching must never block on a slow +// or unreachable engine. func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { if in == nil || in.FromAgent == "" || in.FromAgent == in.ToAgent { return nil, nil } for _, m := range in.FromAgentModels { - if m.Provider != unloadProviderDMR { - continue - } - endpoint, err := resolveUnloadURL(m.BaseURL, m.UnloadAPI) - if err != nil { - slog.WarnContext(ctx, "unload: resolving endpoint failed", - "agent", in.FromAgent, "model", m.Model, "error", err) + if m.Provider != "dmr" { continue } - if endpoint == "" { - continue - } - callCtx, cancel := context.WithTimeout(ctx, unloadTimeout) - if err := postUnloadModel(callCtx, http.DefaultClient, endpoint, m.Model); err != nil { - slog.WarnContext(ctx, "unload: provider unload failed", + if err := unloadOne(ctx, m); err != nil { + slog.WarnContext(ctx, "unload: failed", "agent", in.FromAgent, "model", m.Model, "error", err) } - cancel() } return nil, nil } -// resolveUnloadURL picks the unload endpoint for one model: -// the configured `unload_api` override (rebased against baseURL when -// it isn't already absolute) wins, otherwise [defaultUnloadURL] -// derives one from baseURL by replacing the trailing /v1 segment. -// Returns ("", nil) when no endpoint can be determined so the caller -// can skip without erroring. -func resolveUnloadURL(baseURL, override string) (string, error) { - if override != "" { - return rebaseURL(baseURL, override) - } - if baseURL == "" { - return "", nil - } - return defaultUnloadURL(baseURL), nil -} - -// defaultUnloadURL derives the `_unload` endpoint URL from the OpenAI -// base URL by replacing the trailing `/v1` segment, mirroring how the -// DMR client derives `_configure`: -// -// http://host:port/engines/v1/ → http://host:port/engines/_unload -// http://host:port/engines/llama.cpp/v1/ → http://host:port/engines/llama.cpp/_unload -// http://_/exp/vDD4.40/engines/v1 → http://_/exp/vDD4.40/engines/_unload -func defaultUnloadURL(baseURL string) string { - u, err := url.Parse(baseURL) - if err != nil { - return strings.TrimSuffix(strings.TrimSuffix(baseURL, "/"), "/v1") + "/_unload" +// unloadOne resolves the unload URL for m and POSTs the model id to +// it, bounded by [unloadTimeout]. A model with no resolvable endpoint +// (no base_url and no unload_api) is a silent no-op so the hook stays +// harmless on test / in-process providers. +func unloadOne(parent context.Context, m hooks.ModelEndpoint) error { + endpoint, err := unloadURL(m) + if err != nil || endpoint == "" { + return err } - u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload" - return u.String() -} - -// rebaseURL returns path verbatim if it is already an absolute URL, -// otherwise attaches it to baseURL's scheme + host (dropping any path -// baseURL may carry). This lets users point base_url at e.g. -// http://localhost:12434/engines/v1 and override unload_api with -// /engines/_unload without the version prefix bleeding through. -func rebaseURL(baseURL, path string) (string, error) { - if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { - return path, nil - } - u, err := url.Parse(baseURL) - if err != nil || u.Scheme == "" || u.Host == "" { - return "", fmt.Errorf("base_url %q is not absolute; cannot resolve %q", baseURL, path) - } - if !strings.HasPrefix(path, "/") { - path = "/" + path - } - return u.Scheme + "://" + u.Host + path, nil -} - -// postUnloadModel issues `POST ` with body `{"model": ""}`. -func postUnloadModel(ctx context.Context, client *http.Client, endpoint, modelID string) error { - body, _ := json.Marshal(map[string]string{"model": modelID}) + ctx, cancel := context.WithTimeout(parent, unloadTimeout) + defer cancel() + body, _ := json.Marshal(map[string]string{"model": m.Model}) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) if err != nil { return fmt.Errorf("building unload request: %w", err) } req.Header.Set("Content-Type", "application/json") - slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", modelID) + slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", m.Model) - resp, err := client.Do(req) + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("calling unload endpoint %s: %w", endpoint, err) } @@ -155,3 +96,36 @@ func postUnloadModel(ctx context.Context, client *http.Client, endpoint, modelID _, _ = io.Copy(io.Discard, resp.Body) return nil } + +// unloadURL picks the unload endpoint for one model, in this order: +// +// 1. unload_api is an absolute URL — used verbatim (lets users point +// at a different host than the model's base_url); +// 2. unload_api is set but relative — rebased onto base_url's +// scheme + host (the model's path is dropped); +// 3. unload_api is unset — the default `_unload` URL is derived from +// base_url by replacing its trailing `/v1` segment. +// +// Returns ("", nil) when neither base_url nor unload_api is set, so +// the caller can skip without erroring. +func unloadURL(m hooks.ModelEndpoint) (string, error) { + if strings.HasPrefix(m.UnloadAPI, "http://") || strings.HasPrefix(m.UnloadAPI, "https://") { + return m.UnloadAPI, nil + } + if m.BaseURL == "" && m.UnloadAPI == "" { + return "", nil + } + u, err := url.Parse(m.BaseURL) + if err != nil || u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("base_url %q is not absolute; cannot resolve unload endpoint", m.BaseURL) + } + switch { + case m.UnloadAPI == "": + u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload" + case strings.HasPrefix(m.UnloadAPI, "/"): + u.Path = m.UnloadAPI + default: + u.Path = "/" + m.UnloadAPI + } + return u.String(), nil +} diff --git a/pkg/hooks/builtins/unload_test.go b/pkg/hooks/builtins/unload_test.go index 3cd9d8ad8..50e12f9c8 100644 --- a/pkg/hooks/builtins/unload_test.go +++ b/pkg/hooks/builtins/unload_test.go @@ -14,6 +14,44 @@ import ( "github.com/docker/docker-agent/pkg/hooks" ) +// dmrInput builds an on_agent_switch [hooks.Input] carrying a single +// DMR ModelEndpoint pointed at server. Callers pass the relative +// override (or "") plus optional opts to tweak the generic transition +// fields (e.g. emptying FromAgent, equating From/To). Centralising +// this removes a chunk of per-test boilerplate without hiding the +// input shape — every field the test cares about is still visible on +// the call site. +func dmrInput(server *httptest.Server, unloadAPI string, opts ...func(*hooks.Input)) *hooks.Input { + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{ + Provider: "dmr", + Model: "ai/qwen3", + BaseURL: server.URL + "/engines/v1", + UnloadAPI: unloadAPI, + }}, + } + for _, opt := range opts { + opt(in) + } + return in +} + +// countingServer returns a recording server whose handler runs `mark` +// on every hit; tests that count calls share the same idiom. +func countingServer(t *testing.T, status int, mark func(*http.Request)) *httptest.Server { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if mark != nil { + mark(r) + } + w.WriteHeader(status) + })) + t.Cleanup(srv.Close) + return srv +} + // TestUnloadBuiltin_Registered guarantees the public name is findable // on a registry built by [Register], so YAML hook entries that name it // actually resolve. @@ -47,15 +85,9 @@ func TestUnload_PostsToDefaultEndpoint(t *testing.T) { })) defer server.Close() - in := &hooks.Input{ - FromAgent: "from", - ToAgent: "to", - FromAgentModels: []hooks.ModelEndpoint{{ - Provider: "dmr", - Model: "ai/qwen3", - BaseURL: server.URL + "/engines/llama.cpp/v1", - }}, - } + in := dmrInput(server, "") + in.FromAgentModels[0].BaseURL = server.URL + "/engines/llama.cpp/v1" + out, err := unload(t.Context(), in, nil) require.NoError(t, err) assert.Nil(t, out, "unload is observational; output must be nil") @@ -70,104 +102,79 @@ func TestUnload_HonoursOverrideUnloadAPI(t *testing.T) { t.Parallel() var gotPath string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path - w.WriteHeader(http.StatusOK) - })) - defer server.Close() + server := countingServer(t, http.StatusOK, func(r *http.Request) { gotPath = r.URL.Path }) - in := &hooks.Input{ - FromAgent: "from", - ToAgent: "to", - FromAgentModels: []hooks.ModelEndpoint{{ - Provider: "dmr", - Model: "ai/qwen3", - BaseURL: server.URL + "/engines/v1", - UnloadAPI: "/custom/unload", - }}, - } - _, err := unload(t.Context(), in, nil) + _, err := unload(t.Context(), dmrInput(server, "/custom/unload"), nil) require.NoError(t, err) assert.Equal(t, "/custom/unload", gotPath) } -// TestUnload_SkipsNonDMRProviders pins the cross-provider safety -// property: wiring `unload` on a heterogeneous chain is harmless -// because non-DMR endpoints are silently skipped. -func TestUnload_SkipsNonDMRProviders(t *testing.T) { +// TestUnload_NoOpInputs pins the cheap-path properties the agent loop +// relies on: the hook MUST NOT fire any HTTP call when the input +// describes a transition where unloading would be wrong (back to the +// same agent, no previous agent, only cloud providers, or a model +// without a resolvable endpoint). Combining these into one table +// makes the no-op contract obvious from the test body. +func TestUnload_NoOpInputs(t *testing.T) { t.Parallel() - var calls atomic.Int64 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - calls.Add(1) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - in := &hooks.Input{ - FromAgent: "from", - ToAgent: "to", - FromAgentModels: []hooks.ModelEndpoint{ - {Provider: "openai", Model: "gpt-4", BaseURL: server.URL}, - {Provider: "anthropic", Model: "claude", BaseURL: server.URL}, + tests := []struct { + name string + in func(*httptest.Server) *hooks.Input + }{ + { + name: "nil input", + in: func(*httptest.Server) *hooks.Input { return nil }, + }, + { + name: "empty FromAgent", + in: func(s *httptest.Server) *hooks.Input { + return dmrInput(s, "", func(in *hooks.Input) { in.FromAgent = "" }) + }, + }, + { + name: "FromAgent equals ToAgent", + in: func(s *httptest.Server) *hooks.Input { + return dmrInput(s, "", func(in *hooks.Input) { + in.FromAgent, in.ToAgent = "same", "same" + }) + }, + }, + { + name: "non-DMR providers only", + in: func(s *httptest.Server) *hooks.Input { + return &hooks.Input{ + FromAgent: "from", ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{ + {Provider: "openai", Model: "gpt-4", BaseURL: s.URL}, + {Provider: "anthropic", Model: "claude", BaseURL: s.URL}, + }, + } + }, + }, + { + name: "DMR model with no endpoint", + in: func(*httptest.Server) *hooks.Input { + return &hooks.Input{ + FromAgent: "from", ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{{Provider: "dmr", Model: "ai/qwen3"}}, + } + }, }, } - _, err := unload(t.Context(), in, nil) - require.NoError(t, err) - assert.Equal(t, int64(0), calls.Load(), "no HTTP call must reach a non-DMR endpoint") -} -// TestUnload_NoopWhenFromEqualsTo documents the no-self-unload guard: -// transferring back into the same agent must not unload the model the -// next turn is about to use. -func TestUnload_NoopWhenFromEqualsTo(t *testing.T) { - t.Parallel() - - var calls atomic.Int64 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - calls.Add(1) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - in := &hooks.Input{ - FromAgent: "same", - ToAgent: "same", - FromAgentModels: []hooks.ModelEndpoint{{ - Provider: "dmr", - Model: "ai/qwen3", - BaseURL: server.URL + "/engines/v1", - }}, - } - _, err := unload(t.Context(), in, nil) - require.NoError(t, err) - assert.Equal(t, int64(0), calls.Load()) -} - -// TestUnload_NoopWhenFromAgentEmpty documents the cheap path: the very -// first switch into the team's default agent has no previous agent and -// must not fire any HTTP call. -func TestUnload_NoopWhenFromAgentEmpty(t *testing.T) { - t.Parallel() - - var calls atomic.Int64 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - calls.Add(1) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var calls atomic.Int64 + server := countingServer(t, http.StatusOK, func(*http.Request) { calls.Add(1) }) - in := &hooks.Input{ - ToAgent: "to", - FromAgentModels: []hooks.ModelEndpoint{{ - Provider: "dmr", - Model: "ai/qwen3", - BaseURL: server.URL + "/engines/v1", - }}, + out, err := unload(t.Context(), tt.in(server), nil) + require.NoError(t, err) + assert.Nil(t, out) + assert.Zero(t, calls.Load(), "no HTTP call must reach the server") + }) } - _, err := unload(t.Context(), in, nil) - require.NoError(t, err) - assert.Equal(t, int64(0), calls.Load()) } // TestUnload_SwallowsServerErrors verifies the best-effort contract: @@ -183,77 +190,80 @@ func TestUnload_SwallowsServerErrors(t *testing.T) { })) defer server.Close() - in := &hooks.Input{ - FromAgent: "from", - ToAgent: "to", - FromAgentModels: []hooks.ModelEndpoint{{ - Provider: "dmr", - Model: "ai/qwen3", - BaseURL: server.URL + "/engines/v1", - }}, - } - out, err := unload(t.Context(), in, nil) + out, err := unload(t.Context(), dmrInput(server, ""), nil) require.NoError(t, err) assert.Nil(t, out) } -// TestUnload_NoopWhenNoEndpoint documents that a DMR model without a -// BaseURL or unload_api is silently skipped (rather than erroring) so -// the hook stays harmless when wired against an in-process / test -// provider that hasn't resolved an HTTP endpoint. -func TestUnload_NoopWhenNoEndpoint(t *testing.T) { - t.Parallel() - - in := &hooks.Input{ - FromAgent: "from", - ToAgent: "to", - FromAgentModels: []hooks.ModelEndpoint{{ - Provider: "dmr", - Model: "ai/qwen3", - }}, - } - out, err := unload(t.Context(), in, nil) - require.NoError(t, err) - assert.Nil(t, out) -} - -func TestRebaseURL(t *testing.T) { +// TestUnloadURL covers every branch of the URL-resolution algorithm in +// one table. Replaces what used to be two separate tables for the now +// inlined `defaultUnloadURL` and `rebaseURL` helpers. +func TestUnloadURL(t *testing.T) { t.Parallel() tests := []struct { name string baseURL string - path string + unloadAPI string want string errContains string // empty ⇒ expect success }{ + // Default derivation (no unload_api set). { - name: "absolute https URL is returned verbatim", - baseURL: "http://anything", - path: "https://api.example.com/unload", - want: "https://api.example.com/unload", + name: "default: standard engines path", + baseURL: "http://127.0.0.1:12434/engines/v1/", + want: "http://127.0.0.1:12434/engines/_unload", }, { - name: "rooted path drops base path", - baseURL: "http://localhost:12434/engines/v1", - path: "/engines/_unload", - want: "http://localhost:12434/engines/_unload", + name: "default: no trailing slash", + baseURL: "http://127.0.0.1:12434/engines/v1", + want: "http://127.0.0.1:12434/engines/_unload", }, { - name: "relative path is rooted", - baseURL: "http://localhost:12434/engines/v1", - path: "engines/_unload", - want: "http://localhost:12434/engines/_unload", + name: "default: Docker Desktop experimental prefix", + baseURL: "http://_/exp/vDD4.40/engines/v1", + want: "http://_/exp/vDD4.40/engines/_unload", }, { - name: "empty base URL with relative path errors", - path: "/engines/_unload", + name: "default: backend-scoped path", + baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/", + want: "http://127.0.0.1:12434/engines/llama.cpp/_unload", + }, + + // Override paths and absolute URLs. + { + name: "override: absolute https URL is returned verbatim", + baseURL: "http://anything", + unloadAPI: "https://api.example.com/unload", + want: "https://api.example.com/unload", + }, + { + name: "override: rooted path drops base path", + baseURL: "http://localhost:12434/engines/v1", + unloadAPI: "/engines/_unload", + want: "http://localhost:12434/engines/_unload", + }, + { + name: "override: relative path is rooted", + baseURL: "http://localhost:12434/engines/v1", + unloadAPI: "engines/_unload", + want: "http://localhost:12434/engines/_unload", + }, + + // Skip / error cases. + { + name: "skip: no base_url and no unload_api", + want: "", + }, + { + name: "error: unload_api set but base_url empty", + unloadAPI: "/engines/_unload", errContains: "is not absolute", }, { - name: "base URL without scheme errors", + name: "error: base_url without scheme", baseURL: "localhost:12434/engines/v1", - path: "/engines/_unload", + unloadAPI: "/engines/_unload", errContains: "is not absolute", }, } @@ -261,7 +271,10 @@ func TestRebaseURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := rebaseURL(tt.baseURL, tt.path) + got, err := unloadURL(hooks.ModelEndpoint{ + BaseURL: tt.baseURL, + UnloadAPI: tt.unloadAPI, + }) if tt.errContains != "" { require.Error(t, err) assert.Contains(t, err.Error(), tt.errContains) @@ -272,41 +285,3 @@ func TestRebaseURL(t *testing.T) { }) } } - -func TestDefaultUnloadURL(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - baseURL string - expected string - }{ - { - name: "standard engines path", - baseURL: "http://127.0.0.1:12434/engines/v1/", - expected: "http://127.0.0.1:12434/engines/_unload", - }, - { - name: "no trailing slash", - baseURL: "http://127.0.0.1:12434/engines/v1", - expected: "http://127.0.0.1:12434/engines/_unload", - }, - { - name: "Docker Desktop experimental prefix", - baseURL: "http://_/exp/vDD4.40/engines/v1", - expected: "http://_/exp/vDD4.40/engines/_unload", - }, - { - name: "backend-scoped path", - baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/", - expected: "http://127.0.0.1:12434/engines/llama.cpp/_unload", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - assert.Equal(t, tt.expected, defaultUnloadURL(tt.baseURL)) - }) - } -} diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index b8aae23b5..ae90899b9 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -274,9 +274,8 @@ func (r *LocalRuntime) executeOnAgentSwitchHooks(ctx context.Context, a *agent.A // builtin) read this slice instead of poking at the runtime, so the // hook payload stays self-contained. // -// Returns nil — not an empty slice — when there is nothing to ship -// (no fromAgent, the agent isn't on the team, or it has no models) -// so the JSON wire payload omits the field via `omitempty`. +// Returns nil when there is no previous agent or the lookup fails so +// the JSON wire payload omits the field via `omitempty`. func (r *LocalRuntime) fromAgentModels(fromAgent string) []hooks.ModelEndpoint { if fromAgent == "" { return nil @@ -288,9 +287,6 @@ func (r *LocalRuntime) fromAgentModels(fromAgent string) []hooks.ModelEndpoint { return nil } configured := from.ConfiguredModels() - if len(configured) == 0 { - return nil - } out := make([]hooks.ModelEndpoint, 0, len(configured)) for _, p := range configured { cfg := p.BaseConfig() From 33378c968468f44ffc3b89fa5446e540351568c4 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Thu, 7 May 2026 18:25:00 +0200 Subject: [PATCH 3/3] refactor(hooks/builtins): let DMR own its unload conventions Address review feedback on #2706: - Export pkg/model/provider/dmr.ProviderType and UnloadURL so the unload builtin becomes a dumb dispatcher and DMR owns the provider literal + the /v1 -> /_unload URL convention (sibling to the existing /_configure helper). Move the URL-resolution table test into the dmr package where the helper now lives. - Gate the FromAgentModels snapshot in executeOnAgentSwitchHooks on executor.Has(EventOnAgentSwitch) so audit-free deployments don't pay the team-lookup + per-model allocation on every agent switch. - Add a mixed-providers test pinning the per-element DMR filter: cloud entries (openai, anthropic) in the snapshot must be silently skipped, only DMR endpoints get POSTed. - Add a comment on http.DefaultClient explaining why the SSRF-safe client used by http_post is wrong here (DMR is loopback). --- pkg/hooks/builtins/unload.go | 49 +++------- pkg/hooks/builtins/unload_test.go | 133 ++++++++------------------ pkg/model/provider/dmr/unload.go | 52 ++++++++++ pkg/model/provider/dmr/unload_test.go | 98 +++++++++++++++++++ pkg/runtime/hooks.go | 10 ++ 5 files changed, 213 insertions(+), 129 deletions(-) create mode 100644 pkg/model/provider/dmr/unload.go create mode 100644 pkg/model/provider/dmr/unload_test.go diff --git a/pkg/hooks/builtins/unload.go b/pkg/hooks/builtins/unload.go index 973f3a40a..41f4df23e 100644 --- a/pkg/hooks/builtins/unload.go +++ b/pkg/hooks/builtins/unload.go @@ -8,11 +8,11 @@ import ( "io" "log/slog" "net/http" - "net/url" "strings" "time" "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/model/provider/dmr" ) // Unload is the registered name of the on_agent_switch builtin that @@ -31,6 +31,11 @@ import ( // net/http. It carries no runtime-side coupling and silently skips any // model whose endpoint isn't reachable as plain HTTP (e.g. cloud // providers that don't expose [hooks.ModelEndpoint.BaseURL]). +// +// Provider dispatch and URL resolution are owned by +// [pkg/model/provider/dmr] (see [dmr.ProviderType] and [dmr.UnloadURL]), +// so this builtin stays a dumb dispatcher and DMR keeps full control +// of its conventions. const Unload = "unload" // unloadTimeout caps each per-model Unload call so a stalled engine @@ -47,7 +52,7 @@ func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, er return nil, nil } for _, m := range in.FromAgentModels { - if m.Provider != "dmr" { + if m.Provider != dmr.ProviderType { continue } if err := unloadOne(ctx, m); err != nil { @@ -63,7 +68,7 @@ func unload(ctx context.Context, in *hooks.Input, _ []string) (*hooks.Output, er // (no base_url and no unload_api) is a silent no-op so the hook stays // harmless on test / in-process providers. func unloadOne(parent context.Context, m hooks.ModelEndpoint) error { - endpoint, err := unloadURL(m) + endpoint, err := dmr.UnloadURL(m.BaseURL, m.UnloadAPI) if err != nil || endpoint == "" { return err } @@ -79,6 +84,11 @@ func unloadOne(parent context.Context, m hooks.ModelEndpoint) error { slog.DebugContext(ctx, "Unloading model", "url", endpoint, "model", m.Model) + // Unlike the http_post builtin, the unload target is the + // operator-configured DMR base URL — typically a loopback engine + // (Docker Desktop socket, 127.0.0.1:12434, …). The SSRF-safe + // dialer used by http_post would refuse those addresses by + // design, so we use the default client here. resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("calling unload endpoint %s: %w", endpoint, err) @@ -96,36 +106,3 @@ func unloadOne(parent context.Context, m hooks.ModelEndpoint) error { _, _ = io.Copy(io.Discard, resp.Body) return nil } - -// unloadURL picks the unload endpoint for one model, in this order: -// -// 1. unload_api is an absolute URL — used verbatim (lets users point -// at a different host than the model's base_url); -// 2. unload_api is set but relative — rebased onto base_url's -// scheme + host (the model's path is dropped); -// 3. unload_api is unset — the default `_unload` URL is derived from -// base_url by replacing its trailing `/v1` segment. -// -// Returns ("", nil) when neither base_url nor unload_api is set, so -// the caller can skip without erroring. -func unloadURL(m hooks.ModelEndpoint) (string, error) { - if strings.HasPrefix(m.UnloadAPI, "http://") || strings.HasPrefix(m.UnloadAPI, "https://") { - return m.UnloadAPI, nil - } - if m.BaseURL == "" && m.UnloadAPI == "" { - return "", nil - } - u, err := url.Parse(m.BaseURL) - if err != nil || u.Scheme == "" || u.Host == "" { - return "", fmt.Errorf("base_url %q is not absolute; cannot resolve unload endpoint", m.BaseURL) - } - switch { - case m.UnloadAPI == "": - u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload" - case strings.HasPrefix(m.UnloadAPI, "/"): - u.Path = m.UnloadAPI - default: - u.Path = "/" + m.UnloadAPI - } - return u.String(), nil -} diff --git a/pkg/hooks/builtins/unload_test.go b/pkg/hooks/builtins/unload_test.go index 50e12f9c8..959220fb6 100644 --- a/pkg/hooks/builtins/unload_test.go +++ b/pkg/hooks/builtins/unload_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/model/provider/dmr" ) // dmrInput builds an on_agent_switch [hooks.Input] carrying a single @@ -26,7 +27,7 @@ func dmrInput(server *httptest.Server, unloadAPI string, opts ...func(*hooks.Inp FromAgent: "from", ToAgent: "to", FromAgentModels: []hooks.ModelEndpoint{{ - Provider: "dmr", + Provider: dmr.ProviderType, Model: "ai/qwen3", BaseURL: server.URL + "/engines/v1", UnloadAPI: unloadAPI, @@ -109,6 +110,43 @@ func TestUnload_HonoursOverrideUnloadAPI(t *testing.T) { assert.Equal(t, "/custom/unload", gotPath) } +// TestUnload_FiltersPerElement pins the per-element provider filter: +// when the snapshot mixes DMR and non-DMR endpoints, only the DMR +// ones are POSTed to. The non-DMR entries (cloud providers without a +// reachable unload endpoint) must be silently skipped, not errored +// on, not POSTed to a fabricated URL. +func TestUnload_FiltersPerElement(t *testing.T) { + t.Parallel() + + var gotModels []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body struct { + Model string `json:"model"` + } + _ = json.NewDecoder(r.Body).Decode(&body) + gotModels = append(gotModels, body.Model) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + in := &hooks.Input{ + FromAgent: "from", + ToAgent: "to", + FromAgentModels: []hooks.ModelEndpoint{ + {Provider: "openai", Model: "gpt-4", BaseURL: "https://api.openai.com/v1"}, + {Provider: dmr.ProviderType, Model: "ai/qwen3", BaseURL: server.URL + "/engines/v1"}, + {Provider: "anthropic", Model: "claude", BaseURL: "https://api.anthropic.com"}, + {Provider: dmr.ProviderType, Model: "ai/llama3.2", BaseURL: server.URL + "/engines/llama.cpp/v1"}, + }, + } + + out, err := unload(t.Context(), in, nil) + require.NoError(t, err) + assert.Nil(t, out) + assert.ElementsMatch(t, []string{"ai/qwen3", "ai/llama3.2"}, gotModels, + "only DMR models must be POSTed; cloud providers must be silently skipped") +} + // TestUnload_NoOpInputs pins the cheap-path properties the agent loop // relies on: the hook MUST NOT fire any HTTP call when the input // describes a transition where unloading would be wrong (back to the @@ -157,7 +195,7 @@ func TestUnload_NoOpInputs(t *testing.T) { in: func(*httptest.Server) *hooks.Input { return &hooks.Input{ FromAgent: "from", ToAgent: "to", - FromAgentModels: []hooks.ModelEndpoint{{Provider: "dmr", Model: "ai/qwen3"}}, + FromAgentModels: []hooks.ModelEndpoint{{Provider: dmr.ProviderType, Model: "ai/qwen3"}}, } }, }, @@ -194,94 +232,3 @@ func TestUnload_SwallowsServerErrors(t *testing.T) { require.NoError(t, err) assert.Nil(t, out) } - -// TestUnloadURL covers every branch of the URL-resolution algorithm in -// one table. Replaces what used to be two separate tables for the now -// inlined `defaultUnloadURL` and `rebaseURL` helpers. -func TestUnloadURL(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - baseURL string - unloadAPI string - want string - errContains string // empty ⇒ expect success - }{ - // Default derivation (no unload_api set). - { - name: "default: standard engines path", - baseURL: "http://127.0.0.1:12434/engines/v1/", - want: "http://127.0.0.1:12434/engines/_unload", - }, - { - name: "default: no trailing slash", - baseURL: "http://127.0.0.1:12434/engines/v1", - want: "http://127.0.0.1:12434/engines/_unload", - }, - { - name: "default: Docker Desktop experimental prefix", - baseURL: "http://_/exp/vDD4.40/engines/v1", - want: "http://_/exp/vDD4.40/engines/_unload", - }, - { - name: "default: backend-scoped path", - baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/", - want: "http://127.0.0.1:12434/engines/llama.cpp/_unload", - }, - - // Override paths and absolute URLs. - { - name: "override: absolute https URL is returned verbatim", - baseURL: "http://anything", - unloadAPI: "https://api.example.com/unload", - want: "https://api.example.com/unload", - }, - { - name: "override: rooted path drops base path", - baseURL: "http://localhost:12434/engines/v1", - unloadAPI: "/engines/_unload", - want: "http://localhost:12434/engines/_unload", - }, - { - name: "override: relative path is rooted", - baseURL: "http://localhost:12434/engines/v1", - unloadAPI: "engines/_unload", - want: "http://localhost:12434/engines/_unload", - }, - - // Skip / error cases. - { - name: "skip: no base_url and no unload_api", - want: "", - }, - { - name: "error: unload_api set but base_url empty", - unloadAPI: "/engines/_unload", - errContains: "is not absolute", - }, - { - name: "error: base_url without scheme", - baseURL: "localhost:12434/engines/v1", - unloadAPI: "/engines/_unload", - errContains: "is not absolute", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got, err := unloadURL(hooks.ModelEndpoint{ - BaseURL: tt.baseURL, - UnloadAPI: tt.unloadAPI, - }) - if tt.errContains != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errContains) - return - } - require.NoError(t, err) - assert.Equal(t, tt.want, got) - }) - } -} diff --git a/pkg/model/provider/dmr/unload.go b/pkg/model/provider/dmr/unload.go new file mode 100644 index 000000000..07c2f6c8d --- /dev/null +++ b/pkg/model/provider/dmr/unload.go @@ -0,0 +1,52 @@ +package dmr + +import ( + "fmt" + "net/url" + "strings" +) + +// ProviderType is the canonical [latest.ModelConfig.Provider] value +// for Docker Model Runner. Exported so callers outside the package +// (e.g. the `unload` hook builtin) can dispatch on provider type +// without hard-coding the literal. +const ProviderType = "dmr" + +// UnloadURL resolves the URL of the per-model unload endpoint for a +// DMR-served model, given the resolved provider base URL and the +// per-model `unload_api` override (both as they appear on +// [hooks.ModelEndpoint]). +// +// Resolution order: +// +// 1. unloadAPI is an absolute URL — used verbatim (lets users point +// at a different host than baseURL); +// 2. unloadAPI is set but relative — rebased onto baseURL's +// scheme + host (the model's path is dropped); +// 3. unloadAPI is unset — the default `_unload` URL is derived from +// baseURL by replacing its trailing `/v1` segment, mirroring the +// `/v1` → `/_configure` convention the configure path uses. +// +// Returns ("", nil) when neither baseURL nor unloadAPI is set, so the +// caller can skip without erroring (in-process / test providers). +func UnloadURL(baseURL, unloadAPI string) (string, error) { + if strings.HasPrefix(unloadAPI, "http://") || strings.HasPrefix(unloadAPI, "https://") { + return unloadAPI, nil + } + if baseURL == "" && unloadAPI == "" { + return "", nil + } + u, err := url.Parse(baseURL) + if err != nil || u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("base_url %q is not absolute; cannot resolve unload endpoint", baseURL) + } + switch { + case unloadAPI == "": + u.Path = strings.TrimSuffix(strings.TrimSuffix(u.Path, "/"), "/v1") + "/_unload" + case strings.HasPrefix(unloadAPI, "/"): + u.Path = unloadAPI + default: + u.Path = "/" + unloadAPI + } + return u.String(), nil +} diff --git a/pkg/model/provider/dmr/unload_test.go b/pkg/model/provider/dmr/unload_test.go new file mode 100644 index 000000000..ed6180469 --- /dev/null +++ b/pkg/model/provider/dmr/unload_test.go @@ -0,0 +1,98 @@ +package dmr + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestUnloadURL covers every branch of the URL-resolution algorithm +// in one place. The builtin and any other consumer of [UnloadURL] +// rely on these properties, so the test pins the convention here +// where DMR owns it (sibling to [buildConfigureURL]'s `_configure` +// math). +func TestUnloadURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baseURL string + unloadAPI string + want string + errContains string // empty ⇒ expect success + }{ + // Default derivation (no unload_api set). + { + name: "default: standard engines path", + baseURL: "http://127.0.0.1:12434/engines/v1/", + want: "http://127.0.0.1:12434/engines/_unload", + }, + { + name: "default: no trailing slash", + baseURL: "http://127.0.0.1:12434/engines/v1", + want: "http://127.0.0.1:12434/engines/_unload", + }, + { + name: "default: Docker Desktop experimental prefix", + baseURL: "http://_/exp/vDD4.40/engines/v1", + want: "http://_/exp/vDD4.40/engines/_unload", + }, + { + name: "default: backend-scoped path", + baseURL: "http://127.0.0.1:12434/engines/llama.cpp/v1/", + want: "http://127.0.0.1:12434/engines/llama.cpp/_unload", + }, + + // Override paths and absolute URLs. + { + name: "override: absolute https URL is returned verbatim", + baseURL: "http://anything", + unloadAPI: "https://api.example.com/unload", + want: "https://api.example.com/unload", + }, + { + name: "override: rooted path drops base path", + baseURL: "http://localhost:12434/engines/v1", + unloadAPI: "/engines/_unload", + want: "http://localhost:12434/engines/_unload", + }, + { + name: "override: relative path is rooted", + baseURL: "http://localhost:12434/engines/v1", + unloadAPI: "engines/_unload", + want: "http://localhost:12434/engines/_unload", + }, + + // Skip / error cases. + { + name: "skip: no base_url and no unload_api", + want: "", + }, + { + name: "error: unload_api set but base_url empty", + unloadAPI: "/engines/_unload", + errContains: "is not absolute", + }, + { + name: "error: base_url without scheme", + baseURL: "localhost:12434/engines/v1", + unloadAPI: "/engines/_unload", + errContains: "is not absolute", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := UnloadURL(tt.baseURL, tt.unloadAPI) + if tt.errContains != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index ae90899b9..b6e795e8d 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -258,7 +258,17 @@ const ( // changes the active agent. Observational; failures are logged. The // hook runs alongside the existing [AgentSwitching] event, so users // who already consume that event see no behaviour change. +// +// The previous agent's model-endpoint snapshot is built only when at +// least one hook is configured for this event so audit-free +// deployments don't pay the team-lookup + per-model allocation on +// every agent switch (matches the cheap-when-unused property of the +// other hook callsites). func (r *LocalRuntime) executeOnAgentSwitchHooks(ctx context.Context, a *agent.Agent, sessionID, fromAgent, toAgent, kind string) { + exec := r.hooksExec(a) + if exec == nil || !exec.Has(hooks.EventOnAgentSwitch) { + return + } r.dispatchHook(ctx, a, hooks.EventOnAgentSwitch, &hooks.Input{ SessionID: sessionID, FromAgent: fromAgent,