Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
from functools import cached_property
import logging
import os
import re
from typing import Any
from typing import AsyncGenerator
Expand Down Expand Up @@ -54,6 +55,7 @@
_NEW_LINE = '\n'
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
_GOOGLE_API_VERSION_SUFFIX_PATTERN = re.compile(r'/?(v[0-9][a-z0-9.-]*)/?')
_API_VERSION_ENV_VARIABLE_NAME = 'GOOGLE_GENAI_API_VERSION'


_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """
Expand Down Expand Up @@ -123,6 +125,27 @@ def api_client(self) -> Client:
base_url: Optional[str] = None
"""The base URL for the AI platform service endpoint."""

api_version: Optional[str] = None
"""The API version to use for the AI platform service endpoint.

For the Vertex AI backend the google-genai SDK defaults to ``v1beta1``, which
exposes the latest preview features. Production deployments that require a
stable, SLA-eligible endpoint can set this to ``v1`` to use the GA Vertex AI
API. When unset, the ``GOOGLE_GENAI_API_VERSION`` environment variable is
consulted, and finally the SDK's own default is used so existing behavior is
unchanged.

An API version embedded in ``base_url`` (e.g.
``https://...googleapis.com/v1``) takes precedence over this field.

Sample:
```python
from google.adk.models import Gemini

agent = Agent(model=Gemini(model="gemini-2.5-pro", api_version="v1"))
```
"""

speech_config: Optional[types.SpeechConfig] = None

use_interactions_api: bool = False
Expand Down Expand Up @@ -371,9 +394,29 @@ def _api_backend(self) -> GoogleLLMVariant:
def _tracking_headers(self) -> dict[str, str]:
return get_tracking_headers()

def _configured_api_version(self) -> Optional[str]:
"""Returns the explicitly configured API version, if any.

Resolution order:
1. The ``api_version`` field set on this instance.
2. The ``GOOGLE_GENAI_API_VERSION`` environment variable.

Returns ``None`` when neither is set, in which case the google-genai SDK's
own default (``v1beta1`` for Vertex AI) applies, preserving existing
behavior.
"""
if self.api_version:
return self.api_version
return os.environ.get(_API_VERSION_ENV_VARIABLE_NAME) or None

@cached_property
def _base_url_and_api_version(self) -> tuple[Optional[str], Optional[str]]:
return _normalize_base_url_and_api_version(self.base_url)
base_url, api_version = _normalize_base_url_and_api_version(self.base_url)
# A version embedded in the base URL wins; otherwise fall back to the
# explicitly configured api_version (field or environment variable).
if api_version is None:
api_version = self._configured_api_version()
return base_url, api_version

@cached_property
def _live_api_version(self) -> str:
Expand Down
106 changes: 106 additions & 0 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,61 @@ def test_api_client_preserves_custom_base_url_path():
assert client._api_client._http_options.api_version == "v1beta"


def test_api_client_default_api_version_unchanged(monkeypatch):
"""Without configuration, ADK does not force an api_version (SDK default)."""
monkeypatch.delenv("GOOGLE_GENAI_API_VERSION", raising=False)
model = Gemini(model="gemini-2.5-flash")

# ADK leaves api_version unset so the google-genai SDK applies its own
# default (v1beta1 for Vertex AI), preserving existing behavior.
assert model._base_url_and_api_version == (None, None)


def test_api_client_uses_api_version_field():
"""The api_version field flows into the constructed client's http_options."""
model = Gemini(model="gemini-2.5-flash", api_version="v1")

client = model.api_client

assert client._api_client._http_options.api_version == "v1"


def test_api_client_uses_api_version_env_var(monkeypatch):
"""The GOOGLE_GENAI_API_VERSION env var flows into http_options."""
monkeypatch.setenv("GOOGLE_GENAI_API_VERSION", "v1")
model = Gemini(model="gemini-2.5-flash")

client = model.api_client

assert client._api_client._http_options.api_version == "v1"


def test_api_version_field_overrides_env_var(monkeypatch):
"""The explicit api_version field takes precedence over the env var."""
monkeypatch.setenv("GOOGLE_GENAI_API_VERSION", "v1beta1")
model = Gemini(model="gemini-2.5-flash", api_version="v1")

client = model.api_client

assert client._api_client._http_options.api_version == "v1"


def test_base_url_api_version_overrides_field():
"""A version embedded in base_url wins over the api_version field."""
model = Gemini(
model="gemini-2.5-flash",
base_url="https://generativelanguage.googleapis.com/v1alpha",
api_version="v1",
)

client = model.api_client

assert client._api_client._http_options.base_url == (
"https://generativelanguage.googleapis.com/"
)
assert client._api_client._http_options.api_version == "v1alpha"


def test_maybe_append_user_content(gemini_llm, llm_request):
# Test with user content already present
gemini_llm._maybe_append_user_content(llm_request)
Expand Down Expand Up @@ -766,6 +821,35 @@ async def mock_coro():
assert len(responses) == 2 if stream else 1


@pytest.mark.asyncio
async def test_generate_content_async_patches_api_version_from_field(
llm_request, generate_content_response
):
"""The configured api_version field is patched onto the request config."""
gemini_llm = Gemini(model="gemini-2.5-flash", api_version="v1")
llm_request.config.http_options = types.HttpOptions(
headers={"custom-header": "custom-value"}
)

with mock.patch.object(gemini_llm, "api_client") as mock_client:

async def mock_coro():
return generate_content_response

mock_client.aio.models.generate_content.return_value = mock_coro()

_ = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=False
)
]

call_args = mock_client.aio.models.generate_content.call_args
final_config = call_args.kwargs["config"]
assert final_config.http_options.api_version == "v1"


def test_live_api_version_vertex_ai(gemini_llm):
"""Test that _live_api_version returns 'v1beta1' for Vertex AI backend."""
with mock.patch.object(
Expand All @@ -774,6 +858,28 @@ def test_live_api_version_vertex_ai(gemini_llm):
assert gemini_llm._live_api_version == "v1beta1"


def test_live_api_version_uses_configured_field():
"""Test that _live_api_version honors the configured api_version field."""
gemini_llm = Gemini(model="gemini-2.5-flash", api_version="v1")

with mock.patch.object(
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
):
assert gemini_llm._live_api_version == "v1"


def test_live_api_client_uses_configured_field():
"""Test that _live_api_client http_options honors the api_version field."""
gemini_llm = Gemini(model="gemini-2.5-flash", api_version="v1")

with mock.patch.object(
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
):
client = gemini_llm._live_api_client

assert client._api_client._http_options.api_version == "v1"


def test_live_api_version_uses_google_base_url_version():
gemini_llm = Gemini(
model="gemini-2.5-flash",
Expand Down