@@ -217,7 +217,7 @@ def test_gemini_text_generator_predict_output_schema_success(
217217 llm_text_df : bpd .DataFrame , model_name , session , bq_connection
218218):
219219 gemini_text_generator_model = llm .GeminiTextGenerator (
220- model_name = "gemini-2.0-flash-001" ,
220+ model_name = model_name ,
221221 connection_name = bq_connection ,
222222 session = session ,
223223 )
@@ -812,12 +812,18 @@ def test_text_embedding_generator_no_default_model_warning(model_class):
812812 model_class (model_name = None )
813813
814814
815- @pytest .mark .flaky (retries = 2 )
815+ @pytest .mark .parametrize (
816+ "model_name" ,
817+ (
818+ "gemini-2.0-flash-001" ,
819+ "gemini-2.0-flash-lite-001" ,
820+ ),
821+ )
816822def test_gemini_text_generator_predict_struct_schema_succeeds (
817- llm_text_df : bpd .DataFrame , session , bq_connection
823+ llm_text_df : bpd .DataFrame , session , bq_connection , model_name
818824):
819825 gemini_text_generator_model = llm .GeminiTextGenerator (
820- model_name = "gemini-2.0-flash-001" ,
826+ model_name = model_name ,
821827 connection_name = bq_connection ,
822828 session = session ,
823829 )
@@ -839,12 +845,18 @@ def test_gemini_text_generator_predict_struct_schema_succeeds(
839845 )
840846
841847
842- @pytest .mark .flaky (retries = 2 )
848+ @pytest .mark .parametrize (
849+ "model_name" ,
850+ (
851+ "gemini-2.0-flash-001" ,
852+ "gemini-2.0-flash-lite-001" ,
853+ ),
854+ )
843855def test_gemini_text_generator_predict_struct_schema_flat_succeeds (
844- llm_text_df : bpd .DataFrame , session , bq_connection
856+ llm_text_df : bpd .DataFrame , session , bq_connection , model_name
845857):
846858 gemini_text_generator_model = llm .GeminiTextGenerator (
847- model_name = "gemini-2.0-flash-001" ,
859+ model_name = model_name ,
848860 connection_name = bq_connection ,
849861 session = session ,
850862 )
@@ -865,12 +877,18 @@ def test_gemini_text_generator_predict_struct_schema_flat_succeeds(
865877 )
866878
867879
868- @pytest .mark .flaky (retries = 2 )
880+ @pytest .mark .parametrize (
881+ "model_name" ,
882+ (
883+ "gemini-2.0-flash-001" ,
884+ "gemini-2.0-flash-lite-001" ,
885+ ),
886+ )
869887def test_gemini_text_generator_predict_array_schema_succeeds (
870- llm_text_df : bpd .DataFrame , session , bq_connection
888+ llm_text_df : bpd .DataFrame , session , bq_connection , model_name
871889):
872890 gemini_text_generator_model = llm .GeminiTextGenerator (
873- model_name = "gemini-2.0-flash-001" ,
891+ model_name = model_name ,
874892 connection_name = bq_connection ,
875893 session = session ,
876894 )
@@ -889,12 +907,18 @@ def test_gemini_text_generator_predict_array_schema_succeeds(
889907 )
890908
891909
892- @pytest .mark .flaky (retries = 2 )
910+ @pytest .mark .parametrize (
911+ "model_name" ,
912+ (
913+ "gemini-2.0-flash-001" ,
914+ "gemini-2.0-flash-lite-001" ,
915+ ),
916+ )
893917def test_gemini_text_generator_predict_array_struct_schema_succeeds (
894- llm_text_df : bpd .DataFrame , session , bq_connection
918+ llm_text_df : bpd .DataFrame , session , bq_connection , model_name
895919):
896920 gemini_text_generator_model = llm .GeminiTextGenerator (
897- model_name = "gemini-2.0-flash-001" ,
921+ model_name = model_name ,
898922 connection_name = bq_connection ,
899923 session = session ,
900924 )
@@ -915,12 +939,18 @@ def test_gemini_text_generator_predict_array_struct_schema_succeeds(
915939 )
916940
917941
918- @pytest .mark .flaky (retries = 2 )
942+ @pytest .mark .parametrize (
943+ "model_name" ,
944+ (
945+ "gemini-2.0-flash-001" ,
946+ "gemini-2.0-flash-lite-001" ,
947+ ),
948+ )
919949def test_gemini_text_generator_predict_invalid_schema_fails (
920- llm_text_df : bpd .DataFrame , session , bq_connection
950+ llm_text_df : bpd .DataFrame , session , bq_connection , model_name
921951):
922952 gemini_text_generator_model = llm .GeminiTextGenerator (
923- model_name = "gemini-2.0-flash-001" ,
953+ model_name = model_name ,
924954 connection_name = bq_connection ,
925955 session = session ,
926956 )
0 commit comments