@@ -468,20 +468,20 @@ def foo(x, y, z):
468468 # Fails to apply on dataframe with incompatible number of columns.
469469 with pytest .raises (
470470 ValueError ,
471- match = "^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns \\ .$" ,
471+ match = "^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 2 \\ .$" ,
472472 ):
473473 bf_df [["Id" , "Age" ]].apply (foo , axis = 1 )
474474
475475 with pytest .raises (
476476 ValueError ,
477- match = "^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns \\ .$" ,
477+ match = "^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 4 \\ .$" ,
478478 ):
479479 bf_df .assign (Country = "lalaland" ).apply (foo , axis = 1 )
480480
481481 # Fails to apply on dataframe with incompatible column datatypes.
482482 with pytest .raises (
483483 ValueError ,
484- match = "^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*" ,
484+ match = "^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*" ,
485485 ):
486486 bf_df .assign (Age = bf_df ["Age" ].astype ("Int64" )).apply (foo , axis = 1 )
487487
@@ -965,6 +965,93 @@ def float_parser(row):
965965 )
966966
967967
968+ def test_managed_function_df_apply_axis_1_args (session , dataset_id , scalars_dfs ):
969+ columns = ["int64_col" , "int64_too" ]
970+ scalars_df , scalars_pandas_df = scalars_dfs
971+
972+ try :
973+
974+ def the_sum (s1 , s2 , x ):
975+ return s1 + s2 + x
976+
977+ the_sum_mf = session .udf (
978+ input_types = [int , int , int ],
979+ output_type = int ,
980+ dataset = dataset_id ,
981+ name = prefixer .create_prefix (),
982+ )(the_sum )
983+
984+ args1 = (1 ,)
985+ bf_result = (
986+ scalars_df [columns ]
987+ .dropna ()
988+ .apply (the_sum_mf , axis = 1 , args = args1 )
989+ .to_pandas ()
990+ )
991+ pd_result = scalars_pandas_df [columns ].dropna ().apply (sum , axis = 1 , args = args1 )
992+
993+ pandas .testing .assert_series_equal (pd_result , bf_result , check_dtype = False )
994+
995+ finally :
996+ # clean up the gcp assets created for the managed function.
997+ cleanup_function_assets (the_sum_mf , session .bqclient , ignore_failures = False )
998+
999+
1000+ def test_managed_function_df_apply_axis_1_series_args (session , dataset_id , scalars_dfs ):
1001+ columns = ["int64_col" , "float64_col" ]
1002+ scalars_df , scalars_pandas_df = scalars_dfs
1003+
1004+ try :
1005+
1006+ def analyze (s , x , y ):
1007+ value = f"value is { s ['int64_col' ]} and { s ['float64_col' ]} "
1008+ if x :
1009+ return f"{ value } , x is True!"
1010+ if y > 0 :
1011+ return f"{ value } , x is False, y is positive!"
1012+ return f"{ value } , x is False, y is non-positive!"
1013+
1014+ analyze_mf = session .udf (
1015+ input_types = [bigframes .series .Series , bool , float ],
1016+ output_type = str ,
1017+ dataset = dataset_id ,
1018+ name = prefixer .create_prefix (),
1019+ )(analyze )
1020+
1021+ args1 = (True , 10.0 )
1022+ bf_result = (
1023+ scalars_df [columns ]
1024+ .dropna ()
1025+ .apply (analyze_mf , axis = 1 , args = args1 )
1026+ .to_pandas ()
1027+ )
1028+ pd_result = (
1029+ scalars_pandas_df [columns ].dropna ().apply (analyze , axis = 1 , args = args1 )
1030+ )
1031+
1032+ pandas .testing .assert_series_equal (pd_result , bf_result , check_dtype = False )
1033+
1034+ args2 = (False , - 10.0 )
1035+ analyze_mf_ref = session .read_gbq_function (
1036+ analyze_mf .bigframes_bigquery_function , is_row_processor = True
1037+ )
1038+ bf_result = (
1039+ scalars_df [columns ]
1040+ .dropna ()
1041+ .apply (analyze_mf_ref , axis = 1 , args = args2 )
1042+ .to_pandas ()
1043+ )
1044+ pd_result = (
1045+ scalars_pandas_df [columns ].dropna ().apply (analyze , axis = 1 , args = args2 )
1046+ )
1047+
1048+ pandas .testing .assert_series_equal (pd_result , bf_result , check_dtype = False )
1049+
1050+ finally :
1051+ # clean up the gcp assets created for the managed function.
1052+ cleanup_function_assets (analyze_mf , session .bqclient , ignore_failures = False )
1053+
1054+
9681055def test_managed_function_df_where_mask (session , dataset_id , scalars_dfs ):
9691056 try :
9701057
0 commit comments