Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 4dcc5c7

Browse files
committed
minor update
1 parent 179f60d commit 4dcc5c7

2 files changed

Lines changed: 47 additions & 19 deletions

File tree

bigframes/ml/llm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,9 +734,7 @@ def predict(
734734
output_schema = {
735735
k: utils.standardize_type(v) for k, v in output_schema.items()
736736
}
737-
options["output_schema"] = {
738-
k: utils.standardize_type(v) for k, v in output_schema.items()
739-
}
737+
options["output_schema"] = output_schema
740738
return self._predict_and_retry(
741739
core.BqmlModel.generate_table_tvf,
742740
X,

tests/system/small/ml/test_llm.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_gemini_text_generator_predict_output_schema_success(
217217
llm_text_df: bpd.DataFrame, model_name, session, bq_connection
218218
):
219219
gemini_text_generator_model = llm.GeminiTextGenerator(
220-
model_name="gemini-2.0-flash-001",
220+
model_name=model_name,
221221
connection_name=bq_connection,
222222
session=session,
223223
)
@@ -812,12 +812,18 @@ def test_text_embedding_generator_no_default_model_warning(model_class):
812812
model_class(model_name=None)
813813

814814

815-
@pytest.mark.flaky(retries=2)
815+
@pytest.mark.parametrize(
816+
"model_name",
817+
(
818+
"gemini-2.0-flash-001",
819+
"gemini-2.0-flash-lite-001",
820+
),
821+
)
816822
def test_gemini_text_generator_predict_struct_schema_succeeds(
817-
llm_text_df: bpd.DataFrame, session, bq_connection
823+
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
818824
):
819825
gemini_text_generator_model = llm.GeminiTextGenerator(
820-
model_name="gemini-2.0-flash-001",
826+
model_name=model_name,
821827
connection_name=bq_connection,
822828
session=session,
823829
)
@@ -839,12 +845,18 @@ def test_gemini_text_generator_predict_struct_schema_succeeds(
839845
)
840846

841847

842-
@pytest.mark.flaky(retries=2)
848+
@pytest.mark.parametrize(
849+
"model_name",
850+
(
851+
"gemini-2.0-flash-001",
852+
"gemini-2.0-flash-lite-001",
853+
),
854+
)
843855
def test_gemini_text_generator_predict_struct_schema_flat_succeeds(
844-
llm_text_df: bpd.DataFrame, session, bq_connection
856+
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
845857
):
846858
gemini_text_generator_model = llm.GeminiTextGenerator(
847-
model_name="gemini-2.0-flash-001",
859+
model_name=model_name,
848860
connection_name=bq_connection,
849861
session=session,
850862
)
@@ -865,12 +877,18 @@ def test_gemini_text_generator_predict_struct_schema_flat_succeeds(
865877
)
866878

867879

868-
@pytest.mark.flaky(retries=2)
880+
@pytest.mark.parametrize(
881+
"model_name",
882+
(
883+
"gemini-2.0-flash-001",
884+
"gemini-2.0-flash-lite-001",
885+
),
886+
)
869887
def test_gemini_text_generator_predict_array_schema_succeeds(
870-
llm_text_df: bpd.DataFrame, session, bq_connection
888+
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
871889
):
872890
gemini_text_generator_model = llm.GeminiTextGenerator(
873-
model_name="gemini-2.0-flash-001",
891+
model_name=model_name,
874892
connection_name=bq_connection,
875893
session=session,
876894
)
@@ -889,12 +907,18 @@ def test_gemini_text_generator_predict_array_schema_succeeds(
889907
)
890908

891909

892-
@pytest.mark.flaky(retries=2)
910+
@pytest.mark.parametrize(
911+
"model_name",
912+
(
913+
"gemini-2.0-flash-001",
914+
"gemini-2.0-flash-lite-001",
915+
),
916+
)
893917
def test_gemini_text_generator_predict_array_struct_schema_succeeds(
894-
llm_text_df: bpd.DataFrame, session, bq_connection
918+
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
895919
):
896920
gemini_text_generator_model = llm.GeminiTextGenerator(
897-
model_name="gemini-2.0-flash-001",
921+
model_name=model_name,
898922
connection_name=bq_connection,
899923
session=session,
900924
)
@@ -915,12 +939,18 @@ def test_gemini_text_generator_predict_array_struct_schema_succeeds(
915939
)
916940

917941

918-
@pytest.mark.flaky(retries=2)
942+
@pytest.mark.parametrize(
943+
"model_name",
944+
(
945+
"gemini-2.0-flash-001",
946+
"gemini-2.0-flash-lite-001",
947+
),
948+
)
919949
def test_gemini_text_generator_predict_invalid_schema_fails(
920-
llm_text_df: bpd.DataFrame, session, bq_connection
950+
llm_text_df: bpd.DataFrame, session, bq_connection, model_name
921951
):
922952
gemini_text_generator_model = llm.GeminiTextGenerator(
923-
model_name="gemini-2.0-flash-001",
953+
model_name=model_name,
924954
connection_name=bq_connection,
925955
session=session,
926956
)

0 commit comments

Comments
 (0)