Skip to content

Commit 395ba23

Browse files
committed
feat: Allow passing args to managed functions in DataFrame apply method
1 parent 9d4504b commit 395ba23

File tree

6 files changed

+166
-44
lines changed

6 files changed

+166
-44
lines changed

bigframes/dataframe.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4801,37 +4801,53 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
48014801
)
48024802

48034803
# Apply the function
4804-
result_series = rows_as_json_series._apply_unary_op(
4805-
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
4806-
)
4804+
if args:
4805+
result_series = rows_as_json_series._apply_nary_op(
4806+
ops.NaryRemoteFunctionOp(function_def=func.udf_def),
4807+
list(args),
4808+
)
4809+
else:
4810+
result_series = rows_as_json_series._apply_unary_op(
4811+
ops.RemoteFunctionOp(
4812+
function_def=func.udf_def, apply_on_null=True
4813+
)
4814+
)
48074815
else:
48084816
# This is a special case where we are providing not-pandas-like
48094817
# extension. If the bigquery function can take one or more
4810-
# params then we assume that here the user intention is to use
4811-
# the column values of the dataframe as arguments to the
4812-
# function. For this to work the following condition must be
4813-
# true:
4814-
# 1. The number or input params in the function must be same
4815-
# as the number of columns in the dataframe
4818+
# params (exclude the args) then we assume that here the user
4819+
# intention is to use the column values of the dataframe as
4820+
# arguments to the function. For this to work the following
4821+
# condition must be true:
4822+
# 1. The number or input params (exclude the args) in the
4823+
# function must be same as the number of columns in the
4824+
# dataframe.
48164825
# 2. The dtypes of the columns in the dataframe must be
4817-
# compatible with the data types of the input params
4826+
# compatible with the data types of the input params.
48184827
# 3. The order of the columns in the dataframe must correspond
4819-
# to the order of the input params in the function
4828+
# to the order of the input params in the function.
48204829
udf_input_dtypes = func.udf_def.signature.bf_input_types
4821-
if len(udf_input_dtypes) != len(self.columns):
4830+
if len(udf_input_dtypes) != len(self.columns) + len(args):
48224831
raise ValueError(
4823-
f"BigFrames BigQuery function takes {len(udf_input_dtypes)}"
4824-
f" arguments but DataFrame has {len(self.columns)} columns."
4832+
f"Column count mismatch: BigFrames BigQuery function"
4833+
f" expected {len(udf_input_dtypes) - len(args)} columns"
4834+
f" from DataFrame but received {len(self.columns)}."
48254835
)
4826-
if udf_input_dtypes != tuple(self.dtypes.to_list()):
4836+
end_slice = -len(args) if args else None
4837+
if udf_input_dtypes[:end_slice] != tuple(self.dtypes.to_list()):
48274838
raise ValueError(
4828-
f"BigFrames BigQuery function takes arguments of types "
4829-
f"{udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
4839+
f"Data type mismatch: BigFrames BigQuery function takes"
4840+
f" arguments of types {udf_input_dtypes} but DataFrame"
4841+
f" dtypes are {tuple(self.dtypes)}."
48304842
)
48314843

48324844
series_list = [self[col] for col in self.columns]
4845+
if args:
4846+
op_list = series_list[1:] + list(args)
4847+
else:
4848+
op_list = series_list[1:]
48334849
result_series = series_list[0]._apply_nary_op(
4834-
ops.NaryRemoteFunctionOp(function_def=func.udf_def), series_list[1:]
4850+
ops.NaryRemoteFunctionOp(function_def=func.udf_def), op_list
48354851
)
48364852
result_series.name = None
48374853

bigframes/functions/_function_session.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -959,11 +959,15 @@ def _convert_row_processor_sig(
959959
) -> Optional[inspect.Signature]:
960960
import bigframes.series as bf_series
961961

962-
if len(signature.parameters) == 1:
963-
only_param = next(iter(signature.parameters.values()))
964-
param_type = only_param.annotation
965-
if (param_type == bf_series.Series) or (param_type == pandas.Series):
966-
msg = bfe.format_message("input_types=Series is in preview.")
967-
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
968-
return signature.replace(parameters=[only_param.replace(annotation=str)])
962+
first_param = next(iter(signature.parameters.values()))
963+
param_type = first_param.annotation
964+
if (param_type == bf_series.Series) or (param_type == pandas.Series):
965+
msg = bfe.format_message("input_types=Series is in preview.")
966+
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
967+
return signature.replace(
968+
parameters=[
969+
p.replace(annotation=str) if i == 0 else p
970+
for i, p in enumerate(signature.parameters.values())
971+
]
972+
)
969973
return None

bigframes/functions/function.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,6 @@ def read_gbq_function(
178178
ValueError, f"Unknown function '{routine_ref}'."
179179
)
180180

181-
if is_row_processor and len(routine.arguments) > 1:
182-
raise bf_formatting.create_exception_with_feedback_link(
183-
ValueError,
184-
"A multi-input function cannot be a row processor. A row processor function "
185-
"takes in a single input representing the row.",
186-
)
187-
188181
if is_row_processor:
189182
return _try_import_row_routine(routine, session)
190183
else:

bigframes/functions/function_template.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,28 @@ def generate_managed_function_code(
332332
f"""def bigframes_handler(str_arg):
333333
return {udf_name}({get_pd_series.__name__}(str_arg))"""
334334
)
335+
336+
sig = inspect.signature(def_)
337+
params = list(sig.parameters.values())
338+
additional_params = params[1:]
339+
340+
# Build the parameter list for the new handler function definition.
341+
# e.g., "str_arg, y: bool, z"
342+
handler_def_parts = ["str_arg"]
343+
handler_def_parts.extend(str(p) for p in additional_params)
344+
handler_def_str = ", ".join(handler_def_parts)
345+
346+
# Build the argument list for the call to the original UDF.
347+
# e.g., "get_pd_series(str_arg), y, z"
348+
udf_call_parts = [f"{get_pd_series.__name__}(str_arg)"]
349+
udf_call_parts.extend(p.name for p in additional_params)
350+
udf_call_str = ", ".join(udf_call_parts)
351+
352+
bigframes_handler_code = textwrap.dedent(
353+
f"""def bigframes_handler({handler_def_str}):
354+
return {udf_name}({udf_call_str})"""
355+
)
356+
335357
else:
336358
udf_code = ""
337359
bigframes_handler_code = textwrap.dedent(

tests/system/large/functions/test_managed_function.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,20 +468,20 @@ def foo(x, y, z):
468468
# Fails to apply on dataframe with incompatible number of columns.
469469
with pytest.raises(
470470
ValueError,
471-
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
471+
match="^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 2\\.$",
472472
):
473473
bf_df[["Id", "Age"]].apply(foo, axis=1)
474474

475475
with pytest.raises(
476476
ValueError,
477-
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
477+
match="^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 4\\.$",
478478
):
479479
bf_df.assign(Country="lalaland").apply(foo, axis=1)
480480

481481
# Fails to apply on dataframe with incompatible column datatypes.
482482
with pytest.raises(
483483
ValueError,
484-
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
484+
match="^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
485485
):
486486
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)
487487

@@ -965,6 +965,93 @@ def float_parser(row):
965965
)
966966

967967

968+
def test_managed_function_df_apply_axis_1_args(session, dataset_id, scalars_dfs):
969+
columns = ["int64_col", "int64_too"]
970+
scalars_df, scalars_pandas_df = scalars_dfs
971+
972+
try:
973+
974+
def the_sum(s1, s2, x):
975+
return s1 + s2 + x
976+
977+
the_sum_mf = session.udf(
978+
input_types=[int, int, int],
979+
output_type=int,
980+
dataset=dataset_id,
981+
name=prefixer.create_prefix(),
982+
)(the_sum)
983+
984+
args1 = (1,)
985+
bf_result = (
986+
scalars_df[columns]
987+
.dropna()
988+
.apply(the_sum_mf, axis=1, args=args1)
989+
.to_pandas()
990+
)
991+
pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1)
992+
993+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
994+
995+
finally:
996+
# clean up the gcp assets created for the managed function.
997+
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)
998+
999+
1000+
def test_managed_function_df_apply_axis_1_series_args(session, dataset_id, scalars_dfs):
1001+
columns = ["int64_col", "float64_col"]
1002+
scalars_df, scalars_pandas_df = scalars_dfs
1003+
1004+
try:
1005+
1006+
def analyze(s, x, y):
1007+
value = f"value is {s['int64_col']} and {s['float64_col']}"
1008+
if x:
1009+
return f"{value}, x is True!"
1010+
if y > 0:
1011+
return f"{value}, x is False, y is positive!"
1012+
return f"{value}, x is False, y is non-positive!"
1013+
1014+
analyze_mf = session.udf(
1015+
input_types=[bigframes.series.Series, bool, float],
1016+
output_type=str,
1017+
dataset=dataset_id,
1018+
name=prefixer.create_prefix(),
1019+
)(analyze)
1020+
1021+
args1 = (True, 10.0)
1022+
bf_result = (
1023+
scalars_df[columns]
1024+
.dropna()
1025+
.apply(analyze_mf, axis=1, args=args1)
1026+
.to_pandas()
1027+
)
1028+
pd_result = (
1029+
scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args1)
1030+
)
1031+
1032+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
1033+
1034+
args2 = (False, -10.0)
1035+
analyze_mf_ref = session.read_gbq_function(
1036+
analyze_mf.bigframes_bigquery_function, is_row_processor=True
1037+
)
1038+
bf_result = (
1039+
scalars_df[columns]
1040+
.dropna()
1041+
.apply(analyze_mf_ref, axis=1, args=args2)
1042+
.to_pandas()
1043+
)
1044+
pd_result = (
1045+
scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args2)
1046+
)
1047+
1048+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
1049+
1050+
finally:
1051+
# clean up the gcp assets created for the managed function.
1052+
cleanup_function_assets(analyze_mf, session.bqclient, ignore_failures=False)
1053+
1054+
9681055
def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs):
9691056
try:
9701057

tests/system/large/functions/test_remote_function.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,19 +2200,19 @@ def foo(x, y, z):
22002200
# Fails to apply on dataframe with incompatible number of columns
22012201
with pytest.raises(
22022202
ValueError,
2203-
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
2203+
match="^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 2\\.$",
22042204
):
22052205
bf_df[["Id", "Age"]].apply(foo, axis=1)
22062206
with pytest.raises(
22072207
ValueError,
2208-
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
2208+
match="^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 4\\.$",
22092209
):
22102210
bf_df.assign(Country="lalaland").apply(foo, axis=1)
22112211

22122212
# Fails to apply on dataframe with incompatible column datatypes
22132213
with pytest.raises(
22142214
ValueError,
2215-
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
2215+
match="^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
22162216
):
22172217
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)
22182218

@@ -2284,19 +2284,19 @@ def foo(x, y, z):
22842284
# Fails to apply on dataframe with incompatible number of columns
22852285
with pytest.raises(
22862286
ValueError,
2287-
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
2287+
match="^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 2\\.$",
22882288
):
22892289
bf_df[["Id", "Age"]].apply(foo, axis=1)
22902290
with pytest.raises(
22912291
ValueError,
2292-
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
2292+
match="^Column count mismatch: BigFrames BigQuery function expected 3 columns from DataFrame but received 4\\.$",
22932293
):
22942294
bf_df.assign(Country="lalaland").apply(foo, axis=1)
22952295

22962296
# Fails to apply on dataframe with incompatible column datatypes
22972297
with pytest.raises(
22982298
ValueError,
2299-
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
2299+
match="^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
23002300
):
23012301
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)
23022302

@@ -2358,19 +2358,19 @@ def foo(x):
23582358
# Fails to apply on dataframe with incompatible number of columns
23592359
with pytest.raises(
23602360
ValueError,
2361-
match="^BigFrames BigQuery function takes 1 arguments but DataFrame has 0 columns\\.$",
2361+
match="^Column count mismatch: BigFrames BigQuery function expected 1 columns from DataFrame but received 0\\.$",
23622362
):
23632363
bf_df[[]].apply(foo, axis=1)
23642364
with pytest.raises(
23652365
ValueError,
2366-
match="^BigFrames BigQuery function takes 1 arguments but DataFrame has 2 columns\\.$",
2366+
match="^Column count mismatch: BigFrames BigQuery function expected 1 columns from DataFrame but received 2\\.$",
23672367
):
23682368
bf_df.assign(Country="lalaland").apply(foo, axis=1)
23692369

23702370
# Fails to apply on dataframe with incompatible column datatypes
23712371
with pytest.raises(
23722372
ValueError,
2373-
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
2373+
match="^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
23742374
):
23752375
bf_df.assign(Id=bf_df["Id"].astype("Float64")).apply(foo, axis=1)
23762376

0 commit comments

Comments
 (0)