diff --git a/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java b/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java index 3017243b5..ad3a38571 100644 --- a/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java +++ b/ai/src/main/java/com/google/genkit/ai/GenerateOptions.java @@ -18,6 +18,7 @@ package com.google.genkit.ai; +import com.google.genkit.ai.middleware.GenerationMiddleware; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -43,6 +44,7 @@ public class GenerateOptions { private final Integer maxTurns; private final ResumeOptions resume; private final Class outputClass; + private final List use; /** * Creates new GenerateOptions. @@ -74,7 +76,8 @@ public GenerateOptions( Map context, Integer maxTurns, ResumeOptions resume, - Class outputClass) { + Class outputClass, + List use) { this.model = model; this.prompt = prompt; this.messages = messages; @@ -88,6 +91,7 @@ public GenerateOptions( this.maxTurns = maxTurns; this.resume = resume; this.outputClass = outputClass; + this.use = use; } /** @@ -286,6 +290,15 @@ public Class getOutputClass() { return outputClass; } + /** + * Gets the V2 middleware to apply to this generation. + * + * @return the middleware list, or null if not set + */ + public List getUse() { + return use; + } + /** * Builder for GenerateOptions. * @@ -305,6 +318,7 @@ public static class Builder { private Integer maxTurns; private ResumeOptions resume; private Class outputClass; + private List use; public Builder model(String model) { this.model = model; @@ -407,6 +421,29 @@ public Builder resume(ResumeOptions resume) { return this; } + /** + * Sets V2 middleware to apply to this generation. Middleware hooks wrap the generate loop, + * model calls, and tool executions. + * + * @param use the middleware to apply + * @return this builder + */ + public Builder use(List use) { + this.use = use; + return this; + } + + /** + * Sets V2 middleware to apply to this generation. + * + * @param middleware the middleware to apply + * @return this builder + */ + public Builder use(GenerationMiddleware... middleware) { + this.use = List.of(middleware); + return this; + } + public GenerateOptions build() { return new GenerateOptions<>( model, @@ -421,7 +458,8 @@ public GenerateOptions build() { context, maxTurns, resume, - outputClass); + outputClass, + use); } } } diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/BaseGenerationMiddleware.java b/ai/src/main/java/com/google/genkit/ai/middleware/BaseGenerationMiddleware.java new file mode 100644 index 000000000..d327ea579 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/BaseGenerationMiddleware.java @@ -0,0 +1,78 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; +import java.util.Collections; +import java.util.List; + +/** + * BaseGenerationMiddleware provides default pass-through implementations for all three hooks. + * Extend this class and override only the hooks you need. + * + *

Example: + * + *

{@code
+ * public class TimingMiddleware extends BaseGenerationMiddleware {
+ *   @Override
+ *   public String name() { return "timing"; }
+ *
+ *   @Override
+ *   public GenerationMiddleware newInstance() { return new TimingMiddleware(); }
+ *
+ *   @Override
+ *   public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next)
+ *       throws GenkitException {
+ *     long start = System.currentTimeMillis();
+ *     ModelResponse resp = next.apply(ctx, params);
+ *     System.out.println("Model call took " + (System.currentTimeMillis() - start) + "ms");
+ *     return resp;
+ *   }
+ * }
+ * }
+ */ +public abstract class BaseGenerationMiddleware implements GenerationMiddleware { + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException { + return next.apply(ctx, params); + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + return next.apply(ctx, params); + } + + @Override + public ToolResponse wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + return next.apply(ctx, params); + } + + @Override + public List> tools() { + return Collections.emptyList(); + } +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/GenerateNext.java b/ai/src/main/java/com/google/genkit/ai/middleware/GenerateNext.java new file mode 100644 index 000000000..5301e8d50 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/GenerateNext.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** Next function in the {@link GenerationMiddleware#wrapGenerate} hook chain. */ +@FunctionalInterface +public interface GenerateNext { + + /** + * Calls the next handler in the generate chain. + * + * @param ctx the action context + * @param params the generate parameters + * @return the model response + * @throws GenkitException if processing fails + */ + ModelResponse apply(ActionContext ctx, GenerateParams params) throws GenkitException; +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/GenerateParams.java b/ai/src/main/java/com/google/genkit/ai/middleware/GenerateParams.java new file mode 100644 index 000000000..7508ec467 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/GenerateParams.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelRequest; + +/** Holds parameters for the {@link GenerationMiddleware#wrapGenerate} hook. */ +public class GenerateParams { + + private final ModelRequest request; + private final int iteration; + + /** + * Creates GenerateParams. + * + * @param request the current model request for this iteration + * @param iteration the current tool-loop iteration (0-indexed) + */ + public GenerateParams(ModelRequest request, int iteration) { + this.request = request; + this.iteration = iteration; + } + + /** Returns the current model request with accumulated messages. */ + public ModelRequest getRequest() { + return request; + } + + /** Returns the current tool-loop iteration (0-indexed). */ + public int getIteration() { + return iteration; + } + + /** Returns a new GenerateParams with the given request, preserving the iteration. */ + public GenerateParams withRequest(ModelRequest request) { + return new GenerateParams(request, this.iteration); + } +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/GenerationMiddleware.java b/ai/src/main/java/com/google/genkit/ai/middleware/GenerationMiddleware.java new file mode 100644 index 000000000..27db05f05 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/GenerationMiddleware.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; +import java.util.Collections; +import java.util.List; + +/** + * GenerationMiddleware provides hooks for different stages of the generation pipeline. + * + *

This is the V2 middleware interface that replaces the generic {@code Middleware}. It + * provides three distinct hooks: + * + *

    + *
  • {@link #wrapGenerate} - wraps each iteration of the tool loop + *
  • {@link #wrapModel} - wraps each model API call + *
  • {@link #wrapTool} - wraps each tool execution + *
+ * + *

Each {@code generate()} call creates a fresh instance via {@link #newInstance()}, enabling + * per-invocation state (e.g., counters, timers) without shared mutable state across requests. + * + *

Example: + * + *

{@code
+ * public class LoggingMiddleware extends BaseGenerationMiddleware {
+ *   private int modelCalls = 0;
+ *
+ *   @Override
+ *   public String name() { return "logging"; }
+ *
+ *   @Override
+ *   public GenerationMiddleware newInstance() { return new LoggingMiddleware(); }
+ *
+ *   @Override
+ *   public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next)
+ *       throws GenkitException {
+ *     modelCalls++;
+ *     System.out.println("Model call #" + modelCalls);
+ *     ModelResponse resp = next.apply(ctx, params);
+ *     System.out.println("Model responded with " + resp.getText());
+ *     return resp;
+ *   }
+ * }
+ * }
+ */ +public interface GenerationMiddleware { + + /** Returns the middleware's unique identifier. */ + String name(); + + /** + * Returns a fresh instance for each {@code generate()} call, enabling per-invocation state. + * + *

Stable state (e.g., API keys, configuration) should be preserved. Per-request state (e.g., + * counters) should be reset. + */ + GenerationMiddleware newInstance(); + + /** + * Wraps each iteration of the generate tool loop. + * + * @param ctx the action context + * @param params the generate parameters including the current request and iteration + * @param next the next function in the chain + * @return the model response + * @throws GenkitException if processing fails + */ + ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException; + + /** + * Wraps each model API call. + * + * @param ctx the action context + * @param params the model parameters including the request + * @param next the next function in the chain + * @return the model response + * @throws GenkitException if processing fails + */ + ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException; + + /** + * Wraps each tool execution. May be called concurrently when multiple tools execute in parallel. + * Implementations must be safe for concurrent use. + * + * @param ctx the action context + * @param params the tool parameters including the request and resolved tool + * @param next the next function in the chain + * @return the tool response + * @throws GenkitException if processing fails + */ + ToolResponse wrapTool(ActionContext ctx, ToolParams params, ToolNext next) throws GenkitException; + + /** + * Returns additional tools to make available during generation. These tools are dynamically added + * when the middleware is used. + * + * @return the list of additional tools, or empty list if none + */ + default List> tools() { + return Collections.emptyList(); + } +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/ModelNext.java b/ai/src/main/java/com/google/genkit/ai/middleware/ModelNext.java new file mode 100644 index 000000000..227b87f56 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/ModelNext.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** Next function in the {@link GenerationMiddleware#wrapModel} hook chain. */ +@FunctionalInterface +public interface ModelNext { + + /** + * Calls the next handler in the model chain. + * + * @param ctx the action context + * @param params the model parameters + * @return the model response + * @throws GenkitException if processing fails + */ + ModelResponse apply(ActionContext ctx, ModelParams params) throws GenkitException; +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/ModelParams.java b/ai/src/main/java/com/google/genkit/ai/middleware/ModelParams.java new file mode 100644 index 000000000..0e2aa1633 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/ModelParams.java @@ -0,0 +1,56 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponseChunk; +import java.util.function.Consumer; + +/** Holds parameters for the {@link GenerationMiddleware#wrapModel} hook. */ +public class ModelParams { + + private final ModelRequest request; + private final Consumer streamCallback; + + /** + * Creates ModelParams. + * + * @param request the model request about to be sent + * @param streamCallback the streaming callback, or null if not streaming + */ + public ModelParams(ModelRequest request, Consumer streamCallback) { + this.request = request; + this.streamCallback = streamCallback; + } + + /** Returns the model request about to be sent. */ + public ModelRequest getRequest() { + return request; + } + + /** Returns the streaming callback, or null if not streaming. */ + public Consumer getStreamCallback() { + return streamCallback; + } + + /** Returns a new ModelParams with the given request, preserving the stream callback. */ + public ModelParams withRequest(ModelRequest request) { + return new ModelParams(request, this.streamCallback); + } +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/ToolNext.java b/ai/src/main/java/com/google/genkit/ai/middleware/ToolNext.java new file mode 100644 index 000000000..53d79e835 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/ToolNext.java @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.ToolResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.GenkitException; + +/** Next function in the {@link GenerationMiddleware#wrapTool} hook chain. */ +@FunctionalInterface +public interface ToolNext { + + /** + * Calls the next handler in the tool chain. + * + * @param ctx the action context + * @param params the tool parameters + * @return the tool response + * @throws GenkitException if processing fails + */ + ToolResponse apply(ActionContext ctx, ToolParams params) throws GenkitException; +} diff --git a/ai/src/main/java/com/google/genkit/ai/middleware/ToolParams.java b/ai/src/main/java/com/google/genkit/ai/middleware/ToolParams.java new file mode 100644 index 000000000..36ba87732 --- /dev/null +++ b/ai/src/main/java/com/google/genkit/ai/middleware/ToolParams.java @@ -0,0 +1,50 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolRequest; + +/** Holds parameters for the {@link GenerationMiddleware#wrapTool} hook. */ +public class ToolParams { + + private final ToolRequest request; + private final Tool tool; + + /** + * Creates ToolParams. + * + * @param request the tool request about to be executed + * @param tool the resolved tool being called + */ + public ToolParams(ToolRequest request, Tool tool) { + this.request = request; + this.tool = tool; + } + + /** Returns the tool request about to be executed. */ + public ToolRequest getRequest() { + return request; + } + + /** Returns the resolved tool being called. */ + public Tool getTool() { + return tool; + } +} diff --git a/ai/src/test/java/com/google/genkit/ai/middleware/GenerationMiddlewareTest.java b/ai/src/test/java/com/google/genkit/ai/middleware/GenerationMiddlewareTest.java new file mode 100644 index 000000000..0b40f28c8 --- /dev/null +++ b/ai/src/test/java/com/google/genkit/ai/middleware/GenerationMiddlewareTest.java @@ -0,0 +1,664 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai.middleware; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.genkit.ai.Candidate; +import com.google.genkit.ai.Message; +import com.google.genkit.ai.ModelRequest; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolRequest; +import com.google.genkit.ai.ToolResponse; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.DefaultRegistry; +import com.google.genkit.core.GenkitException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** Tests for V2 GenerationMiddleware hooks: GenerateNext, ModelNext, ToolNext. */ +class GenerationMiddlewareTest { + + private ActionContext ctx; + + @BeforeEach + void setUp() { + ctx = new ActionContext(new DefaultRegistry()); + } + + // ========================================================================= + // Helper: build a simple ModelResponse with text + // ========================================================================= + + private static ModelResponse responseWithText(String text) { + Message msg = Message.model(text); + Candidate candidate = new Candidate(msg); + return ModelResponse.builder().addCandidate(candidate).build(); + } + + // ========================================================================= + // GenerateNext tests + // ========================================================================= + + @Test + void testGenerateNext_passThrough() { + ModelRequest request = ModelRequest.builder().addUserMessage("hello").build(); + GenerateParams params = new GenerateParams(request, 0); + ModelResponse expected = responseWithText("world"); + + GenerateNext next = (c, p) -> expected; + + ModelResponse result = next.apply(ctx, params); + assertSame(expected, result); + } + + @Test + void testGenerateNext_chainOrder() { + List order = new ArrayList<>(); + + // Core function + GenerateNext core = + (c, p) -> { + order.add("core"); + return responseWithText("response"); + }; + + // Outer middleware wrapping core + GenerateNext outer = + (c, p) -> { + order.add("outer-before"); + ModelResponse resp = core.apply(c, p); + order.add("outer-after"); + return resp; + }; + + ModelRequest request = ModelRequest.builder().addUserMessage("test").build(); + outer.apply(ctx, new GenerateParams(request, 0)); + + assertEquals(List.of("outer-before", "core", "outer-after"), order); + } + + @Test + void testGenerateNext_canModifyParams() { + ModelRequest original = ModelRequest.builder().addUserMessage("original").build(); + ModelRequest modified = ModelRequest.builder().addUserMessage("modified").build(); + + AtomicInteger iterationSeen = new AtomicInteger(-1); + GenerateNext core = + (c, p) -> { + iterationSeen.set(p.getIteration()); + assertEquals(modified, p.getRequest()); + return responseWithText("ok"); + }; + + // Middleware that replaces the request + GenerateNext wrapper = + (c, p) -> { + GenerateParams newParams = p.withRequest(modified); + return core.apply(c, newParams); + }; + + wrapper.apply(ctx, new GenerateParams(original, 5)); + assertEquals(5, iterationSeen.get()); // iteration preserved by withRequest + } + + @Test + void testGenerateNext_exceptionPropagates() { + GenerateNext failing = + (c, p) -> { + throw new GenkitException("boom"); + }; + + ModelRequest request = ModelRequest.builder().build(); + assertThrows(GenkitException.class, () -> failing.apply(ctx, new GenerateParams(request, 0))); + } + + // ========================================================================= + // ModelNext tests + // ========================================================================= + + @Test + void testModelNext_passThrough() { + ModelRequest request = ModelRequest.builder().addUserMessage("hello").build(); + ModelParams params = new ModelParams(request, null); + ModelResponse expected = responseWithText("model output"); + + ModelNext next = (c, p) -> expected; + + ModelResponse result = next.apply(ctx, params); + assertSame(expected, result); + } + + @Test + void testModelNext_chainOrder() { + List order = new ArrayList<>(); + + ModelNext core = + (c, p) -> { + order.add("model"); + return responseWithText("result"); + }; + + ModelNext wrapper = + (c, p) -> { + order.add("before-model"); + ModelResponse resp = core.apply(c, p); + order.add("after-model"); + return resp; + }; + + ModelRequest request = ModelRequest.builder().build(); + wrapper.apply(ctx, new ModelParams(request, null)); + + assertEquals(List.of("before-model", "model", "after-model"), order); + } + + @Test + void testModelNext_canModifyRequest() { + ModelRequest original = ModelRequest.builder().addUserMessage("original").build(); + ModelRequest modified = ModelRequest.builder().addUserMessage("injected").build(); + + ModelNext core = + (c, p) -> { + assertEquals(modified, p.getRequest()); + return responseWithText("ok"); + }; + + ModelNext wrapper = + (c, p) -> { + ModelParams newParams = p.withRequest(modified); + return core.apply(c, newParams); + }; + + wrapper.apply(ctx, new ModelParams(original, null)); + } + + @Test + void testModelNext_preservesStreamCallback() { + List streamed = new ArrayList<>(); + ModelParams params = + new ModelParams(ModelRequest.builder().build(), chunk -> streamed.add("chunk")); + + ModelNext next = + (c, p) -> { + assertNotNull(p.getStreamCallback()); + return responseWithText("ok"); + }; + + next.apply(ctx, params); + assertNotNull(params.getStreamCallback()); + } + + @Test + void testModelNext_exceptionPropagates() { + ModelNext failing = + (c, p) -> { + throw new GenkitException("model failed"); + }; + + assertThrows( + GenkitException.class, + () -> failing.apply(ctx, new ModelParams(ModelRequest.builder().build(), null))); + } + + // ========================================================================= + // ToolNext tests + // ========================================================================= + + @Test + void testToolNext_passThrough() { + ToolRequest toolReq = new ToolRequest("myTool", Map.of("key", "value")); + Tool tool = createTestTool("myTool"); + ToolParams params = new ToolParams(toolReq, tool); + ToolResponse expected = new ToolResponse("myTool", "tool output"); + + ToolNext next = (c, p) -> expected; + + ToolResponse result = next.apply(ctx, params); + assertSame(expected, result); + } + + @Test + void testToolNext_chainOrder() { + List order = new ArrayList<>(); + + ToolNext core = + (c, p) -> { + order.add("tool-exec"); + return new ToolResponse(p.getRequest().getName(), "result"); + }; + + ToolNext wrapper = + (c, p) -> { + order.add("before-tool"); + ToolResponse resp = core.apply(c, p); + order.add("after-tool"); + return resp; + }; + + ToolRequest toolReq = new ToolRequest("test", Map.of()); + wrapper.apply(ctx, new ToolParams(toolReq, createTestTool("test"))); + + assertEquals(List.of("before-tool", "tool-exec", "after-tool"), order); + } + + @Test + void testToolNext_accessesToolInfo() { + Tool tool = createTestTool("weatherTool"); + ToolRequest toolReq = new ToolRequest("weatherTool", Map.of("city", "Paris")); + + ToolNext next = + (c, p) -> { + assertEquals("weatherTool", p.getRequest().getName()); + assertEquals("weatherTool", p.getTool().getName()); + return new ToolResponse("weatherTool", "sunny"); + }; + + ToolResponse resp = next.apply(ctx, new ToolParams(toolReq, tool)); + assertEquals("weatherTool", resp.getName()); + } + + @Test + void testToolNext_exceptionPropagates() { + ToolNext failing = + (c, p) -> { + throw new GenkitException("tool failed"); + }; + + ToolRequest toolReq = new ToolRequest("t", Map.of()); + assertThrows( + GenkitException.class, + () -> failing.apply(ctx, new ToolParams(toolReq, createTestTool("t")))); + } + + // ========================================================================= + // BaseGenerationMiddleware tests + // ========================================================================= + + @Test + void testBaseMiddleware_defaultsPassThrough() { + BaseGenerationMiddleware base = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "noop"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + }; + + // wrapGenerate passes through + ModelRequest req = ModelRequest.builder().addUserMessage("test").build(); + ModelResponse expected = responseWithText("pass"); + GenerateNext gNext = (c, p) -> expected; + ModelResponse gResult = base.wrapGenerate(ctx, new GenerateParams(req, 0), gNext); + assertSame(expected, gResult); + + // wrapModel passes through + ModelNext mNext = (c, p) -> expected; + ModelResponse mResult = base.wrapModel(ctx, new ModelParams(req, null), mNext); + assertSame(expected, mResult); + + // wrapTool passes through + ToolResponse toolExpected = new ToolResponse("t", "data"); + ToolNext tNext = (c, p) -> toolExpected; + ToolResponse tResult = + base.wrapTool( + ctx, new ToolParams(new ToolRequest("t", Map.of()), createTestTool("t")), tNext); + assertSame(toolExpected, tResult); + + // tools returns empty + assertTrue(base.tools().isEmpty()); + } + + @Test + void testCustomMiddleware_overridesSelectedHooks() { + AtomicInteger modelCallCount = new AtomicInteger(0); + + BaseGenerationMiddleware middleware = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "model-counter"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + modelCallCount.incrementAndGet(); + return next.apply(ctx, params); + } + }; + + ModelRequest req = ModelRequest.builder().build(); + ModelResponse resp = responseWithText("ok"); + + // wrapModel is overridden + middleware.wrapModel(ctx, new ModelParams(req, null), (c, p) -> resp); + assertEquals(1, modelCallCount.get()); + + // wrapGenerate still passes through (default) + ModelResponse gResp = middleware.wrapGenerate(ctx, new GenerateParams(req, 0), (c, p) -> resp); + assertSame(resp, gResp); + assertEquals(1, modelCallCount.get()); // not incremented + } + + // ========================================================================= + // Chaining multiple middleware + // ========================================================================= + + @Test + void testChainGenerateHooks_nestedOrder() { + List order = new ArrayList<>(); + + GenerationMiddleware outer = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "outer"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapGenerate( + ActionContext ctx, GenerateParams params, GenerateNext next) throws GenkitException { + order.add("outer-before"); + ModelResponse resp = next.apply(ctx, params); + order.add("outer-after"); + return resp; + } + }; + + GenerationMiddleware inner = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "inner"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapGenerate( + ActionContext ctx, GenerateParams params, GenerateNext next) throws GenkitException { + order.add("inner-before"); + ModelResponse resp = next.apply(ctx, params); + order.add("inner-after"); + return resp; + } + }; + + // Chain: outer wraps inner wraps core + // This mirrors the chaining in Genkit.chainGenerateHooks() + List middlewares = List.of(outer, inner); + GenerateNext core = + (c, p) -> { + order.add("core"); + return responseWithText("done"); + }; + + // Build chain by reverse iteration (first middleware = outermost) + GenerateNext chain = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + GenerationMiddleware mw = middlewares.get(i); + GenerateNext wrapped = chain; + chain = (c, p) -> mw.wrapGenerate(c, p, wrapped); + } + + ModelRequest req = ModelRequest.builder().build(); + chain.apply(ctx, new GenerateParams(req, 0)); + + assertEquals( + List.of("outer-before", "inner-before", "core", "inner-after", "outer-after"), order); + } + + @Test + void testChainModelHooks_nestedOrder() { + List order = new ArrayList<>(); + + GenerationMiddleware first = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "first"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + order.add("first-before"); + ModelResponse resp = next.apply(ctx, params); + order.add("first-after"); + return resp; + } + }; + + GenerationMiddleware second = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "second"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + order.add("second-before"); + ModelResponse resp = next.apply(ctx, params); + order.add("second-after"); + return resp; + } + }; + + List middlewares = List.of(first, second); + ModelNext core = + (c, p) -> { + order.add("model"); + return responseWithText("result"); + }; + + ModelNext chain = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + GenerationMiddleware mw = middlewares.get(i); + ModelNext wrapped = chain; + chain = (c, p) -> mw.wrapModel(c, p, wrapped); + } + + chain.apply(ctx, new ModelParams(ModelRequest.builder().build(), null)); + + assertEquals( + List.of("first-before", "second-before", "model", "second-after", "first-after"), order); + } + + @Test + void testChainToolHooks_nestedOrder() { + List order = new ArrayList<>(); + + GenerationMiddleware first = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "first"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ToolResponse wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + order.add("first-before"); + ToolResponse resp = next.apply(ctx, params); + order.add("first-after"); + return resp; + } + }; + + GenerationMiddleware second = + new BaseGenerationMiddleware() { + @Override + public String name() { + return "second"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ToolResponse wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + order.add("second-before"); + ToolResponse resp = next.apply(ctx, params); + order.add("second-after"); + return resp; + } + }; + + List middlewares = List.of(first, second); + ToolNext core = + (c, p) -> { + order.add("tool"); + return new ToolResponse(p.getRequest().getName(), "output"); + }; + + ToolNext chain = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + GenerationMiddleware mw = middlewares.get(i); + ToolNext wrapped = chain; + chain = (c, p) -> mw.wrapTool(c, p, wrapped); + } + + ToolRequest toolReq = new ToolRequest("myTool", Map.of()); + chain.apply(ctx, new ToolParams(toolReq, createTestTool("myTool"))); + + assertEquals( + List.of("first-before", "second-before", "tool", "second-after", "first-after"), order); + } + + // ========================================================================= + // newInstance() isolation + // ========================================================================= + + @Test + void testNewInstance_isolatesState() { + AtomicInteger sharedCounter = new AtomicInteger(0); + + GenerationMiddleware template = + new BaseGenerationMiddleware() { + private final AtomicInteger calls = new AtomicInteger(0); + + @Override + public String name() { + return "stateful"; + } + + @Override + public GenerationMiddleware newInstance() { + // Each instance gets its own counter + return new BaseGenerationMiddleware() { + private final AtomicInteger instanceCalls = new AtomicInteger(0); + + @Override + public String name() { + return "stateful"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + instanceCalls.incrementAndGet(); + sharedCounter.incrementAndGet(); + return next.apply(ctx, params); + } + }; + } + }; + + // Simulate two generate() calls creating separate instances + GenerationMiddleware instance1 = template.newInstance(); + GenerationMiddleware instance2 = template.newInstance(); + + ModelResponse resp = responseWithText("ok"); + ModelNext passThrough = (c, p) -> resp; + ModelParams params = new ModelParams(ModelRequest.builder().build(), null); + + // Call instance1 three times + instance1.wrapModel(ctx, params, passThrough); + instance1.wrapModel(ctx, params, passThrough); + instance1.wrapModel(ctx, params, passThrough); + + // Call instance2 once + instance2.wrapModel(ctx, params, passThrough); + + // Shared counter sees all 4 calls + assertEquals(4, sharedCounter.get()); + + // But instances are independent (verified by the fact that both ran without error) + } + + // ========================================================================= + // Helper + // ========================================================================= + + private static Tool createTestTool(String name) { + Map schema = new HashMap<>(); + schema.put("type", "string"); + return new Tool<>(name, "Test tool", schema, schema, String.class, (ctx, input) -> "result"); + } +} diff --git a/core/src/main/java/com/google/genkit/core/middleware/Middleware.java b/core/src/main/java/com/google/genkit/core/middleware/Middleware.java index db508e7d8..4479eda02 100644 --- a/core/src/main/java/com/google/genkit/core/middleware/Middleware.java +++ b/core/src/main/java/com/google/genkit/core/middleware/Middleware.java @@ -41,7 +41,10 @@ * * @param The input type * @param The output type + * @deprecated Use {@code com.google.genkit.ai.middleware.GenerationMiddleware} instead, which + * supports distinct Generate, Model, and Tool hooks. */ +@Deprecated @FunctionalInterface public interface Middleware { diff --git a/genkit/src/main/java/com/google/genkit/Genkit.java b/genkit/src/main/java/com/google/genkit/Genkit.java index e6c1aed72..e4dc9ac97 100644 --- a/genkit/src/main/java/com/google/genkit/Genkit.java +++ b/genkit/src/main/java/com/google/genkit/Genkit.java @@ -20,6 +20,7 @@ import com.google.genkit.ai.*; import com.google.genkit.ai.evaluation.*; +import com.google.genkit.ai.middleware.*; import com.google.genkit.ai.session.*; import com.google.genkit.ai.telemetry.ModelTelemetryHelper; import com.google.genkit.core.*; @@ -658,85 +659,296 @@ private ModelResponse generateInternal(GenerateOptions options) throws Genkit ActionContext ctx = new ActionContext(registry); int maxTurns = options.getMaxTurns() != null ? options.getMaxTurns() : 5; - int turn = 0; + + // Create fresh middleware instances for this invocation + List middlewares = createMiddlewareInstances(options.getUse()); + + // Collect tools from middleware instances and merge with options tools + List> allTools = new ArrayList<>(); + if (options.getTools() != null) { + allTools.addAll(options.getTools()); + } + for (GenerationMiddleware mw : middlewares) { + List> mwTools = mw.tools(); + if (mwTools != null && !mwTools.isEmpty()) { + allTools.addAll(mwTools); + } + } + + // Add middleware tool definitions to the model request + if (allTools.size() > (options.getTools() != null ? options.getTools().size() : 0)) { + List allToolDefs = new ArrayList<>(); + if (request.getTools() != null) { + allToolDefs.addAll(request.getTools()); + } + for (GenerationMiddleware mw : middlewares) { + List> mwTools = mw.tools(); + if (mwTools != null) { + for (Tool t : mwTools) { + allToolDefs.add(t.getDefinition()); + } + } + } + request.setTools(allToolDefs); + } // Handle resume option if provided if (options.getResume() != null) { request = handleResumeOption(request, options); } - while (turn < maxTurns) { - // Create span metadata for the model call - SpanMetadata modelSpanMetadata = - SpanMetadata.builder() - .name(options.getModel()) - .type(ActionType.MODEL.getValue()) - .subtype("model") - .build(); + // Build model call wrapped with WrapModel hooks + ModelNext wrappedModelCall = buildWrappedModelCall(model, options, ctx, middlewares); - String flowName = ctx.getFlowName(); - if (flowName != null) { - modelSpanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); - } + // Use an array to hold the reference for recursive WrapGenerate wrapping + final GenerateNext[] generateRef = new GenerateNext[1]; - final ModelRequest currentRequest = request; - final String flowNameForTelemetry = flowName; - final String spanPath = "/generate/" + options.getModel(); - ModelResponse response = - Tracer.runInNewSpan( - ctx, + // Core generate iteration: model call → tool handling → recurse + GenerateNext rawGenerate = + (actx, params) -> { + ModelRequest req = params.getRequest(); + int turn = params.getIteration(); + + if (turn >= maxTurns) { + throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded"); + } + + // Call model through WrapModel chain + ModelParams mparams = new ModelParams(req, null); + ModelResponse response = wrappedModelCall.apply(actx, mparams); + + // Check if the model requested tool calls + List toolRequestParts = extractToolRequestParts(response); + if (toolRequestParts.isEmpty()) { + return response; + } + + // Execute tools through WrapTool chain (includes middleware-provided tools) + ToolExecutionResult toolResult = + executeToolsWithMiddleware(actx, toolRequestParts, allTools, middlewares); + + // If there are interrupts, return immediately + if (!toolResult.getInterrupts().isEmpty()) { + return buildInterruptedResponse(response, toolResult); + } + + // Build next request with updated messages + Message assistantMessage = response.getMessage(); + List updatedMessages = new java.util.ArrayList<>(req.getMessages()); + updatedMessages.add(assistantMessage); + + Message toolResponseMessage = new Message(); + toolResponseMessage.setRole(Role.TOOL); + toolResponseMessage.setContent(toolResult.getResponses()); + updatedMessages.add(toolResponseMessage); + + ModelRequest nextRequest = + ModelRequest.builder() + .messages(updatedMessages) + .config(req.getConfig()) + .tools(req.getTools()) + .output(req.getOutput()) + .build(); + + // Recurse through the wrapped generate function (goes through WrapGenerate hooks) + return generateRef[0].apply(actx, new GenerateParams(nextRequest, turn + 1)); + }; + + // Chain WrapGenerate hooks around the core iteration + generateRef[0] = chainGenerateHooks(middlewares, rawGenerate); + + // Start generation + return generateRef[0].apply(ctx, new GenerateParams(request, 0)); + } + + /** Creates fresh middleware instances for a single generate invocation. */ + private List createMiddlewareInstances(List use) { + if (use == null || use.isEmpty()) { + return List.of(); + } + return use.stream().map(GenerationMiddleware::newInstance).toList(); + } + + /** Builds the model call function wrapped with WrapModel hooks from middleware. */ + private ModelNext buildWrappedModelCall( + Model model, + GenerateOptions options, + ActionContext ctx, + List middlewares) { + + // Core model call with telemetry + ModelNext core = + (actx, mparams) -> { + ModelRequest req = mparams.getRequest(); + + SpanMetadata modelSpanMetadata = + SpanMetadata.builder() + .name(options.getModel()) + .type(ActionType.MODEL.getValue()) + .subtype("model") + .build(); + + String flowName = actx.getFlowName(); + if (flowName != null) { + modelSpanMetadata.getAttributes().put("genkit:metadata:flow:name", flowName); + } + + final String spanPath = "/generate/" + options.getModel(); + return Tracer.runInNewSpan( + actx, modelSpanMetadata, - request, - (spanCtx, req) -> { - // Wrap model execution with telemetry to record generate metrics + req, + (spanCtx, r) -> { return ModelTelemetryHelper.runWithTelemetry( options.getModel(), - flowNameForTelemetry, + flowName, spanPath, - currentRequest, - r -> model.run(ctx.withSpanContext(spanCtx), r)); + req, + mr -> model.run(actx.withSpanContext(spanCtx), mr)); }); + }; - // Check if the model requested tool calls - List toolRequestParts = extractToolRequestParts(response); - if (toolRequestParts.isEmpty()) { - // No tool calls, return the response - return response; - } + return chainModelHooks(middlewares, core); + } + + /** Chains WrapGenerate hooks. First middleware is outermost. */ + private GenerateNext chainGenerateHooks( + List middlewares, GenerateNext core) { + if (middlewares.isEmpty()) { + return core; + } + GenerateNext current = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + final GenerationMiddleware mw = middlewares.get(i); + final GenerateNext next = current; + current = (ctx, params) -> mw.wrapGenerate(ctx, params, next); + } + return current; + } + + /** Chains WrapModel hooks. First middleware is outermost. */ + private ModelNext chainModelHooks(List middlewares, ModelNext core) { + if (middlewares.isEmpty()) { + return core; + } + ModelNext current = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + final GenerationMiddleware mw = middlewares.get(i); + final ModelNext next = current; + current = (ctx, params) -> mw.wrapModel(ctx, params, next); + } + return current; + } + + /** Chains WrapTool hooks. First middleware is outermost. */ + private ToolNext chainToolHooks(List middlewares, ToolNext core) { + if (middlewares.isEmpty()) { + return core; + } + ToolNext current = core; + for (int i = middlewares.size() - 1; i >= 0; i--) { + final GenerationMiddleware mw = middlewares.get(i); + final ToolNext next = current; + current = (ctx, params) -> mw.wrapTool(ctx, params, next); + } + return current; + } + + /** Executes tools with WrapTool middleware hooks applied. */ + private ToolExecutionResult executeToolsWithMiddleware( + ActionContext ctx, + List toolRequestParts, + List> tools, + List middlewares) { + + // Build WrapTool chain + ToolNext wrappedToolCall = + chainToolHooks( + middlewares, + (actx, tparams) -> { + Tool tool = tparams.getTool(); + ToolRequest toolReq = tparams.getRequest(); + + Object toolInput = toolReq.getInput(); + Class inputClass = tool.getInputClass(); + if (inputClass != null && toolInput != null && !inputClass.isInstance(toolInput)) { + toolInput = JsonUtils.convert(toolInput, inputClass); + } + + @SuppressWarnings("unchecked") + Tool typedTool = (Tool) tool; + Object result = typedTool.run(actx, toolInput); + + return new ToolResponse(toolReq.getRef(), toolReq.getName(), result); + }); - // Execute tools and handle interrupts - ToolExecutionResult toolResult = - executeToolsWithInterruptHandling(ctx, toolRequestParts, options.getTools()); + List responseParts = new java.util.ArrayList<>(); + List interrupts = new java.util.ArrayList<>(); + Map interruptMap = new java.util.HashMap<>(); + Map pendingOutputMap = new java.util.HashMap<>(); + + for (Part toolRequestPart : toolRequestParts) { + ToolRequest toolRequest = toolRequestPart.getToolRequest(); + String toolName = toolRequest.getName(); + String key = toolName + "#" + (toolRequest.getRef() != null ? toolRequest.getRef() : ""); - // If there are interrupts, return immediately with interrupted response - if (!toolResult.getInterrupts().isEmpty()) { - return buildInterruptedResponse(response, toolResult); + Tool tool = findTool(toolName, tools); + if (tool == null) { + Part errorPart = new Part(); + ToolResponse errorResponse = + new ToolResponse( + toolRequest.getRef(), toolName, Map.of("error", "Tool not found: " + toolName)); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + continue; } - // Add the assistant message with tool requests - Message assistantMessage = response.getMessage(); - List updatedMessages = new java.util.ArrayList<>(request.getMessages()); - updatedMessages.add(assistantMessage); + try { + // Execute through WrapTool chain + ToolParams tparams = new ToolParams(toolRequest, tool); + ToolResponse toolResponse = wrappedToolCall.apply(ctx, tparams); - // Add tool response message - Message toolResponseMessage = new Message(); - toolResponseMessage.setRole(Role.TOOL); - toolResponseMessage.setContent(toolResult.getResponses()); - updatedMessages.add(toolResponseMessage); + Part responsePart = new Part(); + responsePart.setToolResponse(toolResponse); + responseParts.add(responsePart); - // Update request with new messages for next turn - request = - ModelRequest.builder() - .messages(updatedMessages) - .config(request.getConfig()) - .tools(request.getTools()) - .output(request.getOutput()) - .build(); + pendingOutputMap.put(key, toolResponse.getOutput()); - turn++; + logger.debug("Executed tool '{}' successfully", toolName); + + } catch (ToolInterruptException e) { + Map interruptMetadata = e.getMetadata(); + + Part interruptPart = new Part(); + interruptPart.setToolRequest(toolRequest); + Map metadata = + toolRequestPart.getMetadata() != null + ? new java.util.HashMap<>(toolRequestPart.getMetadata()) + : new java.util.HashMap<>(); + metadata.put( + "interrupt", + interruptMetadata != null && !interruptMetadata.isEmpty() ? interruptMetadata : true); + interruptPart.setMetadata(metadata); + + interrupts.add(interruptPart); + interruptMap.put(key, interruptPart); + + logger.debug("Tool '{}' triggered interrupt", toolName); + + } catch (Exception e) { + logger.error("Tool execution failed for '{}': {}", toolName, e.getMessage()); + Part errorPart = new Part(); + ToolResponse errorResponse = + new ToolResponse( + toolRequest.getRef(), + toolName, + Map.of("error", "Tool execution failed: " + e.getMessage())); + errorPart.setToolResponse(errorResponse); + responseParts.add(errorPart); + } } - throw new GenkitException("Max tool execution turns (" + maxTurns + ") exceeded"); + return new ToolExecutionResult(responseParts, interrupts, interruptMap, pendingOutputMap); } /** Handles resume options by processing respond and restart directives. */ diff --git a/pom.xml b/pom.xml index ffb0bc5f6..42210896c 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ samples/evaluators-plugin samples/complex-io samples/middleware + samples/middleware-v2 samples/mcp samples/chat-session samples/multi-agent diff --git a/samples/middleware-v2/README.md b/samples/middleware-v2/README.md new file mode 100644 index 000000000..5726e101c --- /dev/null +++ b/samples/middleware-v2/README.md @@ -0,0 +1,161 @@ +# Genkit Java Middleware V2 Sample + +This sample demonstrates the **V2 GenerationMiddleware** system, which provides three distinct hooks into the generation pipeline: + +- **WrapGenerate** — wraps each iteration of the tool loop +- **WrapModel** — wraps each model API call +- **WrapTool** — wraps each tool execution + +Unlike V1 middleware (which wraps flows), V2 middleware is attached per `generate()` call via `GenerateOptions.builder().use()` and hooks directly into the AI generation pipeline. + +## Prerequisites + +- Java 21+ +- Maven 3.6+ +- OpenAI API key + +## Running the Sample + +### Option 1: Direct Run + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/middleware-v2 + +# Run the sample +./run.sh +# Or: mvn compile exec:java +``` + +### Option 2: With Genkit Dev UI (Recommended) + +```bash +# Set your OpenAI API key +export OPENAI_API_KEY=your-api-key-here + +# Navigate to the sample directory +cd java/samples/middleware-v2 + +# Run with Genkit CLI +genkit start -- ./run.sh +``` + +The Dev UI will be available at http://localhost:4000 + +## Middleware Examples + +### 1. ModelLoggingMiddleware (WrapModel) +Logs every model API call with a per-invocation counter. Demonstrates `newInstance()` for fresh state per `generate()` call. + +### 2. GenerateTimingMiddleware (WrapGenerate) +Measures wall-clock time for each generate loop iteration (model call + tool execution). + +### 3. ToolMonitorMiddleware (WrapTool) +Logs tool execution name and duration. Stateless — `newInstance()` returns `this`. + +### 4. FullObservabilityMiddleware (All 3 hooks) +A single middleware that implements all three hooks, showing how one middleware can observe the entire pipeline with per-invocation counters. + +## Available Endpoints + +| Endpoint | Description | Middleware | +|----------|-------------|------------| +| `/v2-chat` | AI chat | Model logging + generate timing | +| `/v2-observable` | AI chat | Full observability (all 3 hooks) | +| `/v2-stacked` | AI chat | Three separate middleware stacked | +| `/v2-baseline` | AI chat | No middleware (baseline) | + +## Example Requests + +```bash +# Chat with model logging + timing +curl -X POST http://localhost:8080/v2-chat \ + -H 'Content-Type: application/json' \ + -d '"What is middleware?"' + +# Chat with full observability +curl -X POST http://localhost:8080/v2-observable \ + -H 'Content-Type: application/json' \ + -d '"Explain Java records"' + +# Chat with stacked middleware +curl -X POST http://localhost:8080/v2-stacked \ + -H 'Content-Type: application/json' \ + -d '"Hello world"' + +# Baseline (no middleware) +curl -X POST http://localhost:8080/v2-baseline \ + -H 'Content-Type: application/json' \ + -d '"Hello world"' +``` + +## Creating Custom V2 Middleware + +Extend `BaseGenerationMiddleware` and override only the hooks you need: + +```java +import com.google.genkit.ai.middleware.*; +import com.google.genkit.core.ActionContext; + +public class MyMiddleware extends BaseGenerationMiddleware { + + @Override + public String name() { return "my-middleware"; } + + @Override + public GenerationMiddleware newInstance() { return new MyMiddleware(); } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + System.out.println("Before model call"); + ModelResponse resp = next.apply(ctx, params); + System.out.println("After model call: " + resp.getText().length() + " chars"); + return resp; + } +} +``` + +Then attach it to a `generate()` call: + +```java +ModelResponse response = genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt("Hello") + .use(new MyMiddleware()) + .build()); +``` + +## Architecture + +V2 middleware wraps the generation pipeline at three levels: + +``` +generate() call + └─ WrapGenerate (per tool-loop iteration) + └─ WrapModel (per model API call) + └─ WrapTool (per tool execution) + └─ recurse → next WrapGenerate iteration +``` + +Each `generate()` call creates fresh middleware instances via `newInstance()`, enabling per-invocation state (counters, timers) without shared mutable state across requests. + +Middleware are chained in order — the first middleware in the `use()` list is the outermost wrapper. + +## V1 vs V2 Middleware + +| | V1 (`Middleware`) | V2 (`GenerationMiddleware`) | +|---|---|---| +| **Scope** | Wraps flows | Wraps generation pipeline | +| **Hooks** | Single `apply()` | 3 hooks: Generate, Model, Tool | +| **Attachment** | `defineFlow(..., middleware)` | `GenerateOptions.builder().use(...)` | +| **State** | Shared across calls | Fresh per `generate()` via `newInstance()` | + +## See Also + +- [V1 Middleware Sample](../middleware/) — flow-level middleware +- [Genkit Documentation](https://github.com/genkit-ai/genkit-java) diff --git a/samples/middleware-v2/pom.xml b/samples/middleware-v2/pom.xml new file mode 100644 index 000000000..0857d4462 --- /dev/null +++ b/samples/middleware-v2/pom.xml @@ -0,0 +1,91 @@ + + + + 4.0.0 + + + com.google.genkit + genkit-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + com.google.genkit.samples + genkit-sample-middleware-v2 + jar + Genkit Middleware V2 Sample + Sample application demonstrating V2 GenerationMiddleware with Generate, Model, and Tool hooks + + + UTF-8 + 21 + 21 + 1.0.0-SNAPSHOT + true + true + + + + + com.google.genkit + genkit + ${genkit.version} + + + com.google.genkit + genkit-plugin-openai + ${genkit.version} + + + com.google.genkit + genkit-plugin-jetty + ${genkit.version} + + + ch.qos.logback + logback-classic + 1.5.32 + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.15.0 + + 21 + 21 + + + + org.codehaus.mojo + exec-maven-plugin + 3.6.3 + + com.google.genkit.samples.MiddlewareV2Sample + + + + + diff --git a/samples/middleware-v2/run.sh b/samples/middleware-v2/run.sh new file mode 100755 index 000000000..54805f669 --- /dev/null +++ b/samples/middleware-v2/run.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# Run the Genkit Middleware V2 Sample + +set -e + +# Navigate to the sample directory +cd "$(dirname "$0")" + +# Check for OPENAI_API_KEY +if [ -z "$OPENAI_API_KEY" ]; then + echo "Warning: OPENAI_API_KEY is not set. The sample may not work correctly." + echo "Set it with: export OPENAI_API_KEY=your-api-key" +fi + +# Build and run +echo "Building and running Genkit Middleware V2 Sample..." +mvn compile exec:java -q diff --git a/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareV2Sample.java b/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareV2Sample.java new file mode 100644 index 000000000..7c6f288b5 --- /dev/null +++ b/samples/middleware-v2/src/main/java/com/google/genkit/samples/MiddlewareV2Sample.java @@ -0,0 +1,402 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.samples; + +import com.google.genkit.Genkit; +import com.google.genkit.GenkitOptions; +import com.google.genkit.ai.GenerateOptions; +import com.google.genkit.ai.GenerationConfig; +import com.google.genkit.ai.ModelResponse; +import com.google.genkit.ai.Tool; +import com.google.genkit.ai.ToolResponse; +import com.google.genkit.ai.middleware.BaseGenerationMiddleware; +import com.google.genkit.ai.middleware.GenerateNext; +import com.google.genkit.ai.middleware.GenerateParams; +import com.google.genkit.ai.middleware.GenerationMiddleware; +import com.google.genkit.ai.middleware.ModelNext; +import com.google.genkit.ai.middleware.ModelParams; +import com.google.genkit.ai.middleware.ToolNext; +import com.google.genkit.ai.middleware.ToolParams; +import com.google.genkit.core.ActionContext; +import com.google.genkit.core.Flow; +import com.google.genkit.core.GenkitException; +import com.google.genkit.plugins.jetty.JettyPlugin; +import com.google.genkit.plugins.jetty.JettyPluginOptions; +import com.google.genkit.plugins.openai.OpenAIPlugin; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Sample application demonstrating the V2 GenerationMiddleware system. + * + *

V2 middleware hooks into three distinct stages of the generation pipeline: + * + *

    + *
  • WrapGenerate — wraps each iteration of the tool loop + *
  • WrapModel — wraps each model API call + *
  • WrapTool — wraps each tool execution + *
+ * + *

Middleware is attached per {@code generate()} call via {@code GenerateOptions.builder().use()} + * rather than per flow. + * + *

Each {@code generate()} call creates a fresh middleware instance via {@code newInstance()}, + * enabling per-invocation state (counters, timers) without shared mutable state across requests. + * + *

To run: + * + *

    + *
  1. Set the OPENAI_API_KEY environment variable + *
  2. Run: mvn exec:java + *
+ */ +public class MiddlewareV2Sample { + + private static final Logger logger = LoggerFactory.getLogger(MiddlewareV2Sample.class); + + // ========================================================================= + // Example 1: Model logging middleware (WrapModel hook) + // ========================================================================= + + /** + * Logs every model API call with a call counter. The counter resets per generate() invocation + * because {@code newInstance()} returns a fresh object. + */ + static class ModelLoggingMiddleware extends BaseGenerationMiddleware { + + private final AtomicInteger modelCalls = new AtomicInteger(0); + + @Override + public String name() { + return "model-logging"; + } + + @Override + public GenerationMiddleware newInstance() { + return new ModelLoggingMiddleware(); + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + int callNum = modelCalls.incrementAndGet(); + logger.info("[model-logging] Model call #{}", callNum); + ModelResponse resp = next.apply(ctx, params); + logger.info( + "[model-logging] Model call #{} returned ({} chars)", + callNum, + resp.getText() != null ? resp.getText().length() : 0); + return resp; + } + } + + // ========================================================================= + // Example 2: Generate timing middleware (WrapGenerate hook) + // ========================================================================= + + /** + * Measures the wall-clock time of each generate loop iteration including model call + tool + * execution within that iteration. + */ + static class GenerateTimingMiddleware extends BaseGenerationMiddleware { + + @Override + public String name() { + return "generate-timing"; + } + + @Override + public GenerationMiddleware newInstance() { + return new GenerateTimingMiddleware(); + } + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException { + long start = System.currentTimeMillis(); + logger.info("[generate-timing] Starting iteration {}", params.getIteration()); + ModelResponse resp = next.apply(ctx, params); + logger.info( + "[generate-timing] Iteration {} completed in {}ms", + params.getIteration(), + System.currentTimeMillis() - start); + return resp; + } + } + + // ========================================================================= + // Example 3: Tool monitor middleware (WrapTool hook) + // ========================================================================= + + /** Logs tool execution name and duration. Stateless, so newInstance() returns {@code this}. */ + static class ToolMonitorMiddleware extends BaseGenerationMiddleware { + + @Override + public String name() { + return "tool-monitor"; + } + + @Override + public GenerationMiddleware newInstance() { + return this; // stateless — safe to reuse + } + + @Override + public ToolResponse wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + String toolName = params.getRequest().getName(); + logger.info("[tool-monitor] Executing tool: {}", toolName); + long start = System.currentTimeMillis(); + ToolResponse resp = next.apply(ctx, params); + logger.info( + "[tool-monitor] Tool {} completed in {}ms", toolName, System.currentTimeMillis() - start); + return resp; + } + } + + // ========================================================================= + // Example 4: Combined multi-hook middleware + // ========================================================================= + + /** + * A single middleware that implements all three hooks. Demonstrates that one middleware can + * observe every stage of the pipeline. + */ + static class FullObservabilityMiddleware extends BaseGenerationMiddleware { + + private final AtomicInteger iterations = new AtomicInteger(0); + private final AtomicInteger modelCalls = new AtomicInteger(0); + private final AtomicInteger toolCalls = new AtomicInteger(0); + + @Override + public String name() { + return "full-observability"; + } + + @Override + public GenerationMiddleware newInstance() { + return new FullObservabilityMiddleware(); + } + + @Override + public ModelResponse wrapGenerate(ActionContext ctx, GenerateParams params, GenerateNext next) + throws GenkitException { + int iter = iterations.incrementAndGet(); + logger.info("[observability] === Generate iteration {} ===", iter); + ModelResponse resp = next.apply(ctx, params); + logger.info( + "[observability] === Iteration {} done (model calls: {}, tool calls: {}) ===", + iter, + modelCalls.get(), + toolCalls.get()); + return resp; + } + + @Override + public ModelResponse wrapModel(ActionContext ctx, ModelParams params, ModelNext next) + throws GenkitException { + int call = modelCalls.incrementAndGet(); + logger.info("[observability] Model call #{}", call); + return next.apply(ctx, params); + } + + @Override + public ToolResponse wrapTool(ActionContext ctx, ToolParams params, ToolNext next) + throws GenkitException { + int call = toolCalls.incrementAndGet(); + logger.info("[observability] Tool call #{}: {}", call, params.getRequest().getName()); + return next.apply(ctx, params); + } + } + + // ========================================================================= + // Main + // ========================================================================= + + public static void main(String[] args) throws Exception { + JettyPlugin jetty = new JettyPlugin(JettyPluginOptions.builder().port(8080).build()); + + Genkit genkit = + Genkit.builder() + .options(GenkitOptions.builder().devMode(true).reflectionPort(3100).build()) + .plugin(OpenAIPlugin.create()) + .plugin(jetty) + .build(); + + // Instantiate middleware (templates — newInstance() is called per generate()) + GenerationMiddleware modelLogging = new ModelLoggingMiddleware(); + GenerationMiddleware generateTiming = new GenerateTimingMiddleware(); + GenerationMiddleware toolMonitor = new ToolMonitorMiddleware(); + GenerationMiddleware fullObservability = new FullObservabilityMiddleware(); + + // Define a simple tool so the WrapTool hook gets exercised + @SuppressWarnings("unchecked") + Tool, Map> weatherTool = + genkit.defineTool( + "getWeather", + "Gets the current weather for a given city", + Map.of( + "type", + "object", + "properties", + Map.of("city", Map.of("type", "string", "description", "The city name")), + "required", + new String[] {"city"}), + (Class>) (Class) Map.class, + (ctx, input) -> { + String city = (String) input.get("city"); + Map weather = new HashMap<>(); + weather.put("city", city); + weather.put("temperature", "22°C"); + weather.put("conditions", "Sunny"); + return weather; + }); + + // ======================================================= + // Flow 1: Simple chat with model logging + generate timing + // ======================================================= + + Flow chatFlow = + genkit.defineFlow( + "v2-chat", + String.class, + String.class, + (ctx, userMessage) -> { + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .system("You are a helpful assistant. Be concise.") + .prompt(userMessage) + .use(modelLogging, generateTiming) + .config( + GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(200) + .build()) + .build()); + return response.getText(); + }); + + // ======================================================= + // Flow 2: Chat with all three hooks via full observability + // ======================================================= + + Flow observableFlow = + genkit.defineFlow( + "v2-observable", + String.class, + String.class, + (ctx, userMessage) -> { + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .system( + "You are a helpful assistant. Use the getWeather tool when asked about weather.") + .prompt(userMessage) + .tools(List.of(weatherTool)) + .use(fullObservability) + .config( + GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(300) + .build()) + .build()); + return response.getText(); + }); + + // ======================================================= + // Flow 3: Stacking multiple middleware together + // ======================================================= + + Flow stackedFlow = + genkit.defineFlow( + "v2-stacked", + String.class, + String.class, + (ctx, userMessage) -> { + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt(userMessage) + .use(modelLogging, generateTiming, toolMonitor) + .config( + GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(200) + .build()) + .build()); + return response.getText(); + }); + + // ======================================================= + // Flow 4: No middleware (baseline for comparison) + // ======================================================= + + Flow baselineFlow = + genkit.defineFlow( + "v2-baseline", + String.class, + String.class, + (ctx, userMessage) -> { + ModelResponse response = + genkit.generate( + GenerateOptions.builder() + .model("openai/gpt-4o-mini") + .prompt(userMessage) + .config( + GenerationConfig.builder() + .temperature(0.7) + .maxOutputTokens(200) + .build()) + .build()); + return response.getText(); + }); + + logger.info("\n========================================"); + logger.info("Genkit Middleware V2 Sample Started!"); + logger.info("========================================\n"); + + logger.info("Available flows:"); + logger.info(" - v2-chat: Model logging + generate timing middleware"); + logger.info(" - v2-observable: Full observability (all 3 hooks in one middleware)"); + logger.info(" - v2-stacked: Three separate middleware stacked together"); + logger.info(" - v2-baseline: No middleware (baseline comparison)\n"); + + logger.info("Server running on http://localhost:8080"); + logger.info("Reflection server running on http://localhost:3100"); + logger.info("\nExample requests:"); + logger.info( + " curl -X POST http://localhost:8080/v2-chat -H 'Content-Type: application/json' -d '\"What is middleware?\"'"); + logger.info( + " curl -X POST http://localhost:8080/v2-observable -H 'Content-Type: application/json' -d '\"Explain Java records\"'"); + logger.info( + " curl -X POST http://localhost:8080/v2-stacked -H 'Content-Type: application/json' -d '\"Hello world\"'"); + logger.info( + " curl -X POST http://localhost:8080/v2-baseline -H 'Content-Type: application/json' -d '\"Hello world\"'"); + + jetty.start(); + } +} diff --git a/samples/middleware-v2/src/main/resources/logback.xml b/samples/middleware-v2/src/main/resources/logback.xml new file mode 100644 index 000000000..fe98c37a8 --- /dev/null +++ b/samples/middleware-v2/src/main/resources/logback.xml @@ -0,0 +1,26 @@ + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + + + +