diff --git a/google/genai/models.py b/google/genai/models.py index 547132546..17c745ccc 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -1226,6 +1226,16 @@ def _FunctionCallingConfig_to_mldev( return to_object +def _is_text_embedding_batch( + contents: Union[types.ContentListUnion, types.ContentListUnionDict], +) -> bool: + return ( + isinstance(contents, list) + and bool(contents) + and all(isinstance(content, str) for content in contents) + ) + + def _GenerateContentConfig_to_mldev( api_client: BaseApiClient, from_object: Union[dict[str, Any], object], @@ -6335,7 +6345,9 @@ def embed_content( ) """ if not self._api_client.vertexai: - if 'gemini-embedding-2' in model: + if 'gemini-embedding-2' in model and not _is_text_embedding_batch( + contents + ): contents = t.t_contents(contents) # type: ignore[assignment] return self._embed_content(model=model, contents=contents, config=config) @@ -9296,7 +9308,9 @@ async def embed_content( ) """ if not self._api_client.vertexai: - if 'gemini-embedding-2' in model: + if 'gemini-embedding-2' in model and not _is_text_embedding_batch( + contents + ): contents = t.t_contents(contents) # type: ignore[assignment] return await self._embed_content( model=model, contents=contents, config=config diff --git a/google/genai/tests/models/test_embed_content.py b/google/genai/tests/models/test_embed_content.py index 02cc81f71..9103a6c8d 100644 --- a/google/genai/tests/models/test_embed_content.py +++ b/google/genai/tests/models/test_embed_content.py @@ -21,6 +21,7 @@ import pytest from ... import _transformers as t +from ... import models from ... import types from .. import pytest_helper @@ -227,6 +228,78 @@ def _get_bytes_from_file(relative_path: str) -> bytes: ) +class _FakeApiClient: + vertexai = False + + +def test_gemini_embedding_2_text_list_stays_batched( + monkeypatch, use_vertex, replays_prefix, http_options +): + module = models.Models(_FakeApiClient()) + captured = {} + + def fake_embed_content(**kwargs): + captured.update(kwargs) + return types.EmbedContentResponse() + + monkeypatch.setattr(module, '_embed_content', fake_embed_content) + + module.embed_content( + model='gemini-embedding-2-preview', + contents=['first text', 'second text'], + ) + + assert captured['contents'] == ['first text', 'second text'] + + +def test_gemini_embedding_2_mixed_content_still_combines_parts( + monkeypatch, use_vertex, replays_prefix, http_options +): + module = models.Models(_FakeApiClient()) + captured = {} + + def fake_embed_content(**kwargs): + captured.update(kwargs) + return types.EmbedContentResponse() + + monkeypatch.setattr(module, '_embed_content', fake_embed_content) + + module.embed_content( + model='gemini-embedding-2-preview', + contents=[ + 'Similar things to the following image:', + types.Part.from_uri( + file_uri='gs://generativeai-downloads/images/scones.jpg', + mime_type='image/jpeg', + ), + ], + ) + + assert len(captured['contents']) == 1 + assert len(captured['contents'][0].parts) == 2 + + +@pytest.mark.asyncio +async def test_async_gemini_embedding_2_text_list_stays_batched( + monkeypatch, use_vertex, replays_prefix, http_options +): + module = models.AsyncModels(_FakeApiClient()) + captured = {} + + async def fake_embed_content(**kwargs): + captured.update(kwargs) + return types.EmbedContentResponse() + + monkeypatch.setattr(module, '_embed_content', fake_embed_content) + + await module.embed_content( + model='gemini-embedding-2-preview', + contents=['first text', 'second text'], + ) + + assert captured['contents'] == ['first text', 'second text'] + + def test_gemini_embedding_2_content_combination(client): response = client.models.embed_content( model='gemini-embedding-2-preview',