diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 7710c4487..ac54ef55f 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -984,12 +984,7 @@ def _access_token(self) -> str: self.project = project if self._credentials: - if self._credentials.expired or not self._credentials.token: - # Only refresh when it needs to. Default expiration is 3600 seconds. - refresh_auth(self._credentials) - if not self._credentials.token: - raise RuntimeError('Could not resolve API token from the environment') - return self._credentials.token # type: ignore[no-any-return] + return get_token_from_credentials(self, self._credentials) # type: ignore[no-any-return] else: raise RuntimeError('Could not resolve API token from the environment') @@ -1034,18 +1029,10 @@ async def _async_access_token(self) -> Union[str, Any]: self.project = project if self._credentials: - if self._credentials.expired or not self._credentials.token: - # Only refresh when it needs to. Default expiration is 3600 seconds. - async_auth_lock = await self._get_async_auth_lock() - async with async_auth_lock: - if self._credentials.expired or not self._credentials.token: - # Double check that the credentials expired before refreshing. - await asyncio.to_thread(refresh_auth, self._credentials) - - if not self._credentials.token: - raise RuntimeError('Could not resolve API token from the environment') - - return self._credentials.token + return await async_get_token_from_credentials( + self, + self._credentials + ) # type: ignore[no-any-return] else: raise RuntimeError('Could not resolve API token from the environment') @@ -1925,3 +1912,35 @@ def __del__(self) -> None: asyncio.get_running_loop().create_task(self.aclose()) except Exception: # pylint: disable=broad-except pass + +def get_token_from_credentials( + client: 'BaseApiClient', + credentials: google.auth.credentials.Credentials +) -> str: + """Refreshes the authentication token for the given credentials.""" + if credentials.expired or not credentials.token: + # Only refresh when it needs to. Default expiration is 3600 seconds. + refresh_auth(credentials) + if not credentials.token: + raise RuntimeError('Could not resolve API token from the environment') + return credentials.token # type: ignore[no-any-return] + +async def async_get_token_from_credentials( + client: 'BaseApiClient', + credentials: google.auth.credentials.Credentials +) -> str: + """Refreshes the authentication token for the given credentials.""" + if credentials.expired or not credentials.token: + # Only refresh when it needs to. Default expiration is 3600 seconds. + async_auth_lock = await client._get_async_auth_lock() + async with async_auth_lock: + if credentials.expired or not credentials.token: + # Double check that the credentials expired before refreshing. + await asyncio.to_thread(refresh_auth, credentials) + + if not credentials.token: + raise RuntimeError('Could not resolve API token from the environment') + + return credentials.token # type: ignore[no-any-return] + + diff --git a/google/genai/files.py b/google/genai/files.py index 0558aa92c..2f13d8bf8 100644 --- a/google/genai/files.py +++ b/google/genai/files.py @@ -22,6 +22,9 @@ from typing import Any, Optional, Union from urllib.parse import urlencode +import google.auth + +from . import _api_client from . import _api_module from . import _common from . import _extra_utils @@ -149,6 +152,33 @@ def _ListFilesResponse_from_mldev( return to_object +def _RegisterFilesParameters_to_mldev( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['uris']) is not None: + setv(to_object, ['uris'], getv(from_object, ['uris'])) + + return to_object + + +def _RegisterFilesResponse_from_mldev( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['sdkHttpResponse']) is not None: + setv( + to_object, ['sdk_http_response'], getv(from_object, ['sdkHttpResponse']) + ) + + if getv(from_object, ['files']) is not None: + setv(to_object, ['files'], [item for item in getv(from_object, ['files'])]) + + return to_object + + class Files(_api_module.BaseModule): def _list( @@ -402,6 +432,69 @@ def delete( self._api_client._verify_response(return_value) return return_value + def _register_files( + self, + *, + uris: list[str], + config: Optional[types.RegisterFilesConfigOrDict] = None, + ) -> types.RegisterFilesResponse: + parameter_model = types._RegisterFilesParameters( + uris=uris, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if self._api_client.vertexai: + raise ValueError( + 'This method is only supported in the Gemini Developer client.' + ) + else: + request_dict = _RegisterFilesParameters_to_mldev(parameter_model) + request_url_dict = request_dict.get('_url') + if request_url_dict: + path = 'files:register'.format_map(request_url_dict) + else: + path = 'files:register' + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + request_dict.pop('config', None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request( + 'post', path, request_dict, http_options + ) + + if config is not None and getattr( + config, 'should_return_http_response', None + ): + return_value = types.RegisterFilesResponse(sdk_http_response=response) + self._api_client._verify_response(return_value) + return return_value + + response_dict = {} if not response.body else json.loads(response.body) + + if not self._api_client.vertexai: + response_dict = _RegisterFilesResponse_from_mldev(response_dict) + + return_value = types.RegisterFilesResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + + self._api_client._verify_response(return_value) + return return_value + def upload( self, *, @@ -559,6 +652,39 @@ def download( return data + def register_files( + self, + *, + auth: google.auth.credentials.Credentials, + uris: list[str], + config: Optional[types.RegisterFilesConfigOrDict] = None, + ) -> types.RegisterFilesResponse: + """Registers gcs files with the file service.""" + if not isinstance(auth, google.auth.credentials.Credentials): + raise ValueError( + 'auth must be a google.auth.credentials.Credentials object.' + ) + if config is None: + config = types.RegisterFilesConfig() + else: + config = types.RegisterFilesConfig.model_validate(config) + config = config.model_copy(deep=True) + + http_options = config.http_options or types.HttpOptions() + headers = http_options.headers or {} + headers = {k.lower(): v for k, v in headers.items()} + + token = _api_client.get_token_from_credentials(self._api_client, auth) + headers['authorization'] = f'Bearer {token}' + + if auth.quota_project_id: + headers['x-goog-user-project'] = auth.quota_project_id + + http_options.headers = headers + config.http_options = http_options + + return self._register_files(uris=uris, config=config) + def list( self, *, config: Optional[types.ListFilesConfigOrDict] = None ) -> Pager[types.File]: @@ -845,6 +971,69 @@ async def delete( self._api_client._verify_response(return_value) return return_value + async def _register_files( + self, + *, + uris: list[str], + config: Optional[types.RegisterFilesConfigOrDict] = None, + ) -> types.RegisterFilesResponse: + parameter_model = types._RegisterFilesParameters( + uris=uris, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if self._api_client.vertexai: + raise ValueError( + 'This method is only supported in the Gemini Developer client.' + ) + else: + request_dict = _RegisterFilesParameters_to_mldev(parameter_model) + request_url_dict = request_dict.get('_url') + if request_url_dict: + path = 'files:register'.format_map(request_url_dict) + else: + path = 'files:register' + + query_params = request_dict.get('_query') + if query_params: + path = f'{path}?{urlencode(query_params)}' + # TODO: remove the hack that pops config. + request_dict.pop('config', None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + 'post', path, request_dict, http_options + ) + + if config is not None and getattr( + config, 'should_return_http_response', None + ): + return_value = types.RegisterFilesResponse(sdk_http_response=response) + self._api_client._verify_response(return_value) + return return_value + + response_dict = {} if not response.body else json.loads(response.body) + + if not self._api_client.vertexai: + response_dict = _RegisterFilesResponse_from_mldev(response_dict) + + return_value = types.RegisterFilesResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + + self._api_client._verify_response(return_value) + return return_value + async def upload( self, *, @@ -992,6 +1181,41 @@ async def download( return data + async def register_files( + self, + *, + auth: google.auth.credentials.Credentials, + uris: list[str], + config: Optional[types.RegisterFilesConfigOrDict] = None, + ) -> types.RegisterFilesResponse: + """Registers gcs files with the file service.""" + if not isinstance(auth, google.auth.credentials.Credentials): + raise ValueError( + 'auth must be a google.auth.credentials.Credentials object.' + ) + if config is None: + config = types.RegisterFilesConfig() + else: + config = types.RegisterFilesConfig.model_validate(config) + config = config.model_copy(deep=True) + + http_options = config.http_options or types.HttpOptions() + headers = http_options.headers or {} + headers = {k.lower(): v for k, v in headers.items()} + + token = await _api_client.async_get_token_from_credentials( + self._api_client, auth + ) + headers['authorization'] = f'Bearer {token}' + + if auth.quota_project_id: + headers['x-goog-user-project'] = auth.quota_project_id + + http_options.headers = headers + config.http_options = http_options + + return await self._register_files(uris=uris, config=config) + async def list( self, *, config: Optional[types.ListFilesConfigOrDict] = None ) -> AsyncPager[types.File]: diff --git a/google/genai/tests/files/test_register.py b/google/genai/tests/files/test_register.py new file mode 100644 index 000000000..eeebe90ac --- /dev/null +++ b/google/genai/tests/files/test_register.py @@ -0,0 +1,272 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +"""Test files register method.""" + +import json +from unittest import mock + +from google.auth import credentials +import httpx +import pytest + +from ... import _api_client +from ... import Client +from ... import types +from .. import pytest_helper + + +class FakeCredentials(credentials.Credentials): + + def __init__(self, token="fake_token", expired=False, quota_project_id=None): + super().__init__() + self.token = token + self._expired = expired + self._quota_project_id = quota_project_id + self.refresh_count = 0 + + @property + def expired(self): + return self._expired + + @property + def quota_project_id(self): + return self._quota_project_id + + def refresh(self, request): + self.refresh_count += 1 + self.token = "refreshed_token" + self._expired = False + + +@mock.patch.object(_api_client.BaseApiClient, "_request_once", autospec=True) +def test_simple_token(mock_request): + client = Client(api_key="dummy_key") + captured_request = None + + def side_effect(self, http_request, stream=False): + nonlocal captured_request + captured_request = http_request + return _api_client.HttpResponse( + headers={}, + response_stream=[json.dumps({"files": [{"uri": "files/abc"}]})], + ) + + mock_request.side_effect = side_effect + + with pytest_helper.exception_if_vertex(client, ValueError): + response = client.files.register_files( + auth=FakeCredentials(token="test_token"), + uris=["gs://test-bucket/test-file-1.txt"], + ) + + assert len(response.files) == 1 + assert response.files[0].uri == "files/abc" + assert captured_request.headers["authorization"] == "Bearer test_token" + + +@mock.patch.object(_api_client.BaseApiClient, "_request_once", autospec=True) +def test_token_refresh(mock_request): + client = Client(api_key="dummy_key") + captured_request = None + + def side_effect(self, http_request, stream=False): + nonlocal captured_request + captured_request = http_request + return _api_client.HttpResponse( + headers={}, + response_stream=[json.dumps({"files": [{"uri": "files/abc"}]})], + ) + + mock_request.side_effect = side_effect + + with pytest_helper.exception_if_vertex(client, ValueError): + creds = FakeCredentials(expired=True) + response = client.files.register_files( + auth=creds, + uris=["gs://test-bucket/test-file-1.txt"], + ) + assert creds.refresh_count == 1 + assert len(response.files) == 1 + assert response.files[0].uri == "files/abc" + assert captured_request.headers["authorization"] == "Bearer refreshed_token" + + +@mock.patch.object(_api_client.BaseApiClient, "_request_once", autospec=True) +def test_quota_project(mock_request): + client = Client(api_key="dummy_key") + captured_request = None + + def side_effect(self, http_request, stream=False): + nonlocal captured_request + captured_request = http_request + return _api_client.HttpResponse( + headers={}, + response_stream=[json.dumps({"files": [{"uri": "files/abc"}]})], + ) + + mock_request.side_effect = side_effect + + with pytest_helper.exception_if_vertex(client, ValueError): + creds = FakeCredentials(quota_project_id="test_project") + response = client.files.register_files( + auth=creds, + uris=["gs://test-bucket/test-file-1.txt"], + ) + assert len(response.files) == 1 + assert response.files[0].uri == "files/abc" + assert captured_request.headers["x-goog-user-project"] == "test_project" + + +@mock.patch.object(_api_client.BaseApiClient, "_request_once", autospec=True) +def test_multiple_uris(mock_request): + client = Client(api_key="dummy_key") + + def side_effect(self, http_request, stream=False): + return _api_client.HttpResponse( + headers={}, + response_stream=[ + json.dumps({"files": [{"uri": "files/abc"}, {"uri": "files/def"}]}) + ], + ) + + mock_request.side_effect = side_effect + + with pytest_helper.exception_if_vertex(client, ValueError): + response = client.files.register_files( + auth=FakeCredentials(), + uris=[ + "gs://test-bucket/test-file-1.txt", + "gs://test-bucket/test-file-2.txt", + ], + ) + assert len(response.files) == 2 + assert response.files[0].uri == "files/abc" + assert response.files[1].uri == "files/def" + + +@pytest.mark.asyncio +@mock.patch.object( + _api_client.BaseApiClient, "_async_request_once", autospec=True +) +async def test_async_single(mock_request): + client = Client(api_key="dummy_key") + + async def side_effect(self, http_request, stream=False): + return _api_client.HttpResponse( + headers={}, + response_stream=[json.dumps({"files": [{"uri": "files/abc"}]})], + ) + + mock_request.side_effect = side_effect + + with pytest_helper.exception_if_vertex(client, ValueError): + response = await client.aio.files.register_files( + auth=FakeCredentials(), + uris=["gs://test-bucket/test-file-1.txt"], + ) + + assert len(response.files) == 1 + assert response.files[0].uri == "files/abc" + + +@pytest.mark.asyncio +@mock.patch.object( + _api_client.BaseApiClient, "_async_request_once", autospec=True +) +async def test_async_token_refresh(mock_request): + client = Client(api_key="dummy_key") + captured_request = None + + async def side_effect(self, http_request, stream=False): + nonlocal captured_request + captured_request = http_request + return _api_client.HttpResponse( + headers={}, + response_stream=[json.dumps({"files": [{"uri": "files/abc"}]})], + ) + + mock_request.side_effect = side_effect + + with pytest_helper.exception_if_vertex(client, ValueError): + creds = FakeCredentials(expired=True) + response = await client.aio.files.register_files( + auth=creds, + uris=["gs://test-bucket/test-file-1.txt"], + ) + assert creds.refresh_count == 1 + assert len(response.files) == 1 + assert response.files[0].uri == "files/abc" + assert captured_request.headers["authorization"] == "Bearer refreshed_token" + + +@pytest.mark.asyncio +@mock.patch.object( + _api_client.BaseApiClient, "_async_request_once", autospec=True +) +async def test_async_quota_project(mock_request): + client = Client(api_key="dummy_key") + captured_request = None + + async def side_effect(self, http_request, stream=False): + nonlocal captured_request + captured_request = http_request + return _api_client.HttpResponse( + headers={}, + response_stream=[json.dumps({"files": [{"uri": "files/abc"}]})], + ) + + mock_request.side_effect = side_effect + + with pytest_helper.exception_if_vertex(client, ValueError): + creds = FakeCredentials(quota_project_id="test_project") + response = await client.aio.files.register_files( + auth=creds, + uris=["gs://test-bucket/test-file-1.txt"], + ) + assert len(response.files) == 1 + assert response.files[0].uri == "files/abc" + assert captured_request.headers["x-goog-user-project"] == "test_project" + + +@pytest.mark.asyncio +@mock.patch.object( + _api_client.BaseApiClient, "_async_request_once", autospec=True +) +async def test_async_multiple_uris(mock_request): + client = Client(api_key="dummy_key") + + async def side_effect(self, http_request, stream=False): + return _api_client.HttpResponse( + headers={}, + response_stream=[ + json.dumps({"files": [{"uri": "files/abc"}, {"uri": "files/def"}]}) + ], + ) + + mock_request.side_effect = side_effect + + with pytest_helper.exception_if_vertex(client, ValueError): + response = await client.aio.files.register_files( + auth=FakeCredentials(), + uris=[ + "gs://test-bucket/test-file-1.txt", + "gs://test-bucket/test-file-2.txt", + ], + ) + assert len(response.files) == 2 + assert response.files[0].uri == "files/abc" + assert response.files[1].uri == "files/def" diff --git a/google/genai/tests/files/test_register_table.py b/google/genai/tests/files/test_register_table.py new file mode 100644 index 000000000..ff70dc0f2 --- /dev/null +++ b/google/genai/tests/files/test_register_table.py @@ -0,0 +1,70 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +"""Test files get method.""" + +import pytest +from ... import types +from ... import Client +from ... import _api_client +from .. import pytest_helper +import google.auth + + +# $ gcloud config set project vertex-sdk-dev +# $ gcloud auth application-default login --no-launch-browser --scopes="https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/devstorage.read_only" +def get_headers(): + try: + credentials, _ = google.auth.default() + token = _api_client.get_token_from_credentials(None, credentials) + headers = { + "Authorization": f"Bearer {token}",} + if credentials.quota_project_id: + headers["x-goog-user-project"] = credentials.quota_project_id + except google.auth.exceptions.DefaultCredentialsError: + # So this can run in replay mode without credentials. + headers = {} + + +test_table: list[pytest_helper.TestTableItem] = [ + pytest_helper.TestTableItem( + name='test_register', + parameters=types._RegisterFilesParameters(uris=['gs://unified-genai-dev/image.jpg']), + exception_if_vertex='only supported in the Gemini Developer client', + skip_in_api_mode=( + 'The files have a TTL, they cannot be reliably retrieved for a long' + ' time.' + ), + ), +] + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method='files._register_files', + test_table=test_table, + http_options={ + 'headers': get_headers(), + }, +) + + +@pytest.mark.asyncio +async def test_async(client): + with pytest_helper.exception_if_vertex(client, ValueError): + files = await client.aio.files._register_files(uris=['gs://unified-genai-dev/image.jpg']) + assert files.files + assert files.files[0].mime_type == 'image/jpeg' diff --git a/google/genai/types.py b/google/genai/types.py index 552874e85..d0de9fffc 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -825,6 +825,7 @@ class FileSource(_common.CaseInSensitiveEnum): SOURCE_UNSPECIFIED = 'SOURCE_UNSPECIFIED' UPLOADED = 'UPLOADED' GENERATED = 'GENERATED' + REGISTERED = 'REGISTERED' class TurnCompleteReason(_common.CaseInSensitiveEnum): @@ -13897,6 +13898,85 @@ class DeleteFileResponseDict(TypedDict, total=False): DeleteFileResponseOrDict = Union[DeleteFileResponse, DeleteFileResponseDict] +class RegisterFilesConfig(_common.BaseModel): + """Used to override the default configuration.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + should_return_http_response: Optional[bool] = Field( + default=None, + description=""" If true, the raw HTTP response will be returned in the 'sdk_http_response' field.""", + ) + + +class RegisterFilesConfigDict(TypedDict, total=False): + """Used to override the default configuration.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + should_return_http_response: Optional[bool] + """ If true, the raw HTTP response will be returned in the 'sdk_http_response' field.""" + + +RegisterFilesConfigOrDict = Union[RegisterFilesConfig, RegisterFilesConfigDict] + + +class _RegisterFilesParameters(_common.BaseModel): + """Generates the parameters for the private _Register method.""" + + uris: Optional[list[str]] = Field( + default=None, + description="""The Google Cloud Storage URIs to register. Example: `gs://bucket/object`.""", + ) + config: Optional[RegisterFilesConfig] = Field( + default=None, + description="""Used to override the default configuration.""", + ) + + +class _RegisterFilesParametersDict(TypedDict, total=False): + """Generates the parameters for the private _Register method.""" + + uris: Optional[list[str]] + """The Google Cloud Storage URIs to register. Example: `gs://bucket/object`.""" + + config: Optional[RegisterFilesConfigDict] + """Used to override the default configuration.""" + + +_RegisterFilesParametersOrDict = Union[ + _RegisterFilesParameters, _RegisterFilesParametersDict +] + + +class RegisterFilesResponse(_common.BaseModel): + """Response for the _register file method.""" + + sdk_http_response: Optional[HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + files: Optional[list[File]] = Field( + default=None, description="""The registered files.""" + ) + + +class RegisterFilesResponseDict(TypedDict, total=False): + """Response for the _register file method.""" + + sdk_http_response: Optional[HttpResponseDict] + """Used to retain the full HTTP response.""" + + files: Optional[list[FileDict]] + """The registered files.""" + + +RegisterFilesResponseOrDict = Union[ + RegisterFilesResponse, RegisterFilesResponseDict +] + + class InlinedRequest(_common.BaseModel): """Config for inlined request."""