Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.google.genai;

import com.google.genai.types.UrlContext;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -489,6 +490,8 @@ Prompt buildRequestPrompt(Prompt prompt) {
this.defaultOptions.getSafetySettings()));
requestOptions
.setLabels(ModelOptionsUtils.mergeOption(runtimeOptions.getLabels(), this.defaultOptions.getLabels()));
requestOptions.setUrlContextEnabled(ModelOptionsUtils.mergeOption(runtimeOptions.getUrlContextEnabled(),
this.defaultOptions.getUrlContextEnabled()));
}
else {
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
Expand All @@ -499,6 +502,7 @@ Prompt buildRequestPrompt(Prompt prompt) {
requestOptions.setGoogleSearchRetrieval(this.defaultOptions.getGoogleSearchRetrieval());
requestOptions.setSafetySettings(this.defaultOptions.getSafetySettings());
requestOptions.setLabels(this.defaultOptions.getLabels());
requestOptions.setUrlContextEnabled(this.defaultOptions.getUrlContextEnabled());
}

ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
Expand Down Expand Up @@ -749,6 +753,14 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
tools.add(googleSearchRetrievalTool);
}

if (prompt.getOptions() instanceof GoogleGenAiChatOptions options && Boolean.TRUE.equals(options.getUrlContextEnabled())) {
final var urlContextTool = Tool.builder()
.urlContext(UrlContext.builder().build())
.build();

tools.add(urlContextTool);
}

if (!CollectionUtils.isEmpty(tools)) {
configBuilder.tools(tools);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions, Structure
private Map<String, String> labels = new HashMap<>();
// @formatter:on

/**
* Enable Google's UrlContext tool
*/
@JsonIgnore
private Boolean urlContextEnabled;

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -218,6 +224,7 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti
options.setUseCachedContent(fromOptions.getUseCachedContent());
options.setAutoCacheThreshold(fromOptions.getAutoCacheThreshold());
options.setAutoCacheTtl(fromOptions.getAutoCacheTtl());
options.setUrlContextEnabled(fromOptions.getUrlContextEnabled());
return options;
}

Expand Down Expand Up @@ -459,6 +466,14 @@ public void setOutputSchema(String jsonSchemaText) {
this.setResponseMimeType("application/json");
}

public Boolean getUrlContextEnabled() {
return this.urlContextEnabled;
}

public void setUrlContextEnabled(Boolean urlContextEnabled) {
this.urlContextEnabled = urlContextEnabled;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -481,7 +496,8 @@ public boolean equals(Object o) {
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.safetySettings, that.safetySettings)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels);
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels)
&& Objects.equals(this.urlContextEnabled, that.urlContextEnabled);
}

@Override
Expand All @@ -490,7 +506,7 @@ public int hashCode() {
this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model,
this.responseMimeType, this.responseSchema, this.toolCallbacks, this.toolNames,
this.googleSearchRetrieval, this.safetySettings, this.internalToolExecutionEnabled, this.toolContext,
this.labels);
this.labels, this.urlContextEnabled);
}

@Override
Expand All @@ -502,7 +518,7 @@ public String toString() {
+ this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks="
+ this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval="
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", labels=" + this.labels
+ '}';
+ ", urlContextEnabled=" + this.urlContextEnabled + '}';
}

@Override
Expand Down Expand Up @@ -671,6 +687,11 @@ public Builder autoCacheTtl(java.time.Duration autoCacheTtl) {
return this;
}

public Builder urlContextEnabled(boolean enabled) {
this.options.setUrlContextEnabled(enabled);
return this;
}

public GoogleGenAiChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,26 @@ public void createRequestWithGenerationConfigOptions() {
assertThat(request.config().responseMimeType().orElse("")).isEqualTo("application/json");
}

@Test
public void createRequestWithUrlContextToolEnabled() {

var client = GoogleGenAiChatModel.builder()
.genAiClient(this.genAiClient)
.defaultOptions(GoogleGenAiChatOptions.builder().model("DEFAULT_MODEL").urlContextEnabled(true).build())
.build();

GeminiRequest request = client
.createGeminiRequest(client.buildRequestPrompt(new Prompt("Test message content")));

assertThat(request.config().tools()).isPresent();
assertThat(request.config().tools().get()).anySatisfy(tool -> assertThat(tool.urlContext()).isPresent());

request = client.createGeminiRequest(client.buildRequestPrompt(
new Prompt("Test message content", GoogleGenAiChatOptions.builder().urlContextEnabled(false).build())));

assertThat(request.config().tools()).isEmpty();
}

@Test
public void createRequestWithThinkingBudget() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,27 @@ public void testLabelsWithEmptyMap() {
assertThat(options.getLabels()).isEmpty();
}

@Test
public void testUrlContextEnabledCopyAndEquality() {
GoogleGenAiChatOptions original = GoogleGenAiChatOptions.builder()
.model("test-model")
.urlContextEnabled(true)
.build();

GoogleGenAiChatOptions copy = original.copy();

assertThat(original.getUrlContextEnabled()).isTrue();
assertThat(copy.getUrlContextEnabled()).isTrue();
assertThat(copy).isEqualTo(original);
assertThat(copy).isNotSameAs(original);
assertThat(copy.toString()).contains("urlContextEnabled=true");

GoogleGenAiChatOptions different = GoogleGenAiChatOptions.builder()
.model("test-model")
.urlContextEnabled(false)
.build();

assertThat(original).isNotEqualTo(different);
assertThat(original.hashCode()).isNotEqualTo(different.hashCode());
}
}