diff --git a/dist/pom.xml b/dist/pom.xml index 9e0c0c6c5..428f04f93 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -85,6 +85,11 @@ under the License. flink-agents-integrations-chat-models-bedrock ${project.version} + + org.apache.flink + flink-agents-integrations-chat-models-gemini + ${project.version} + org.apache.flink flink-agents-integrations-embedding-models-ollama diff --git a/integrations/chat-models/gemini/pom.xml b/integrations/chat-models/gemini/pom.xml new file mode 100644 index 000000000..d99420142 --- /dev/null +++ b/integrations/chat-models/gemini/pom.xml @@ -0,0 +1,48 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-chat-models + 0.3-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-chat-models-gemini + Flink Agents : Integrations: Chat Models: Google Gemini + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + + com.google.genai + google-genai + ${google.genai.version} + + + + diff --git a/integrations/chat-models/gemini/src/main/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelConnection.java b/integrations/chat-models/gemini/src/main/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelConnection.java new file mode 100644 index 000000000..b0fb30df9 --- /dev/null +++ b/integrations/chat-models/gemini/src/main/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelConnection.java @@ -0,0 +1,542 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.chatmodels.gemini; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genai.Client; +import com.google.genai.types.Candidate; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponse; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.Part; +import com.google.genai.types.Tool; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A chat model integration for the Google Gemini {@code generateContent} API using the official + * google-genai Java SDK. + * + *

The native Gemini protocol differs from the OpenAI-compatible shape in a few places this + * module handles directly: + * + *

+ * + *

Supported connection parameters: + * + *

+ * + *

Example usage: + * + *

{@code
+ * public class MyAgent extends Agent {
+ *   @ChatModelConnection
+ *   public static ResourceDesc gemini() {
+ *     return ResourceDescriptor.Builder.newBuilder(GeminiChatModelConnection.class.getName())
+ *             .addInitialArgument("api_key", System.getenv("GEMINI_API_KEY"))
+ *             .addInitialArgument("model", "gemini-3.1-pro-preview")
+ *             .build();
+ *   }
+ * }
+ * }
+ */ +public class GeminiChatModelConnection extends BaseChatModelConnection { + + private static final Logger LOG = LoggerFactory.getLogger(GeminiChatModelConnection.class); + + private static final TypeReference> MAP_TYPE = new TypeReference<>() {}; + + private final ObjectMapper mapper = new ObjectMapper(); + private final Client client; + private final String defaultModel; + + public GeminiChatModelConnection( + ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + + String apiKey = descriptor.getArgument("api_key"); + String baseUrl = descriptor.getArgument("base_url"); + Boolean vertexAi = descriptor.getArgument("vertex_ai"); + + boolean useVertex = Boolean.TRUE.equals(vertexAi); + if (!useVertex + && (apiKey == null || apiKey.isBlank()) + && (baseUrl == null || baseUrl.isBlank())) { + throw new IllegalArgumentException( + "Either api_key or base_url must be provided for the Gemini connection."); + } + + Client.Builder builder = Client.builder(); + if (!useVertex) { + // The SDK requires a non-blank API key for the Gemini Developer backend. When the + // caller relies on a proxy (base_url) to inject the real credential, supply a + // placeholder so the SDK's own validation passes; the proxy overrides it on the wire. + if (apiKey != null && !apiKey.isBlank()) { + builder.apiKey(apiKey); + } else { + builder.apiKey("proxy-injected"); + } + } + + HttpOptions.Builder httpOptions = null; + if (baseUrl != null && !baseUrl.isBlank()) { + httpOptions = HttpOptions.builder().baseUrl(baseUrl); + } + Integer timeoutSeconds = descriptor.getArgument("timeout"); + if (timeoutSeconds != null && timeoutSeconds > 0) { + if (httpOptions == null) { + httpOptions = HttpOptions.builder(); + } + // HttpOptions timeout is expressed in milliseconds. Compute in long to avoid int + // overflow for large second values, then clamp to Integer.MAX_VALUE. + long timeoutMs = (long) timeoutSeconds * 1000L; + httpOptions.timeout((int) Math.min(timeoutMs, Integer.MAX_VALUE)); + } + if (httpOptions != null) { + builder.httpOptions(httpOptions.build()); + } + + if (useVertex) { + builder.vertexAI(true); + String project = descriptor.getArgument("project"); + String location = descriptor.getArgument("location"); + if (project != null && !project.isBlank()) { + builder.project(project); + } + if (location != null && !location.isBlank()) { + builder.location(location); + } + } + + this.defaultModel = descriptor.getArgument("model"); + this.client = builder.build(); + } + + @Override + public void close() { + this.client.close(); + } + + @Override + public ChatMessage chat( + List messages, + List tools, + Map arguments) { + Map args = arguments != null ? new HashMap<>(arguments) : new HashMap<>(); + + Object modelObj = args.remove("model"); + String modelName = modelObj != null ? modelObj.toString() : this.defaultModel; + if (modelName == null || modelName.isBlank()) { + modelName = this.defaultModel; + } + if (modelName == null || modelName.isBlank()) { + throw new IllegalArgumentException("model name must be provided for Gemini."); + } + + // ChatModelAction emits TOOL messages with only `externalId` in extraArgs (matching the + // sibling Anthropic/OpenAI connectors). Gemini's functionResponse part however requires the + // function name. Build a tool-call-id -> name lookup from prior ASSISTANT turns so the TOOL + // branch in convertToContent can recover the name from `externalId`. + Map toolCallIdToName = buildToolCallIdToNameMap(messages); + + try { + List contents = + messages.stream() + .filter(m -> m.getRole() != MessageRole.SYSTEM) + .map(m -> convertToContent(m, toolCallIdToName)) + .collect(Collectors.toList()); + + GenerateContentConfig config = buildConfig(messages, tools, args); + + GenerateContentResponse response = + client.models.generateContent(modelName, contents, config); + ChatMessage result = convertResponse(response); + + recordUsage(result, modelName, response); + + return result; + } catch (IllegalArgumentException e) { + // Preserve the validation-error contract: surface IAE unwrapped, consistent with the + // constructor. + throw e; + } catch (Exception e) { + throw new RuntimeException("Failed to call Gemini generateContent API.", e); + } + } + + // Package-visible for testing. Walks ASSISTANT messages and records every tool-call's + // `original_id` (or `id`) -> function `name` mapping so TOOL turns can resolve their name from + // `externalId` alone (which is what the runtime supplies). + static Map buildToolCallIdToNameMap(List messages) { + Map map = new HashMap<>(); + for (ChatMessage message : messages) { + if (message.getRole() != MessageRole.ASSISTANT) { + continue; + } + List> toolCalls = message.getToolCalls(); + if (toolCalls == null) { + continue; + } + for (Map call : toolCalls) { + Object id = call.get("original_id"); + if (id == null) { + id = call.get("id"); + } + Object function = call.get("function"); + if (id == null || !(function instanceof Map)) { + continue; + } + Object name = ((Map) function).get("name"); + if (name != null) { + map.put(id.toString(), name.toString()); + } + } + } + return map; + } + + private GenerateContentConfig buildConfig( + List messages, + List tools, + Map arguments) { + GenerateContentConfig.Builder builder = GenerateContentConfig.builder(); + + Content systemInstruction = extractSystemInstruction(messages); + if (systemInstruction != null) { + builder.systemInstruction(systemInstruction); + } + + Object temperature = arguments.remove("temperature"); + if (temperature instanceof Number) { + builder.temperature(((Number) temperature).floatValue()); + } + + Object maxOutputTokens = arguments.remove("max_output_tokens"); + if (maxOutputTokens instanceof Number) { + builder.maxOutputTokens(((Number) maxOutputTokens).intValue()); + } + + @SuppressWarnings("unchecked") + Map additionalKwargs = + (Map) arguments.remove("additional_kwargs"); + if (additionalKwargs != null) { + applyAdditionalKwargs(builder, additionalKwargs); + } + + if (tools != null && !tools.isEmpty()) { + builder.tools(List.of(convertTools(tools))); + } + + return builder.build(); + } + + // Package-visible for unit testing of the additional-kwargs forwarding. + void applyAdditionalKwargs(GenerateContentConfig.Builder builder, Map kwargs) { + for (Map.Entry entry : kwargs.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + if (value == null) { + continue; + } + switch (key) { + case "top_k": + // Gemini's protocol defines topK as a float, despite the OpenAI/Anthropic + // convention of an integer. + if (value instanceof Number) { + builder.topK(((Number) value).floatValue()); + } + break; + case "top_p": + if (value instanceof Number) { + builder.topP(((Number) value).floatValue()); + } + break; + case "stop_sequences": + if (value instanceof List) { + List stopSequences = + ((List) value) + .stream() + .filter(Objects::nonNull) + .map(Object::toString) + .collect(Collectors.toList()); + builder.stopSequences(stopSequences); + } + break; + default: + // The Gemini SDK's GenerateContentConfig.Builder is AutoValue-generated and + // does not accept arbitrary body fields (unlike Anthropic/OpenAI which expose + // putAdditionalBodyProperty). Surface a warning so the user can see which key + // was dropped instead of silently mis-configuring sampling. + LOG.warn( + "Ignoring additional_kwargs.{}: not recognized by the Gemini connector" + + " (supported keys: top_k, top_p, stop_sequences).", + key); + break; + } + } + } + + private Tool convertTools(List tools) { + List declarations = new ArrayList<>(tools.size()); + for (org.apache.flink.agents.api.tools.Tool tool : tools) { + ToolMetadata metadata = tool.getMetadata(); + FunctionDeclaration.Builder builder = + FunctionDeclaration.builder() + .name(metadata.getName()) + .description(metadata.getDescription()); + + String schema = metadata.getInputSchema(); + if (schema != null && !schema.isBlank()) { + builder.parametersJsonSchema(parseSchema(schema)); + } + + declarations.add(builder.build()); + } + return Tool.builder().functionDeclarations(declarations).build(); + } + + private Content extractSystemInstruction(List messages) { + Part[] parts = + messages.stream() + .filter(m -> m.getRole() == MessageRole.SYSTEM) + .map(m -> Part.fromText(Optional.ofNullable(m.getContent()).orElse(""))) + .toArray(Part[]::new); + return parts.length == 0 ? null : Content.fromParts(parts); + } + + // Package-visible for unit testing of the message conversion. + Content convertToContent(ChatMessage message, Map toolCallIdToName) { + MessageRole role = message.getRole(); + String content = Optional.ofNullable(message.getContent()).orElse(""); + + switch (role) { + case USER: + return Content.builder() + .role("user") + .parts(List.of(Part.fromText(content))) + .build(); + + case ASSISTANT: + List parts = new ArrayList<>(); + if (!content.isEmpty()) { + parts.add(Part.fromText(content)); + } + List> toolCalls = message.getToolCalls(); + if (toolCalls != null) { + for (Map call : toolCalls) { + parts.add(convertToolCallToPart(call)); + } + } + if (parts.isEmpty()) { + parts.add(Part.fromText("")); + } + return Content.builder().role("model").parts(parts).build(); + + case TOOL: + String functionName = resolveToolFunctionName(message, toolCallIdToName); + Map responseMap = new LinkedHashMap<>(); + responseMap.put("result", content); + return Content.builder() + .role("user") + .parts(List.of(Part.fromFunctionResponse(functionName, responseMap))) + .build(); + + default: + throw new IllegalArgumentException("Unsupported role: " + role); + } + } + + private static String resolveToolFunctionName( + ChatMessage toolMessage, Map toolCallIdToName) { + // 1. Honor an explicit `name` if the caller supplied one. + Object explicit = toolMessage.getExtraArgs().get("name"); + if (explicit != null) { + return explicit.toString(); + } + // 2. Otherwise look up the function name via the tool-call id the runtime supplies as + // `externalId` (set equal to the assistant turn's `original_id` by ToolCallAction). + Object externalId = toolMessage.getExtraArgs().get("externalId"); + if (externalId != null && toolCallIdToName != null) { + String name = toolCallIdToName.get(externalId.toString()); + if (name != null) { + return name; + } + } + throw new IllegalArgumentException( + "Tool message must carry the function name: provide either 'name' in extraArgs, or" + + " an 'externalId' matching a prior ASSISTANT tool-call's id."); + } + + // Package-visible for unit testing of the tool-call round-trip. + Part convertToolCallToPart(Map call) { + Map functionPayload = toMap(call.get("function")); + String functionName = String.valueOf(functionPayload.get("name")); + Map argsMap = toMap(functionPayload.get("arguments")); + + FunctionCall.Builder fcBuilder = FunctionCall.builder().name(functionName).args(argsMap); + Object originalId = call.get("original_id"); + if (originalId != null) { + fcBuilder.id(originalId.toString()); + } + + Part.Builder partBuilder = Part.builder().functionCall(fcBuilder.build()); + // Echo back the thoughtSignature captured from the model response (Gemini 3 requirement). + Object signature = call.get("thought_signature"); + if (signature != null) { + partBuilder.thoughtSignature(Base64.getDecoder().decode(signature.toString())); + } + return partBuilder.build(); + } + + private Object parseSchema(String schemaJson) { + try { + return mapper.readValue(schemaJson, MAP_TYPE); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse tool schema JSON.", e); + } + } + + private ChatMessage convertResponse(GenerateContentResponse response) { + // Walk the first candidate's parts directly (rather than the response.text()/ + // functionCalls() conveniences) so we can capture the part-level thoughtSignature that + // Gemini 3 emits alongside each functionCall and requires to be echoed back on the next + // turn. + StringBuilder textContent = new StringBuilder(); + List> toolCalls = new ArrayList<>(); + + List candidates = response.candidates().orElseGet(List::of); + if (candidates.isEmpty()) { + throw new IllegalStateException( + "Gemini response did not contain any candidates (likely safety-blocked or" + + " filtered)."); + } + // Let the SDK validate the finish reason: this raises IllegalArgumentException when the + // model finished for an unexpected reason (SAFETY, MAX_TOKENS, RECITATION, …) instead of + // silently returning a truncated or filtered message. The IAE is propagated unwrapped by + // chat()'s catch block, matching the constructor's error contract. + response.checkFinishReason(); + + List parts = candidates.get(0).content().flatMap(Content::parts).orElseGet(List::of); + + for (Part part : parts) { + part.text().ifPresent(textContent::append); + part.functionCall() + .ifPresent( + fc -> + toolCalls.add( + convertFunctionCall( + fc, part.thoughtSignature().orElse(null)))); + } + + ChatMessage chatMessage = ChatMessage.assistant(textContent.toString()); + if (!toolCalls.isEmpty()) { + chatMessage.setToolCalls(toolCalls); + } + return chatMessage; + } + + // Package-visible for unit testing of the function-call parsing. + Map convertFunctionCall(FunctionCall functionCall, byte[] thoughtSignature) { + String id = functionCall.id().orElse(null); + String name = functionCall.name().orElse(""); + Map argsMap = functionCall.args().orElseGet(LinkedHashMap::new); + + Map functionMap = new LinkedHashMap<>(); + functionMap.put("name", name); + functionMap.put("arguments", argsMap); + + Map toolCall = new LinkedHashMap<>(); + if (id != null) { + toolCall.put("id", id); + toolCall.put("original_id", id); + } + toolCall.put("type", "function"); + toolCall.put("function", functionMap); + // Gemini 3 requires the opaque thoughtSignature to be echoed back when the tool-call turn + // is replayed. Stash it as Base64 so it survives the Map representation. + if (thoughtSignature != null) { + toolCall.put("thought_signature", Base64.getEncoder().encodeToString(thoughtSignature)); + } + return toolCall; + } + + private void recordUsage( + ChatMessage result, String modelName, GenerateContentResponse response) { + GenerateContentResponseUsageMetadata usage = response.usageMetadata().orElse(null); + if (usage == null) { + return; + } + long promptTokens = usage.promptTokenCount().orElse(0); + long completionTokens = usage.candidatesTokenCount().orElse(0); + result.getExtraArgs().put("model_name", modelName); + result.getExtraArgs().put("promptTokens", promptTokens); + result.getExtraArgs().put("completionTokens", completionTokens); + } + + private Map toMap(Object value) { + if (value instanceof Map) { + @SuppressWarnings("unchecked") + Map casted = (Map) value; + return new LinkedHashMap<>(casted); + } + if (value == null) { + return new LinkedHashMap<>(); + } + return mapper.convertValue(value, MAP_TYPE); + } +} diff --git a/integrations/chat-models/gemini/src/main/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelSetup.java b/integrations/chat-models/gemini/src/main/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelSetup.java new file mode 100644 index 000000000..0cb843949 --- /dev/null +++ b/integrations/chat-models/gemini/src/main/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelSetup.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integrations.chatmodels.gemini; + +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Chat model setup for the Google Gemini {@code generateContent} API. + * + *

Responsible for providing per-chat configuration such as model, temperature, max output + * tokens, tool bindings, and additional Gemini parameters. The setup delegates execution to {@link + * GeminiChatModelConnection}. + * + *

Supported parameters: + * + *

    + *
  • connection (required): Name of the GeminiChatModelConnection resource + *
  • model (optional): Model name (default: gemini-3.1-pro-preview) + *
  • temperature (optional): Sampling temperature 0.0-2.0 (default: 0.1) + *
  • max_output_tokens (optional): Maximum tokens in response (default: 1024) + *
  • tools (optional): List of tool names available for the model to use + *
  • additional_kwargs (optional): Additional parameters (e.g. top_k, top_p) + *
+ * + *

Example usage: + * + *

{@code
+ * public class MyAgent extends Agent {
+ *   @ChatModelSetup
+ *   public static ResourceDesc gemini() {
+ *     return ResourceDescriptor.Builder.newBuilder(GeminiChatModelSetup.class.getName())
+ *             .addInitialArgument("connection", "myGeminiConnection")
+ *             .addInitialArgument("model", "gemini-3.1-pro-preview")
+ *             .addInitialArgument("temperature", 0.3d)
+ *             .addInitialArgument("max_output_tokens", 2048)
+ *             .addInitialArgument("tools", List.of("getWeather"))
+ *             .build();
+ *   }
+ * }
+ * }
+ */ +public class GeminiChatModelSetup extends BaseChatModelSetup { + + private static final String DEFAULT_MODEL = "gemini-3.1-pro-preview"; + private static final double DEFAULT_TEMPERATURE = 0.1d; + private static final long DEFAULT_MAX_OUTPUT_TOKENS = 1024L; + + private final Double temperature; + private final Long maxOutputTokens; + private final Map additionalArguments; + + public GeminiChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { + super(descriptor, resourceContext); + this.temperature = + Optional.ofNullable(descriptor.getArgument("temperature")) + .map(Number::doubleValue) + .orElse(DEFAULT_TEMPERATURE); + if (this.temperature < 0.0 || this.temperature > 2.0) { + throw new IllegalArgumentException("temperature must be between 0.0 and 2.0"); + } + + this.maxOutputTokens = + Optional.ofNullable(descriptor.getArgument("max_output_tokens")) + .map(Number::longValue) + .orElse(DEFAULT_MAX_OUTPUT_TOKENS); + if (this.maxOutputTokens <= 0) { + throw new IllegalArgumentException("max_output_tokens must be greater than 0"); + } + + this.additionalArguments = + Optional.ofNullable( + descriptor.>getArgument("additional_kwargs")) + .map(HashMap::new) + .orElseGet(HashMap::new); + + if (this.model == null || this.model.isBlank()) { + this.model = DEFAULT_MODEL; + } + } + + public GeminiChatModelSetup( + String model, + double temperature, + long maxOutputTokens, + Map additionalArguments, + List tools, + ResourceContext resourceContext) { + this( + createDescriptor(model, temperature, maxOutputTokens, additionalArguments, tools), + resourceContext); + } + + @Override + public Map getParameters() { + Map parameters = new HashMap<>(); + if (model != null) { + parameters.put("model", model); + } + parameters.put("temperature", temperature); + parameters.put("max_output_tokens", maxOutputTokens); + if (additionalArguments != null && !additionalArguments.isEmpty()) { + parameters.put("additional_kwargs", additionalArguments); + } + return parameters; + } + + private static ResourceDescriptor createDescriptor( + String model, + double temperature, + long maxOutputTokens, + Map additionalArguments, + List tools) { + ResourceDescriptor.Builder builder = + ResourceDescriptor.Builder.newBuilder(GeminiChatModelSetup.class.getName()) + .addInitialArgument("model", model) + .addInitialArgument("temperature", temperature) + .addInitialArgument("max_output_tokens", maxOutputTokens); + + if (additionalArguments != null && !additionalArguments.isEmpty()) { + builder.addInitialArgument("additional_kwargs", additionalArguments); + } + if (tools != null && !tools.isEmpty()) { + builder.addInitialArgument("tools", tools); + } + + return builder.build(); + } +} diff --git a/integrations/chat-models/gemini/src/test/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelConnectionTest.java b/integrations/chat-models/gemini/src/test/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelConnectionTest.java new file mode 100644 index 000000000..abc616f62 --- /dev/null +++ b/integrations/chat-models/gemini/src/test/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelConnectionTest.java @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.gemini; + +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link GeminiChatModelConnection}. These exercise the protocol-conversion logic + * with no network access, so they run in CI without any API key. + */ +class GeminiChatModelConnectionTest { + + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); + + private static ResourceDescriptor descriptor(String apiKey, String baseUrl, String model) { + ResourceDescriptor.Builder b = + ResourceDescriptor.Builder.newBuilder(GeminiChatModelConnection.class.getName()); + if (apiKey != null) { + b.addInitialArgument("api_key", apiKey); + } + if (baseUrl != null) { + b.addInitialArgument("base_url", baseUrl); + } + if (model != null) { + b.addInitialArgument("model", model); + } + return b.build(); + } + + private static GeminiChatModelConnection connection() { + return new GeminiChatModelConnection( + descriptor("test-key", null, "gemini-3-pro-preview"), NOOP); + } + + @Test + @DisplayName("Constructor with api_key creates a connection") + void testConstructorWithApiKey() { + GeminiChatModelConnection conn = connection(); + assertThat(conn).isInstanceOf(BaseChatModelConnection.class); + } + + @Test + @DisplayName("Constructor with base_url (proxy) creates a connection without api_key") + void testConstructorWithBaseUrl() { + GeminiChatModelConnection conn = + new GeminiChatModelConnection( + descriptor(null, "http://127.0.0.1:15799", "gemini-3-pro-preview"), NOOP); + assertThat(conn).isInstanceOf(BaseChatModelConnection.class); + } + + @Test + @DisplayName("Constructor throws when neither api_key nor base_url is provided") + void testConstructorThrowsWithoutCredentials() { + assertThatThrownBy( + () -> + new GeminiChatModelConnection( + descriptor(null, null, "gemini-3-pro-preview"), NOOP)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("api_key or base_url"); + } + + @Test + @DisplayName( + "Vertex AI path is wired but not e2e-tested in CI. We only assert here that " + + "vertex_ai=true does NOT silently fall through to the Developer-API " + + "construction success path; either it succeeds with ADC, or it surfaces a " + + "credentials / configuration error. A real Vertex run is a follow-up.") + void testConstructorVertexAiIsWired() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(GeminiChatModelConnection.class.getName()) + .addInitialArgument("vertex_ai", true) + .addInitialArgument("project", "test-project-does-not-exist") + .addInitialArgument("location", "us-central1") + .addInitialArgument("model", "gemini-3-pro-preview") + .build(); + // Two acceptable outcomes: + // 1. CI/dev box without ADC -> the SDK throws while resolving credentials. + // 2. A machine with ADC configured -> construction succeeds. We close the client to + // release resources. + // What must NOT happen: vertex_ai is silently ignored and the Developer-API path is taken, + // which would mean the Vertex flag is dead code. + try { + GeminiChatModelConnection conn = new GeminiChatModelConnection(desc, NOOP); + // Reached only when ADC is configured locally. Smoke-checked the build path. + assertThat(conn).isInstanceOf(BaseChatModelConnection.class); + conn.close(); + } catch (RuntimeException e) { + // ADC missing: the SDK surfaces a credentials error. The exact message is SDK-internal; + // the important assertion is that an error was raised, not silent fallthrough. + assertThat(e).isNotNull(); + } + } + + @Test + @DisplayName("convertToContent maps USER role to a Gemini user turn") + void testConvertUserMessage() { + Content content = + connection().convertToContent(ChatMessage.user("hello"), Collections.emptyMap()); + assertThat(content.role()).hasValue("user"); + assertThat(content.parts().orElseThrow().get(0).text()).hasValue("hello"); + } + + @Test + @DisplayName("convertToContent maps ASSISTANT role to a Gemini model turn") + void testConvertAssistantMessage() { + Content content = + connection() + .convertToContent( + ChatMessage.assistant("hi there"), Collections.emptyMap()); + assertThat(content.role()).hasValue("model"); + assertThat(content.parts().orElseThrow().get(0).text()).hasValue("hi there"); + } + + @Test + @DisplayName("convertToContent uses explicit `name` in extraArgs when supplied") + void testConvertToolMessageWithExplicitName() { + ChatMessage tool = ChatMessage.tool("sunny, 22C"); + tool.getExtraArgs().put("name", "get_weather"); + + Content content = connection().convertToContent(tool, Collections.emptyMap()); + assertThat(content.role()).hasValue("user"); + Part part = content.parts().orElseThrow().get(0); + assertThat(part.functionResponse()).isPresent(); + assertThat(part.functionResponse().orElseThrow().name()).hasValue("get_weather"); + } + + @Test + @DisplayName( + "convertToContent resolves the function name from `externalId` when the runtime omits " + + "`name` (matches ChatModelAction's emission shape)") + void testRuntimeShapeToolMessageResolvesNameFromExternalId() { + // Runtime contract: ChatModelAction emits TOOL messages with only `externalId` in + // extraArgs, matching how Anthropic/OpenAI siblings work. The name must be recovered from + // the prior ASSISTANT turn's tool-call map. + ChatMessage tool = ChatMessage.tool("sunny, 22C"); + tool.getExtraArgs().put("externalId", "call_abc"); + + Map idToName = Map.of("call_abc", "get_weather"); + + Content content = connection().convertToContent(tool, idToName); + assertThat(content.role()).hasValue("user"); + Part part = content.parts().orElseThrow().get(0); + assertThat(part.functionResponse()).isPresent(); + assertThat(part.functionResponse().orElseThrow().name()).hasValue("get_weather"); + } + + @Test + @DisplayName( + "convertToContent throws only when the function name truly cannot be resolved (no " + + "`name`, no matching `externalId`)") + void testConvertToolMessageThrowsWhenUnresolvable() { + ChatMessage tool = ChatMessage.tool("result"); + tool.getExtraArgs().put("externalId", "call_unknown"); + + assertThatThrownBy(() -> connection().convertToContent(tool, Collections.emptyMap())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("function name"); + } + + @Test + @DisplayName("convertFunctionCall captures name, args, id and Base64 thoughtSignature") + void testConvertFunctionCall() { + FunctionCall fc = + FunctionCall.builder() + .id("call_1") + .name("get_weather") + .args(Map.of("city", "Tokyo")) + .build(); + byte[] signature = new byte[] {1, 2, 3, 4}; + + Map toolCall = connection().convertFunctionCall(fc, signature); + + assertThat(toolCall).containsEntry("id", "call_1").containsEntry("original_id", "call_1"); + assertThat(toolCall).containsEntry("type", "function"); + @SuppressWarnings("unchecked") + Map function = (Map) toolCall.get("function"); + assertThat(function).containsEntry("name", "get_weather"); + assertThat(function.get("arguments")).isEqualTo(Map.of("city", "Tokyo")); + assertThat(toolCall.get("thought_signature")) + .isEqualTo(Base64.getEncoder().encodeToString(signature)); + } + + @Test + @DisplayName("convertFunctionCall omits thought_signature when absent") + void testConvertFunctionCallNoSignature() { + FunctionCall fc = FunctionCall.builder().name("noop").args(Map.of()).build(); + Map toolCall = connection().convertFunctionCall(fc, null); + assertThat(toolCall).doesNotContainKey("thought_signature"); + } + + @Test + @DisplayName("Tool-call round-trip preserves name, args and thoughtSignature") + void testToolCallRoundTrip() { + byte[] signature = new byte[] {9, 8, 7}; + FunctionCall fc = + FunctionCall.builder() + .id("c1") + .name("get_weather") + .args(Map.of("city", "Osaka")) + .build(); + + GeminiChatModelConnection conn = connection(); + Map toolCall = conn.convertFunctionCall(fc, signature); + Part part = conn.convertToolCallToPart(toolCall); + + assertThat(part.functionCall()).isPresent(); + FunctionCall rebuilt = part.functionCall().orElseThrow(); + assertThat(rebuilt.name()).hasValue("get_weather"); + assertThat(rebuilt.args().orElseThrow()).containsEntry("city", "Osaka"); + assertThat(part.thoughtSignature()).isPresent(); + assertThat(part.thoughtSignature().orElseThrow()).isEqualTo(signature); + } + + @Test + @DisplayName("convertToContent embeds tool calls into the assistant model turn") + void testAssistantWithToolCalls() { + FunctionCall fc = + FunctionCall.builder() + .id("c2") + .name("get_weather") + .args(Map.of("city", "Kyoto")) + .build(); + Map toolCall = connection().convertFunctionCall(fc, null); + ChatMessage assistant = ChatMessage.assistant("", List.of(toolCall)); + + Content content = connection().convertToContent(assistant, Collections.emptyMap()); + assertThat(content.role()).hasValue("model"); + assertThat(content.parts().orElseThrow()) + .anySatisfy(p -> assertThat(p.functionCall()).isPresent()); + } + + @Test + @DisplayName( + "buildToolCallIdToNameMap mirrors what ChatModelAction emits: ASSISTANT turn carries " + + "tool-call map, follow-up TOOL turn carries only externalId") + void testRuntimeShapeMultiTurn() { + // Step 1: simulate the assistant's tool-call turn produced by convertFunctionCall. + FunctionCall fc = + FunctionCall.builder() + .id("call_xyz") + .name("get_weather") + .args(Map.of("city", "Tokyo")) + .build(); + Map toolCall = connection().convertFunctionCall(fc, null); + ChatMessage assistantTurn = ChatMessage.assistant("", List.of(toolCall)); + + // Step 2: the runtime emits a TOOL message with only externalId (no name). + Map toolExtras = new HashMap<>(); + toolExtras.put("externalId", "call_xyz"); + ChatMessage toolTurn = new ChatMessage(MessageRole.TOOL, "sunny, 22C", toolExtras); + + List conversation = + List.of(ChatMessage.user("weather in Tokyo?"), assistantTurn, toolTurn); + + Map idToName = + GeminiChatModelConnection.buildToolCallIdToNameMap(conversation); + assertThat(idToName).containsEntry("call_xyz", "get_weather"); + + // Round-trip: TOOL message converts to a functionResponse with the recovered name. + Content content = connection().convertToContent(toolTurn, idToName); + assertThat(content.parts().orElseThrow().get(0).functionResponse().orElseThrow().name()) + .hasValue("get_weather"); + } + + @Test + @DisplayName( + "applyAdditionalKwargs forwards top_k, top_p and stop_sequences onto the " + + "GenerateContentConfig (mirrors Anthropic's `additional_kwargs` path)") + void testApplyAdditionalKwargs() { + GenerateContentConfig.Builder builder = GenerateContentConfig.builder(); + Map kwargs = + Map.of("top_k", 40, "top_p", 0.9, "stop_sequences", List.of("END", "STOP")); + + connection().applyAdditionalKwargs(builder, kwargs); + + GenerateContentConfig config = builder.build(); + assertThat(config.topK()).hasValue(40f); + assertThat(config.topP()).hasValue(0.9f); + assertThat(config.stopSequences().orElseThrow()).containsExactly("END", "STOP"); + } + + @Test + @DisplayName("applyAdditionalKwargs ignores unknown keys without throwing (logs a warning)") + void testApplyAdditionalKwargsIgnoresUnknown() { + GenerateContentConfig.Builder builder = GenerateContentConfig.builder(); + connection().applyAdditionalKwargs(builder, Map.of("not_a_real_param", "x")); + GenerateContentConfig config = builder.build(); + assertThat(config).isNotNull(); + // Unknown key must not leak into a known field. + assertThat(config.topK()).isEmpty(); + assertThat(config.topP()).isEmpty(); + } + + @Test + @DisplayName( + "applyAdditionalKwargs ignores known keys with the wrong value type without throwing " + + "(e.g. top_k as a String) — must not silently set a wrong value either") + void testApplyAdditionalKwargsIgnoresTypeMismatch() { + GenerateContentConfig.Builder builder = GenerateContentConfig.builder(); + connection() + .applyAdditionalKwargs( + builder, + Map.of( + "top_k", "fast", // wrong type + "stop_sequences", "STOP" // wrong type (should be List) + )); + GenerateContentConfig config = builder.build(); + assertThat(config.topK()).isEmpty(); + assertThat(config.stopSequences()).isEmpty(); + } +} diff --git a/integrations/chat-models/gemini/src/test/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelSetupTest.java b/integrations/chat-models/gemini/src/test/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelSetupTest.java new file mode 100644 index 000000000..3ef50ed4b --- /dev/null +++ b/integrations/chat-models/gemini/src/test/java/org/apache/flink/agents/integrations/chatmodels/gemini/GeminiChatModelSetupTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.gemini; + +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.ResourceContext; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link GeminiChatModelSetup}. */ +class GeminiChatModelSetupTest { + + private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null); + + private static ResourceDescriptor.Builder base() { + return ResourceDescriptor.Builder.newBuilder(GeminiChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn"); + } + + @Test + @DisplayName("getParameters applies default model, temperature and max output tokens") + void testGetParametersDefaults() { + GeminiChatModelSetup setup = new GeminiChatModelSetup(base().build(), NOOP); + + Map params = setup.getParameters(); + assertThat(params).containsEntry("model", "gemini-3.1-pro-preview"); + assertThat(params).containsEntry("temperature", 0.1); + assertThat(params).containsEntry("max_output_tokens", 1024L); + } + + @Test + @DisplayName("getParameters honors custom model, temperature and max output tokens") + void testGetParametersCustom() { + ResourceDescriptor desc = + base().addInitialArgument("model", "gemini-3-flash-preview") + .addInitialArgument("temperature", 0.7) + .addInitialArgument("max_output_tokens", 4096) + .build(); + GeminiChatModelSetup setup = new GeminiChatModelSetup(desc, NOOP); + + Map params = setup.getParameters(); + assertThat(params).containsEntry("model", "gemini-3-flash-preview"); + assertThat(params).containsEntry("temperature", 0.7); + assertThat(params).containsEntry("max_output_tokens", 4096L); + } + + @Test + @DisplayName("Constructor rejects out-of-range temperature") + void testInvalidTemperature() { + assertThatThrownBy( + () -> + new GeminiChatModelSetup( + base().addInitialArgument("temperature", 2.5).build(), + NOOP)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("temperature"); + } + + @Test + @DisplayName("Constructor rejects non-positive max output tokens") + void testInvalidMaxOutputTokens() { + assertThatThrownBy( + () -> + new GeminiChatModelSetup( + base().addInitialArgument("max_output_tokens", 0).build(), + NOOP)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("max_output_tokens"); + } + + @Test + @DisplayName("Extends BaseChatModelSetup") + void testInheritance() { + assertThat(new GeminiChatModelSetup(base().build(), NOOP)) + .isInstanceOf(BaseChatModelSetup.class); + } +} diff --git a/integrations/chat-models/pom.xml b/integrations/chat-models/pom.xml index e5f4b9d4f..ef8064528 100644 --- a/integrations/chat-models/pom.xml +++ b/integrations/chat-models/pom.xml @@ -34,6 +34,7 @@ under the License. anthropic azureai bedrock + gemini ollama openai diff --git a/integrations/pom.xml b/integrations/pom.xml index 754048813..25b73bed1 100644 --- a/integrations/pom.xml +++ b/integrations/pom.xml @@ -37,6 +37,7 @@ under the License. 4.8.0 2.11.1 2.32.16 + 1.56.0