@@ -965,7 +965,7 @@ def float_parser(row):
965965 )
966966
967967
968- def test_managed_function_df_where (session , dataset_id , scalars_dfs ):
968+ def test_managed_function_df_where_mask (session , dataset_id , scalars_dfs ):
969969 try :
970970
971971 # The return type has to be bool type for callable where condition.
@@ -987,15 +987,15 @@ def is_sum_positive(a, b):
987987 pd_int64_df = scalars_pandas_df [int64_cols ]
988988 pd_int64_df_filtered = pd_int64_df .dropna ()
989989
990- # Use callable condition in dataframe.where method.
990+ # Test callable condition in dataframe.where method.
991991 bf_result = bf_int64_df_filtered .where (is_sum_positive_mf ).to_pandas ()
992992 # Pandas doesn't support such case, use following as workaround.
993993 pd_result = pd_int64_df_filtered .where (pd_int64_df_filtered .sum (axis = 1 ) > 0 )
994994
995995 # Ignore any dtype difference.
996996 pandas .testing .assert_frame_equal (bf_result , pd_result , check_dtype = False )
997997
998- # Make sure the read_gbq_function path works for this function .
998+ # Make sure the read_gbq_function path works for dataframe.where method .
999999 is_sum_positive_ref = session .read_gbq_function (
10001000 function_name = is_sum_positive_mf .bigframes_bigquery_function
10011001 )
@@ -1012,14 +1012,27 @@ def is_sum_positive(a, b):
10121012 bf_result_gbq , pd_result_gbq , check_dtype = False
10131013 )
10141014
1015+ # Test callable condition in dataframe.mask method.
1016+ bf_result_gbq = bf_int64_df_filtered .mask (
1017+ is_sum_positive_ref , - bf_int64_df_filtered
1018+ ).to_pandas ()
1019+ pd_result_gbq = pd_int64_df_filtered .mask (
1020+ pd_int64_df_filtered .sum (axis = 1 ) > 0 , - pd_int64_df_filtered
1021+ )
1022+
1023+ # Ignore any dtype difference.
1024+ pandas .testing .assert_frame_equal (
1025+ bf_result_gbq , pd_result_gbq , check_dtype = False
1026+ )
1027+
10151028 finally :
10161029 # Clean up the gcp assets created for the managed function.
10171030 cleanup_function_assets (
10181031 is_sum_positive_mf , session .bqclient , ignore_failures = False
10191032 )
10201033
10211034
1022- def test_managed_function_df_where_series (session , dataset_id , scalars_dfs ):
1035+ def test_managed_function_df_where_mask_series (session , dataset_id , scalars_dfs ):
10231036 try :
10241037
10251038 # The return type has to be bool type for callable where condition.
@@ -1041,14 +1054,14 @@ def is_sum_positive_series(s):
10411054 pd_int64_df = scalars_pandas_df [int64_cols ]
10421055 pd_int64_df_filtered = pd_int64_df .dropna ()
10431056
1044- # Use callable condition in dataframe.where method.
1057+ # Test callable condition in dataframe.where method.
10451058 bf_result = bf_int64_df_filtered .where (is_sum_positive_series ).to_pandas ()
10461059 pd_result = pd_int64_df_filtered .where (is_sum_positive_series )
10471060
10481061 # Ignore any dtype difference.
10491062 pandas .testing .assert_frame_equal (bf_result , pd_result , check_dtype = False )
10501063
1051- # Make sure the read_gbq_function path works for this function .
1064+ # Make sure the read_gbq_function path works for dataframe.where method .
10521065 is_sum_positive_series_ref = session .read_gbq_function (
10531066 function_name = is_sum_positive_series_mf .bigframes_bigquery_function ,
10541067 is_row_processor = True ,
@@ -1070,6 +1083,19 @@ def func_for_other(x):
10701083 bf_result_gbq , pd_result_gbq , check_dtype = False
10711084 )
10721085
1086+ # Test callable condition in dataframe.mask method.
1087+ bf_result_gbq = bf_int64_df_filtered .mask (
1088+ is_sum_positive_series_ref , func_for_other
1089+ ).to_pandas ()
1090+ pd_result_gbq = pd_int64_df_filtered .mask (
1091+ is_sum_positive_series , func_for_other
1092+ )
1093+
1094+ # Ignore any dtype difference.
1095+ pandas .testing .assert_frame_equal (
1096+ bf_result_gbq , pd_result_gbq , check_dtype = False
1097+ )
1098+
10731099 finally :
10741100 # Clean up the gcp assets created for the managed function.
10751101 cleanup_function_assets (
0 commit comments