Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions genkit-tools/common/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export * from './document';
export * from './env';
export * from './eval';
export * from './evaluator';
export * from './middleware';
export * from './model';
export * from './prompt';
export * from './retriever';
Expand Down
36 changes: 36 additions & 0 deletions genkit-tools/common/src/types/middleware.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { z } from 'zod';

/** Descriptor for a registered middleware, returned by reflection API. */
export const MiddlewareDescSchema = z.object({
/** Unique name of the middleware. */
name: z.string(),
/** Human-readable description of what the middleware does. */
description: z.string().optional(),
/** JSON Schema for the middleware's configuration. */
configSchema: z.record(z.any()).nullish(),
});
export type MiddlewareDesc = z.infer<typeof MiddlewareDescSchema>;

/** Reference to a registered middleware with optional configuration. */
export const MiddlewareRefSchema = z.object({
/** Name of the registered middleware. */
name: z.string(),
/** Configuration for the middleware (schema defined by the middleware). */
config: z.any().optional(),
});
export type MiddlewareRef = z.infer<typeof MiddlewareRefSchema>;
3 changes: 3 additions & 0 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
import { z } from 'zod';
import { DocumentDataSchema } from './document';
import { MiddlewareRefSchema } from './middleware';
import {
CustomPartSchema,
DataPartSchema,
Expand Down Expand Up @@ -399,5 +400,7 @@ export const GenerateActionOptionsSchema = z.object({
maxTurns: z.number().optional(),
/** Custom step name for this generate call to display in trace views. Defaults to "generate". */
stepName: z.string().optional(),
/** Middleware to apply to this generation. */
use: z.array(MiddlewareRefSchema).optional(),
});
export type GenerateActionOptions = z.infer<typeof GenerateActionOptionsSchema>;
45 changes: 45 additions & 0 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,45 @@
],
"additionalProperties": false
},
"MiddlewareDesc": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"description": {
"type": "string"
},
"configSchema": {
"anyOf": [
{
"type": "object",
"additionalProperties": {}
},
{
"type": "null"
}
]
}
},
"required": [
"name"
],
"additionalProperties": false
},
"MiddlewareRef": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"config": {}
},
"required": [
"name"
],
"additionalProperties": false
},
"CandidateError": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -466,6 +505,12 @@
},
"stepName": {
"type": "string"
},
"use": {
"type": "array",
"items": {
"$ref": "#/$defs/MiddlewareRef"
}
}
},
"required": [
Expand Down
1 change: 1 addition & 0 deletions genkit-tools/scripts/schema-exporter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const EXPORTED_TYPE_MODULES = [
'../common/src/types/embedder.ts',
'../common/src/types/evaluator.ts',
'../common/src/types/error.ts',
'../common/src/types/middleware.ts',
'../common/src/types/model.ts',
'../common/src/types/parts.ts',
'../common/src/types/reranker.ts',
Expand Down
21 changes: 21 additions & 0 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ type GenerateActionOptions struct {
ToolChoice ToolChoice `json:"toolChoice,omitempty"`
// Tools is a list of registered tool names for this generation if supported.
Tools []string `json:"tools,omitempty"`
// Use is middleware to apply to this generation, referenced by name with optional config.
Use []*MiddlewareRef `json:"use,omitempty"`
}

// GenerateActionResume holds options for resuming an interrupted generation.
Expand Down Expand Up @@ -206,6 +208,25 @@ type Message struct {
Role Role `json:"role,omitempty"`
}

// MiddlewareDesc is the registered descriptor for a middleware.
type MiddlewareDesc struct {
// ConfigSchema is a JSON Schema describing the middleware's configuration.
ConfigSchema map[string]any `json:"configSchema,omitempty"`
// Description explains what the middleware does.
Description string `json:"description,omitempty"`
// Name is the middleware's unique identifier.
Name string `json:"name,omitempty"`
configFromJSON middlewareConfigFunc
}

// MiddlewareRef is a serializable reference to a registered middleware with config.
type MiddlewareRef struct {
// Config contains the middleware configuration.
Config any `json:"config,omitempty"`
// Name is the name of the registered middleware.
Name string `json:"name,omitempty"`
}

// ModelInfo contains metadata about a model's capabilities and characteristics.
type ModelInfo struct {
// ConfigSchema defines the model-specific configuration schema.
Expand Down
124 changes: 120 additions & 4 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ type ModelFunc = core.StreamingFunc[*ModelRequest, *ModelResponse, *ModelRespons
type ModelStreamCallback = func(context.Context, *ModelResponseChunk) error

// ModelMiddleware is middleware for model generate requests that takes in a ModelFunc, does something, then returns another ModelFunc.
//
// Deprecated: Use [Middleware] interface with [WithUse] instead, which supports Generate, Model, and Tool hooks.
type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk]

// model is an action with functions specific to model generation such as Generate().
Expand Down Expand Up @@ -313,13 +315,50 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi
Output: &outputCfg,
}

var middlewareHandlers []Middleware
if len(opts.Use) > 0 {
middlewareHandlers = make([]Middleware, 0, len(opts.Use))
for _, ref := range opts.Use {
desc := LookupMiddleware(r, ref.Name)
if desc == nil {
return nil, core.NewError(core.NOT_FOUND, "ai.GenerateWithRequest: middleware %q not found", ref.Name)
}
configJSON, err := json.Marshal(ref.Config)
if err != nil {
return nil, core.NewError(core.INTERNAL, "ai.GenerateWithRequest: failed to marshal config for middleware %q: %v", ref.Name, err)
}
handler, err := desc.configFromJSON(configJSON)
if err != nil {
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.GenerateWithRequest: failed to create middleware %q: %v", ref.Name, err)
}
middlewareHandlers = append(middlewareHandlers, handler)
}
}

fn := m.Generate
if bm != nil {
if cb != nil {
logger.FromContext(ctx).Warn("background model does not support streaming", "model", bm.Name())
}
fn = backgroundModelToModelFn(bm.Start)
}

if len(middlewareHandlers) > 0 {
modelHook := func(next ModelFunc) ModelFunc {
wrapped := next
for i := len(middlewareHandlers) - 1; i >= 0; i-- {
h := middlewareHandlers[i]
inner := wrapped
wrapped = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
return h.Model(ctx, &ModelState{Request: req, Callback: cb}, func(ctx context.Context, state *ModelState) (*ModelResponse, error) {
return inner(ctx, state.Request, state.Callback)
})
}
}
return wrapped
}
mw = append([]ModelMiddleware{modelHook}, mw...)
}
fn = core.ChainMiddleware(mw...)(fn)

// Inline recursive helper function that captures variables from parent scope.
Expand Down Expand Up @@ -388,7 +427,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi
return nil, core.NewError(core.ABORTED, "exceeded maximum tool call iterations (%d)", maxTurns)
}

newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex)
newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, wrappedCb, currentIndex, middlewareHandlers)
if err != nil {
return nil, err
}
Expand All @@ -406,6 +445,28 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi
})
}

// Wrap generate with the Generate hook chain from middleware.
if len(middlewareHandlers) > 0 {
innerGenerate := generate
generate = func(ctx context.Context, req *ModelRequest, currentTurn int, messageIndex int) (*ModelResponse, error) {
innerFn := func(ctx context.Context, state *GenerateState) (*ModelResponse, error) {
return innerGenerate(ctx, state.Request, currentTurn, messageIndex)
}
for i := len(middlewareHandlers) - 1; i >= 0; i-- {
h := middlewareHandlers[i]
next := innerFn
innerFn = func(ctx context.Context, state *GenerateState) (*ModelResponse, error) {
return h.Generate(ctx, state, next)
}
}
return innerFn(ctx, &GenerateState{
Options: opts,
Request: req,
Iteration: currentTurn,
})
}
}

return generate(ctx, req, 0, 0)
}

Expand Down Expand Up @@ -535,7 +596,27 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod
}
}

// Process resources in messages
if len(genOpts.Use) > 0 {
for _, mw := range genOpts.Use {
name := mw.Name()
if LookupMiddleware(r, name) == nil {
if !r.IsChild() {
r = r.NewChild()
}
DefineMiddleware(r, "", mw)
}
configJSON, err := json.Marshal(mw)
if err != nil {
return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to marshal middleware %q config: %v", name, err)
}
var config any
if err := json.Unmarshal(configJSON, &config); err != nil {
return nil, core.NewError(core.INTERNAL, "ai.Generate: failed to unmarshal middleware %q config: %v", name, err)
}
actionOpts.Use = append(actionOpts.Use, &MiddlewareRef{Name: name, Config: config})
}
}

processedMessages, err := processResources(ctx, r, messages)
if err != nil {
return nil, core.NewError(core.INTERNAL, "ai.Generate: error processing resources: %v", err)
Expand Down Expand Up @@ -773,7 +854,7 @@ func clone[T any](obj *T) *T {
// handleToolRequests processes any tool requests in the response, returning
// either a new request to continue the conversation or nil if no tool requests
// need handling.
func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int) (*ModelRequest, *Message, error) {
func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamCallback, messageIndex int, middlewareHandlers []Middleware) (*ModelRequest, *Message, error) {
toolCount := len(resp.ToolRequests())
if toolCount == 0 {
return nil, nil, nil
Expand All @@ -796,7 +877,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,
return
}

multipartResp, err := tool.RunRawMultipart(ctx, toolReq.Input)
multipartResp, err := runToolWithMiddleware(ctx, tool, toolReq, middlewareHandlers)
if err != nil {
var tie *toolInterruptError
if errors.As(err, &tie) {
Expand Down Expand Up @@ -879,6 +960,40 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest,
return newReq, nil, nil
}

// runToolWithMiddleware runs a tool, wrapping the execution with Tool hooks from middleware.
func runToolWithMiddleware(ctx context.Context, tool Tool, toolReq *ToolRequest, handlers []Middleware) (*MultipartToolResponse, error) {
if len(handlers) == 0 {
return tool.RunRawMultipart(ctx, toolReq.Input)
}

inner := func(ctx context.Context, state *ToolState) (*ToolResponse, error) {
resp, err := state.Tool.RunRawMultipart(ctx, state.Request.Input)
if err != nil {
return nil, err
}
return &ToolResponse{
Name: state.Request.Name,
Output: resp.Output,
Content: resp.Content,
}, nil
}

for i := len(handlers) - 1; i >= 0; i-- {
h := handlers[i]
next := inner
inner = func(ctx context.Context, state *ToolState) (*ToolResponse, error) {
return h.Tool(ctx, state, next)
}
}

toolResp, err := inner(ctx, &ToolState{Request: toolReq, Tool: tool})
if err != nil {
return nil, err
}

return &MultipartToolResponse{Output: toolResp.Output, Content: toolResp.Content}, nil
}

// Text returns the contents of the first candidate in a
// [ModelResponse] as a string. It returns an empty string if there
// are no candidates or if the candidate has no message.
Expand Down Expand Up @@ -1357,6 +1472,7 @@ func handleResumeOption(ctx context.Context, r api.Registry, genOpts *GenerateAc
Docs: genOpts.Docs,
ReturnToolRequests: genOpts.ReturnToolRequests,
Output: genOpts.Output,
Use: genOpts.Use,
},
toolMessage: toolMessage,
}, nil
Expand Down
Loading
Loading