diff --git a/genkit-tools/common/src/types/index.ts b/genkit-tools/common/src/types/index.ts index ea12971f0e..360546af0e 100644 --- a/genkit-tools/common/src/types/index.ts +++ b/genkit-tools/common/src/types/index.ts @@ -23,6 +23,7 @@ export * from './document'; export * from './env'; export * from './eval'; export * from './evaluator'; +export * from './middleware'; export * from './model'; export * from './prompt'; export * from './retriever'; diff --git a/genkit-tools/common/src/types/middleware.ts b/genkit-tools/common/src/types/middleware.ts new file mode 100644 index 0000000000..4bb1297ede --- /dev/null +++ b/genkit-tools/common/src/types/middleware.ts @@ -0,0 +1,36 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { z } from 'zod'; + +/** Descriptor for a registered middleware, returned by reflection API. */ +export const MiddlewareDescSchema = z.object({ + /** Unique name of the middleware. */ + name: z.string(), + /** Human-readable description of what the middleware does. */ + description: z.string().optional(), + /** JSON Schema for the middleware's configuration. */ + configSchema: z.record(z.any()).nullish(), +}); +export type MiddlewareDesc = z.infer; + +/** Reference to a registered middleware with optional configuration. */ +export const MiddlewareRefSchema = z.object({ + /** Name of the registered middleware. */ + name: z.string(), + /** Configuration for the middleware (schema defined by the middleware). */ + config: z.any().optional(), +}); +export type MiddlewareRef = z.infer; diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index a36d9f288f..62fa83dedb 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -15,6 +15,7 @@ */ import { z } from 'zod'; import { DocumentDataSchema } from './document'; +import { MiddlewareRefSchema } from './middleware'; import { CustomPartSchema, DataPartSchema, @@ -399,5 +400,7 @@ export const GenerateActionOptionsSchema = z.object({ maxTurns: z.number().optional(), /** Custom step name for this generate call to display in trace views. Defaults to "generate". */ stepName: z.string().optional(), + /** Middleware to apply to this generation. */ + use: z.array(MiddlewareRefSchema).optional(), }); export type GenerateActionOptions = z.infer; diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 26cc4fbf4f..d808e46df9 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -270,6 +270,45 @@ ], "additionalProperties": false }, + "MiddlewareDesc": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "configSchema": { + "anyOf": [ + { + "type": "object", + "additionalProperties": {} + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "name" + ], + "additionalProperties": false + }, + "MiddlewareRef": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "config": {} + }, + "required": [ + "name" + ], + "additionalProperties": false + }, "CandidateError": { "type": "object", "properties": { @@ -466,6 +505,12 @@ }, "stepName": { "type": "string" + }, + "use": { + "type": "array", + "items": { + "$ref": "#/$defs/MiddlewareRef" + } } }, "required": [ diff --git a/genkit-tools/scripts/schema-exporter.ts b/genkit-tools/scripts/schema-exporter.ts index 48df79b56a..7462a12a8d 100644 --- a/genkit-tools/scripts/schema-exporter.ts +++ b/genkit-tools/scripts/schema-exporter.ts @@ -26,6 +26,7 @@ const EXPORTED_TYPE_MODULES = [ '../common/src/types/embedder.ts', '../common/src/types/evaluator.ts', '../common/src/types/error.ts', + '../common/src/types/middleware.ts', '../common/src/types/model.ts', '../common/src/types/parts.ts', '../common/src/types/reranker.ts', diff --git a/go/ai/gen.go b/go/ai/gen.go index e391ef2215..963b8cd737 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -96,6 +96,8 @@ type GenerateActionOptions struct { ToolChoice ToolChoice `json:"toolChoice,omitempty"` // Tools is a list of registered tool names for this generation if supported. Tools []string `json:"tools,omitempty"` + // Use is middleware to apply to this generation, referenced by name with optional config. + Use []*MiddlewareRef `json:"use,omitempty"` } // GenerateActionResume holds options for resuming an interrupted generation. @@ -206,6 +208,25 @@ type Message struct { Role Role `json:"role,omitempty"` } +// MiddlewareDesc is the registered descriptor for a middleware. +type MiddlewareDesc struct { + // ConfigSchema is a JSON Schema describing the middleware's configuration. + ConfigSchema map[string]any `json:"configSchema,omitempty"` + // Description explains what the middleware does. + Description string `json:"description,omitempty"` + // Name is the middleware's unique identifier. + Name string `json:"name,omitempty"` + configFromJSON middlewareConfigFunc +} + +// MiddlewareRef is a serializable reference to a registered middleware with config. +type MiddlewareRef struct { + // Config contains the middleware configuration. + Config any `json:"config,omitempty"` + // Name is the name of the registered middleware. + Name string `json:"name,omitempty"` +} + // ModelInfo contains metadata about a model's capabilities and characteristics. type ModelInfo struct { // ConfigSchema defines the model-specific configuration schema. diff --git a/go/ai/generate.go b/go/ai/generate.go index 003eb0b653..c6b2a8066a 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -67,6 +67,8 @@ type ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelRespons type ModelStreamCallback = func(context.Context, *ModelResponseChunk) error // ModelMiddleware is middleware for model generate requests that takes in a ModelFunc, does something, then returns another ModelFunc. +// +// Deprecated: Use [Middleware] interface with [WithUse] instead, which supports Generate, Model, and Tool hooks. type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk] // model is an action with functions specific to model generation such as Generate(). @@ -313,6 +315,26 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Output: &outputCfg, } + var middlewareHandlers []Middleware + if len(opts.Use) > 0 { + middlewareHandlers = make([]Middleware, 0, len(opts.Use)) + for _, ref := range opts.Use { + desc := LookupMiddleware(r, ref.Name) + if desc == nil { + return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name) + } + configJSON, err := json.Marshal(ref.Config) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err) + } + handler, err := desc.configFromJSON(configJSON) + if err != nil { + return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err) + } + middlewareHandlers = append(middlewareHandlers, handler) + } + } + fn := m.Generate if bm != nil { if cb != nil { @@ -320,6 +342,23 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi } fn = backgroundModelToModelFn(bm.Start) } + + if len(middlewareHandlers) > 0 { + modelHook := func(next ModelFunc) ModelFunc { + wrapped := next + for i := len(middlewareHandlers) - 1; i >= 0; i-- { + h := middlewareHandlers[i] + inner := wrapped + wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return h.Model(ctx, &ModelState{Request: req, Callback: cb}, func(ctx context.Context, state *ModelState) (*ModelResponse, error) { + return inner(ctx, state.Request, state.Callback) + }) + } + } + return wrapped + } + mw = append([]ModelMiddleware{modelHook}, mw...) + } fn = core.ChainMiddleware(mw...)(fn) // Inline recursive helper function that captures variables from parent scope. @@ -388,7 +427,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns) } - newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex) + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, middlewareHandlers) if err != nil { return nil, err } @@ -406,6 +445,28 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi }) } + // Wrap generate with the Generate hook chain from middleware. + if len(middlewareHandlers) > 0 { + innerGenerate := generate + generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) { + innerFn := func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { + return innerGenerate(ctx, state.Request, currentTurn, messageIndex) + } + for i := len(middlewareHandlers) - 1; i >= 0; i-- { + h := middlewareHandlers[i] + next := innerFn + innerFn = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) { + return h.Generate(ctx, state, next) + } + } + return innerFn(ctx, &GenerateState{ + Options: opts, + Request: req, + Iteration: currentTurn, + }) + } + } + return generate(ctx, req, 0, 0) } @@ -535,7 +596,27 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod } } - // Process resources in messages + if len(genOpts.Use) > 0 { + for _, mw := range genOpts.Use { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + DefineMiddleware(r, "", mw) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to marshal middleware %q config: %v", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to unmarshal middleware %q config: %v", name, err) + } + actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + } + } + processedMessages, err := processResources(ctx, r, messages) if err != nil { return nil, core.NewError(core.INTERNAL, "ai.Generate: error processing resources: %v", err) @@ -773,7 +854,7 @@ func clone[T any](obj *T) *T { // handleToolRequests processes any tool requests in the response, returning // either a new request to continue the conversation or nil if no tool requests // need handling. -func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) { +func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, middlewareHandlers []Middleware) (*ModelRequest, *Message, error) { toolCount := len(resp.ToolRequests()) if toolCount == 0 { return nil, nil, nil @@ -796,7 +877,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return } - multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input) + multipartResp, err := runToolWithMiddleware(ctx, tool, toolReq, middlewareHandlers) if err != nil { var tie *toolInterruptError if errors.As(err, &tie) { @@ -879,6 +960,40 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, return newReq, nil, nil } +// runToolWithMiddleware runs a tool, wrapping the execution with Tool hooks from middleware. +func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, handlers []Middleware) (*MultipartToolResponse, error) { + if len(handlers) == 0 { + return tool.RunRawMultipart(ctx, toolReq.Input) + } + + inner := func(ctx context.Context, state *ToolState) (*ToolResponse, error) { + resp, err := state.Tool.RunRawMultipart(ctx, state.Request.Input) + if err != nil { + return nil, err + } + return &ToolResponse{ + Name: state.Request.Name, + Output: resp.Output, + Content: resp.Content, + }, nil + } + + for i := len(handlers) - 1; i >= 0; i-- { + h := handlers[i] + next := inner + inner = func(ctx context.Context, state *ToolState) (*ToolResponse, error) { + return h.Tool(ctx, state, next) + } + } + + toolResp, err := inner(ctx, &ToolState{Request: toolReq, Tool: tool}) + if err != nil { + return nil, err + } + + return &MultipartToolResponse{Output: toolResp.Output, Content: toolResp.Content}, nil +} + // Text returns the contents of the first candidate in a // [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. @@ -1357,6 +1472,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc Docs: genOpts.Docs, ReturnToolRequests: genOpts.ReturnToolRequests, Output: genOpts.Output, + Use: genOpts.Use, }, toolMessage: toolMessage, }, nil diff --git a/go/ai/middleware.go b/go/ai/middleware.go new file mode 100644 index 0000000000..35b2faf37f --- /dev/null +++ b/go/ai/middleware.go @@ -0,0 +1,144 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" +) + +// middlewareConfigFunc creates a Middleware instance from JSON config. +type middlewareConfigFunc = func([]byte) (Middleware, error) + +// Middleware provides hooks for different stages of generation. +type Middleware interface { + // Name returns the middleware's unique identifier. + Name() string + // New returns a fresh instance for each ai.Generate() call, enabling per-invocation state. + New() Middleware + // Generate wraps each iteration of the tool loop. + Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) + // Model wraps each model API call. + Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) + // Tool wraps each tool execution. + Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) +} + +// GenerateState holds state for the Generate hook. +type GenerateState struct { + // Options is the original options passed to [Generate]. + Options *GenerateActionOptions + // Request is the current model request for this iteration, with accumulated messages. + Request *ModelRequest + // Iteration is the current tool-loop iteration (0-indexed). + Iteration int +} + +// ModelState holds state for the Model hook. +type ModelState struct { + // Request is the model request about to be sent. + Request *ModelRequest + // Callback is the streaming callback, or nil if not streaming. + Callback ModelStreamCallback +} + +// ToolState holds state for the Tool hook. +type ToolState struct { + // Request is the tool request about to be executed. + Request *ToolRequest + // Tool is the resolved tool being called. + Tool Tool +} + +// GenerateNext is the next function in the Generate hook chain. +type GenerateNext = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) + +// ModelNext is the next function in the Model hook chain. +type ModelNext = func(ctx context.Context, state *ModelState) (*ModelResponse, error) + +// ToolNext is the next function in the Tool hook chain. +type ToolNext = func(ctx context.Context, state *ToolState) (*ToolResponse, error) + +// BaseMiddleware provides default pass-through for the three hooks. +// Embed this so you only need to implement Name() and New(). +type BaseMiddleware struct{} + +func (b *BaseMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { + return next(ctx, state) +} + +func (b *BaseMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + return next(ctx, state) +} + +func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { + return next(ctx, state) +} + +// Register registers the descriptor with the registry. +func (d *MiddlewareDesc) Register(r api.Registry) { + r.RegisterValue("/middleware/"+d.Name, d) +} + +// NewMiddleware creates a middleware descriptor without registering it. +// The prototype carries stable state; configFromJSON calls prototype.New() +// then unmarshals user config on top. +func NewMiddleware[T Middleware](description string, prototype T) *MiddlewareDesc { + return &MiddlewareDesc{ + Name: prototype.Name(), + Description: description, + ConfigSchema: core.InferSchemaMap(*new(T)), + configFromJSON: func(configJSON []byte) (Middleware, error) { + inst := prototype.New() + if len(configJSON) > 0 { + if err := json.Unmarshal(configJSON, inst); err != nil { + return nil, fmt.Errorf("middleware %q: %w", prototype.Name(), err) + } + } + return inst, nil + }, + } +} + +// DefineMiddleware creates and registers a middleware descriptor. +func DefineMiddleware[T Middleware](r api.Registry, description string, prototype T) *MiddlewareDesc { + d := NewMiddleware(description, prototype) + d.Register(r) + return d +} + +// LookupMiddleware looks up a registered middleware descriptor by name. +func LookupMiddleware(r api.Registry, name string) *MiddlewareDesc { + v := r.LookupValue("/middleware/" + name) + if v == nil { + return nil + } + d, ok := v.(*MiddlewareDesc) + if !ok { + return nil + } + return d +} + +// MiddlewarePlugin is implemented by plugins that provide middleware. +type MiddlewarePlugin interface { + ListMiddleware(ctx context.Context) ([]*MiddlewareDesc, error) +} diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go new file mode 100644 index 0000000000..a0f9f935ea --- /dev/null +++ b/go/ai/middleware_test.go @@ -0,0 +1,237 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "sync/atomic" + "testing" +) + +// testMiddleware is a simple middleware for testing that tracks hook invocations. +type testMiddleware struct { + BaseMiddleware + Label string `json:"label"` + generateCalls int + modelCalls int + toolCalls int32 // atomic since tool hooks run in parallel +} + +func (m *testMiddleware) Name() string { return "test" } + +func (m *testMiddleware) New() Middleware { + return &testMiddleware{Label: m.Label} +} + +func (m *testMiddleware) Generate(ctx context.Context, state *GenerateState, next GenerateNext) (*ModelResponse, error) { + m.generateCalls++ + return next(ctx, state) +} + +func (m *testMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + m.modelCalls++ + return next(ctx, state) +} + +func (m *testMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error) { + atomic.AddInt32(&m.toolCalls, 1) + return next(ctx, state) +} + +func TestNewMiddleware(t *testing.T) { + proto := &testMiddleware{Label: "original"} + desc := NewMiddleware("test middleware", proto) + + if desc.Name != "test" { + t.Errorf("got name %q, want %q", desc.Name, "test") + } + if desc.Description != "test middleware" { + t.Errorf("got description %q, want %q", desc.Description, "test middleware") + } +} + +func TestDefineAndLookupMiddleware(t *testing.T) { + r := newTestRegistry(t) + proto := &testMiddleware{Label: "original"} + DefineMiddleware(r, "test middleware", proto) + + found := LookupMiddleware(r, "test") + if found == nil { + t.Fatal("expected to find middleware, got nil") + } + if found.Name != "test" { + t.Errorf("got name %q, want %q", found.Name, "test") + } +} + +func TestLookupMiddlewareNotFound(t *testing.T) { + r := newTestRegistry(t) + found := LookupMiddleware(r, "nonexistent") + if found != nil { + t.Errorf("expected nil, got %v", found) + } +} + +func TestConfigFromJSON(t *testing.T) { + proto := &testMiddleware{Label: "stable"} + desc := NewMiddleware("test middleware", proto) + + handler, err := desc.configFromJSON([]byte(`{"label": "custom"}`)) + if err != nil { + t.Fatalf("configFromJSON failed: %v", err) + } + + tm, ok := handler.(*testMiddleware) + if !ok { + t.Fatalf("expected *testMiddleware, got %T", handler) + } + if tm.Label != "custom" { + t.Errorf("got label %q, want %q", tm.Label, "custom") + } + // Per-request state should be zeroed by New() + if tm.generateCalls != 0 { + t.Errorf("got generateCalls %d, want 0", tm.generateCalls) + } +} + +func TestConfigFromJSONPreservesStableState(t *testing.T) { + // Simulate a plugin middleware with unexported stable state + proto := &stableStateMiddleware{apiKey: "secret123"} + desc := NewMiddleware("middleware with stable state", proto) + + handler, err := desc.configFromJSON([]byte(`{"sampleRate": 0.5}`)) + if err != nil { + t.Fatalf("configFromJSON failed: %v", err) + } + + sm, ok := handler.(*stableStateMiddleware) + if !ok { + t.Fatalf("expected *stableStateMiddleware, got %T", handler) + } + if sm.apiKey != "secret123" { + t.Errorf("got apiKey %q, want %q", sm.apiKey, "secret123") + } + if sm.SampleRate != 0.5 { + t.Errorf("got SampleRate %f, want 0.5", sm.SampleRate) + } +} + +func TestMiddlewareModelHook(t *testing.T) { + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + DefineMiddleware(r, "tracks calls", &testMiddleware{}) + + resp, err := Generate(ctx, r, + WithModel(m), + WithPrompt("hello"), + WithUse(&testMiddleware{}), + ) + assertNoError(t, err) + if resp == nil { + t.Fatal("expected response, got nil") + } +} + +func TestMiddlewareToolHook(t *testing.T) { + r := newTestRegistry(t) + defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolModel", + handler: toolCallingModelHandler("myTool", map[string]any{"value": "test"}, "done"), + }) + defineFakeTool(t, r, "myTool", "A test tool") + + mw := &testMiddleware{} + DefineMiddleware(r, "tracks calls", mw) + + _, err := Generate(ctx, r, + WithModelName("test/toolModel"), + WithPrompt("use the tool"), + WithTools(ToolName("myTool")), + WithUse(&testMiddleware{}), + ) + assertNoError(t, err) +} + +func TestMiddlewareOrdering(t *testing.T) { + // First middleware is outermost + var order []string + r := newTestRegistry(t) + m := defineFakeModel(t, r, fakeModelConfig{}) + + mwA := &orderMiddleware{label: "A", order: &order} + mwB := &orderMiddleware{label: "B", order: &order} + DefineMiddleware(r, "middleware A", mwA) + DefineMiddleware(r, "middleware B", mwB) + + _, err := Generate(ctx, r, + WithModel(m), + WithPrompt("hello"), + WithUse( + &orderMiddleware{label: "A", order: &order}, + &orderMiddleware{label: "B", order: &order}, + ), + ) + assertNoError(t, err) + + // Expect: A-before, B-before, B-after, A-after (first is outermost) + want := []string{"A-model-before", "B-model-before", "B-model-after", "A-model-after"} + if len(order) != len(want) { + t.Fatalf("got order %v, want %v", order, want) + } + for i := range want { + if order[i] != want[i] { + t.Errorf("order[%d] = %q, want %q", i, order[i], want[i]) + } + } +} + +// --- helper middleware types for tests --- + +// stableStateMiddleware has unexported stable state preserved by New(). +type stableStateMiddleware struct { + BaseMiddleware + SampleRate float64 `json:"sampleRate"` + apiKey string +} + +func (m *stableStateMiddleware) Name() string { return "stableState" } + +func (m *stableStateMiddleware) New() Middleware { + return &stableStateMiddleware{apiKey: m.apiKey} +} + +// orderMiddleware tracks the order of Model hook invocations. +type orderMiddleware struct { + BaseMiddleware + label string + order *[]string +} + +func (m *orderMiddleware) Name() string { return "order-" + m.label } + +func (m *orderMiddleware) New() Middleware { + return &orderMiddleware{label: m.label, order: m.order} +} + +func (m *orderMiddleware) Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error) { + *m.order = append(*m.order, m.label+"-model-before") + resp, err := next(ctx, state) + *m.order = append(*m.order, m.label+"-model-after") + return resp, err +} + +var ctx = context.Background() diff --git a/go/ai/option.go b/go/ai/option.go index d28c68e3e9..84019b11d7 100644 --- a/go/ai/option.go +++ b/go/ai/option.go @@ -109,7 +109,8 @@ type commonGenOptions struct { ToolChoice ToolChoice // Whether tool calls are required, disabled, or optional. MaxTurns int // Maximum number of tool call iterations. ReturnToolRequests *bool // Whether to return tool requests instead of making the tool calls and continuing the generation. - Middleware []ModelMiddleware // Middleware to apply to the model request and model response. + Middleware []ModelMiddleware // Deprecated: Use WithUse instead. Middleware to apply to the model request and model response. + Use []Middleware // Middleware to apply to generation (Generate, Model, and Tool hooks). } type CommonGenOption interface { @@ -181,6 +182,13 @@ func (o *commonGenOptions) applyCommonGen(opts *commonGenOptions) error { opts.Middleware = o.Middleware } + if o.Use != nil { + if opts.Use != nil { + return errors.New("cannot set middleware more than once (WithUse)") + } + opts.Use = o.Use + } + return nil } @@ -233,10 +241,18 @@ func WithModelName(name string) CommonGenOption { } // WithMiddleware sets middleware to apply to the model request. +// +// Deprecated: Use [WithUse] instead, which supports Generate, Model, and Tool hooks. func WithMiddleware(middleware ...ModelMiddleware) CommonGenOption { return &commonGenOptions{Middleware: middleware} } +// WithUse sets middleware to apply to generation. Middleware hooks wrap +// the generate loop, model calls, and tool executions. +func WithUse(middleware ...Middleware) CommonGenOption { + return &commonGenOptions{Use: middleware} +} + // WithMaxTurns sets the maximum number of tool call iterations before erroring. // A tool call happens when tools are provided in the request and a model decides to call one or more as a response. // Each round trip, including multiple tools in parallel, counts as one turn. diff --git a/go/ai/prompt.go b/go/ai/prompt.go index 4d0151c4c8..9e4dff9f14 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -249,6 +249,27 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod } } + if len(execOpts.Use) > 0 { + for _, mw := range execOpts.Use { + name := mw.Name() + if LookupMiddleware(r, name) == nil { + if !r.IsChild() { + r = r.NewChild() + } + DefineMiddleware(r, "", mw) + } + configJSON, err := json.Marshal(mw) + if err != nil { + return nil, fmt.Errorf("Prompt.Execute: failed to marshal middleware %q config: %w", name, err) + } + var config any + if err := json.Unmarshal(configJSON, &config); err != nil { + return nil, fmt.Errorf("Prompt.Execute: failed to unmarshal middleware %q config: %w", name, err) + } + actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config}) + } + } + return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) } diff --git a/go/core/schemas.config b/go/core/schemas.config index 70798f2eb3..2fe8cc6d54 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -732,6 +732,10 @@ StepName is a custom step name for this generate call to display in trace views. Defaults to "generate". . +GenerateActionOptions.use doc +Use is middleware to apply to this generation, referenced by name with optional config. +. + GenerateActionOptionsResume doc GenerateActionResume holds options for resuming an interrupted generation. . @@ -840,6 +844,38 @@ PathMetadata.error doc Error contains error information if the path failed. . +# ---------------------------------------------------------------------------- +# Middleware Types +# ---------------------------------------------------------------------------- + +MiddlewareDesc doc +MiddlewareDesc is the registered descriptor for a middleware. +. + +MiddlewareDesc.name doc +Name is the middleware's unique identifier. +. + +MiddlewareDesc.description doc +Description explains what the middleware does. +. + +MiddlewareDesc.configSchema doc +ConfigSchema is a JSON Schema describing the middleware's configuration. +. + +MiddlewareRef doc +MiddlewareRef is a serializable reference to a registered middleware with config. +. + +MiddlewareRef.name doc +Name is the name of the registered middleware. +. + +MiddlewareRef.config doc +Config contains the middleware configuration. +. + # ---------------------------------------------------------------------------- # Multipart Tool Response # ---------------------------------------------------------------------------- @@ -1060,6 +1096,7 @@ GenerateActionOptions.config type any GenerateActionOptions.output type *GenerateActionOutputConfig GenerateActionOptions.returnToolRequests type bool GenerateActionOptions.maxTurns type int +GenerateActionOptions.use type []*MiddlewareRef GenerateActionOptionsResume name GenerateActionResume # GenerateActionOutputConfig @@ -1101,6 +1138,12 @@ ModelResponseChunk.index type int ModelResponseChunk.role type Role ModelResponseChunk field formatHandler StreamingFormatHandler +# Middleware +MiddlewareDesc pkg ai +MiddlewareDesc.configSchema type map[string]any +MiddlewareDesc field configFromJSON middlewareConfigFunc +MiddlewareRef pkg ai + Score omit Embedding.embedding type []float32 diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 377fb5e836..8fd32913c2 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -228,6 +228,16 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { action.Register(r) } r.RegisterPlugin(plugin.Name(), plugin) + + if mp, ok := plugin.(ai.MiddlewarePlugin); ok { + descs, err := mp.ListMiddleware(ctx) + if err != nil { + panic(fmt.Errorf("genkit.Init: plugin %q ListMiddleware failed: %w", plugin.Name(), err)) + } + for _, d := range descs { + d.Register(r) + } + } } ai.ConfigureFormats(r) diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 1bd675f75a..9936936e61 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -303,6 +303,7 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux { mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions))) mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify())) mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions))) + mux.HandleFunc("GET /api/values", wrapReflectionHandler(handleListValues(g))) return mux } @@ -598,6 +599,27 @@ func handleListActions(g *Genkit) func(w http.ResponseWriter, r *http.Request) e } } +// handleListValues returns registered values filtered by type query parameter. +// Matches JS: GET /api/values?type=middleware +func handleListValues(g *Genkit) func(w http.ResponseWriter, r *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + valueType := r.URL.Query().Get("type") + if valueType == "" { + http.Error(w, `query parameter "type" is required`, http.StatusBadRequest) + return nil + } + prefix := "/" + valueType + "/" + result := map[string]any{} + for key, val := range g.reg.ListValues() { + if strings.HasPrefix(key, prefix) { + name := strings.TrimPrefix(key, prefix) + result[name] = val + } + } + return writeJSON(r.Context(), w, result) + } +} + // listActions lists all the registered actions. func listActions(g *Genkit) []api.ActionDesc { ads := []api.ActionDesc{} diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index 02f5927450..f67b0e91e2 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -120,6 +120,23 @@ class GenkitError(BaseModel): data: Data | None = None +class MiddlewareDesc(BaseModel): + """Model for middlewaredesc data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str + description: str | None = None + config_schema: dict[str, Any] | None = Field(default=None) + + +class MiddlewareRef(BaseModel): + """Model for middlewareref data.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) + name: str + config: Any | None = None + + class Code(StrEnum): """Code data type class.""" @@ -1002,6 +1019,7 @@ class GenerateActionOptions(BaseModel): return_tool_requests: bool | None = Field(default=None) max_turns: float | None = Field(default=None) step_name: str | None = Field(default=None) + use: list[MiddlewareRef] | None = None class GenerateRequest(BaseModel):