Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 7b80215

Browse files
committed
feat: Allow callable as a conditional or replacement input in DataFrame.where()
1 parent d9bc4a5 commit 7b80215

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

bigframes/dataframe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2763,6 +2763,12 @@ def where(self, cond, other=None):
27632763
"The dataframe.where() method does not support multi-column."
27642764
)
27652765

2766+
# Execute it with the DataFrame when cond or/and other is callable.
2767+
if callable(cond):
2768+
cond = cond(self)
2769+
if callable(other):
2770+
other = other(self)
2771+
27662772
aligned_block, (_, _) = self._block.join(cond._block, how="left")
27672773
# No left join is needed when 'other' is None or constant.
27682774
if isinstance(other, bigframes.dataframe.DataFrame):

tests/system/small/test_dataframe.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,53 @@ def test_where_dataframe_cond_dataframe_other(
514514
pandas.testing.assert_frame_equal(bf_result, pd_result)
515515

516516

517+
def test_where_callable_cond_constant_other(scalars_df_index, scalars_pandas_df_index):
518+
# Condition is callable, other is a constant.
519+
columns = ["int64_col", "float64_col"]
520+
dataframe_bf = scalars_df_index[columns]
521+
dataframe_pd = scalars_pandas_df_index[columns]
522+
523+
cond = lambda x: x > 0
524+
other = 10
525+
526+
bf_result = dataframe_bf.where(cond, other).to_pandas()
527+
pd_result = dataframe_pd.where(cond, other)
528+
pandas.testing.assert_frame_equal(bf_result, pd_result)
529+
530+
531+
def test_where_dataframe_cond_callable_other(scalars_df_index, scalars_pandas_df_index):
532+
# Condition is a dataframe, other is callable.
533+
columns = ["int64_col", "float64_col"]
534+
dataframe_bf = scalars_df_index[columns]
535+
dataframe_pd = scalars_pandas_df_index[columns]
536+
537+
cond_bf = dataframe_bf > 0
538+
cond_pd = dataframe_pd > 0
539+
540+
def func(x):
541+
return x * 2
542+
543+
bf_result = dataframe_bf.where(cond_bf, func).to_pandas()
544+
pd_result = dataframe_pd.where(cond_pd, func)
545+
pandas.testing.assert_frame_equal(bf_result, pd_result)
546+
547+
548+
def test_where_callable_cond_callable_other(scalars_df_index, scalars_pandas_df_index):
549+
# Condition is callable, other is callable too.
550+
columns = ["int64_col", "float64_col"]
551+
dataframe_bf = scalars_df_index[columns]
552+
dataframe_pd = scalars_pandas_df_index[columns]
553+
554+
def func(x):
555+
return x["int64_col"] > 0
556+
557+
other = lambda x: x * 2
558+
559+
bf_result = dataframe_bf.where(func, other).to_pandas()
560+
pd_result = dataframe_pd.where(func, other)
561+
pandas.testing.assert_frame_equal(bf_result, pd_result)
562+
563+
517564
def test_drop_column(scalars_dfs):
518565
scalars_df, scalars_pandas_df = scalars_dfs
519566
col_name = "int64_col"

0 commit comments

Comments
 (0)