From 1c0a033f5bad1ead9bbd016947ea9e2e0aabf99d Mon Sep 17 00:00:00 2001 From: NathanGrand Date: Tue, 2 Dec 2025 14:24:05 +0000 Subject: [PATCH] GH-5007: Fix handling when response contains both text and function calls Signed-off-by: NathanGrand --- .../ai/google/genai/GoogleGenAiChatModel.java | 58 +++++++------------ .../google/genai/GoogleGenAiChatModelIT.java | 50 ++++++++++++++++ .../GoogleGenAiCachedContentServiceTests.java | 2 +- 3 files changed, 73 insertions(+), 37 deletions(-) diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index be9da54ccc2..2706123076f 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -647,45 +648,30 @@ protected List responseCandidateToGeneration(Candidate candidate) { .finishReason(candidateFinishReason.toString()) .build(); - boolean isFunctionCall = candidate.content().isPresent() && candidate.content().get().parts().isPresent() - && candidate.content().get().parts().get().stream().allMatch(part -> part.functionCall().isPresent()); + List parts = candidate.content().get().parts().orElse(List.of()); - if (isFunctionCall) { - List assistantToolCalls = candidate.content() - .get() - .parts() - .orElse(List.of()) - .stream() - .filter(part -> part.functionCall().isPresent()) - .map(part -> { - FunctionCall functionCall = part.functionCall().get(); - var functionName = functionCall.name().orElse(""); - String functionArguments = mapToJson(functionCall.args().orElse(Map.of())); - return new AssistantMessage.ToolCall("", "function", functionName, functionArguments); - }) - .toList(); + List assistantToolCalls = parts.stream() + .filter(part -> part.functionCall().isPresent()) + .map(part -> { + FunctionCall functionCall = part.functionCall().get(); + var functionName = functionCall.name().orElse(""); + String functionArguments = mapToJson(functionCall.args().orElse(Map.of())); + return new AssistantMessage.ToolCall("", "function", functionName, functionArguments); + }) + .toList(); - AssistantMessage assistantMessage = AssistantMessage.builder() - .content("") - .properties(messageMetadata) - .toolCalls(assistantToolCalls) - .build(); + String text = parts.stream() + .filter(part -> part.text().isPresent() && !part.text().get().isEmpty()) + .map(part -> part.text().get()) + .collect(Collectors.joining(" ")); - return List.of(new Generation(assistantMessage, chatGenerationMetadata)); - } - else { - return candidate.content() - .get() - .parts() - .orElse(List.of()) - .stream() - .map(part -> AssistantMessage.builder() - .content(part.text().orElse("")) - .properties(messageMetadata) - .build()) - .map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata)) - .toList(); - } + AssistantMessage assistantMessage = AssistantMessage.builder() + .content(text) + .properties(messageMetadata) + .toolCalls(assistantToolCalls) + .build(); + + return List.of(new Generation(assistantMessage, chatGenerationMetadata)); } private ChatResponseMetadata toChatResponseMetadata(Usage usage, String modelVersion) { diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelIT.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelIT.java index d09d79d8078..bcb8202f7a6 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelIT.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelIT.java @@ -388,6 +388,41 @@ void jsonTextToolCallingTest() { assertThat(response).contains("2025-05-08T10:10:10+02:00"); } + /** + * See https://github.com/spring-projects/spring-ai/pull/4599 + */ + @Test + void testMixedPartsMessages() { + + ToolCallingManager toolCallingManager = ToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .build(); + + GoogleGenAiChatModel chatModelWithTools = GoogleGenAiChatModel.builder() + .genAiClient(genAiClient()) + .toolCallingManager(toolCallingManager) + .defaultOptions(GoogleGenAiChatOptions.builder() + .model(GoogleGenAiChatModel.ChatModel.GEMINI_2_5_FLASH) + .temperature(0.0) + .build()) + .build(); + + ChatClient chatClient = ChatClient.builder(chatModelWithTools).build(); + + // Create a prompt that will encourage gemini to explain why it is calling tools + // as it does. + AlarmTools alarmTools = new AlarmTools(); + String response = chatClient.prompt() + .tools(new CurrentTimeTools(), alarmTools) + .system("You MUST include reasoning when you issue tool calls.") + .user("Set an alarm for an hour from now, and tell me what time that was for.") + .call() + .content(); + + assertThat(response).isEqualTo("I have set an alarm for 11:10 AM."); + assertThat(alarmTools.getAlarm()).isEqualTo("2025-05-08T11:10:10+02:00"); + } + @Test void testThinkingBudgetGeminiProAutomaticDecisionByModel() { GoogleGenAiChatModel chatModelWithThinkingBudget = GoogleGenAiChatModel.builder() @@ -516,6 +551,21 @@ String getCurrentDateTime() { } + public static class AlarmTools { + + private String alarm; + + @Tool(description = "Set a user alarm for the given time, provided in ISO-8601 format") + void setAlarm(String time) { + this.alarm = time; + } + + public String getAlarm() { + return this.alarm; + } + + } + record ActorsFilmsRecord(String actor, List movies) { } diff --git a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/cache/GoogleGenAiCachedContentServiceTests.java b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/cache/GoogleGenAiCachedContentServiceTests.java index d9f8e32a9e1..83d7e6ef056 100644 --- a/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/cache/GoogleGenAiCachedContentServiceTests.java +++ b/models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/cache/GoogleGenAiCachedContentServiceTests.java @@ -152,7 +152,7 @@ void testUpdateCachedContent() { assertThat(updated.getName()).isEqualTo(name); assertThat(updated.getTtl()).isEqualTo(newTtl); assertThat(updated.getUpdateTime()).isNotNull(); - assertThat(updated.getUpdateTime()).isAfter(created.getCreateTime()); + assertThat(updated.getUpdateTime()).isAfterOrEqualTo(created.getCreateTime()); } @Test