@@ -514,6 +514,53 @@ def test_where_dataframe_cond_dataframe_other(
514514 pandas .testing .assert_frame_equal (bf_result , pd_result )
515515
516516
517+ def test_where_callable_cond_constant_other (scalars_df_index , scalars_pandas_df_index ):
518+ # Condition is callable, other is a constant.
519+ columns = ["int64_col" , "float64_col" ]
520+ dataframe_bf = scalars_df_index [columns ]
521+ dataframe_pd = scalars_pandas_df_index [columns ]
522+
523+ cond = lambda x : x > 0
524+ other = 10
525+
526+ bf_result = dataframe_bf .where (cond , other ).to_pandas ()
527+ pd_result = dataframe_pd .where (cond , other )
528+ pandas .testing .assert_frame_equal (bf_result , pd_result )
529+
530+
531+ def test_where_dataframe_cond_callable_other (scalars_df_index , scalars_pandas_df_index ):
532+ # Condition is a dataframe, other is callable.
533+ columns = ["int64_col" , "float64_col" ]
534+ dataframe_bf = scalars_df_index [columns ]
535+ dataframe_pd = scalars_pandas_df_index [columns ]
536+
537+ cond_bf = dataframe_bf > 0
538+ cond_pd = dataframe_pd > 0
539+
540+ def func (x ):
541+ return x * 2
542+
543+ bf_result = dataframe_bf .where (cond_bf , func ).to_pandas ()
544+ pd_result = dataframe_pd .where (cond_pd , func )
545+ pandas .testing .assert_frame_equal (bf_result , pd_result )
546+
547+
548+ def test_where_callable_cond_callable_other (scalars_df_index , scalars_pandas_df_index ):
549+ # Condition is callable, other is callable too.
550+ columns = ["int64_col" , "float64_col" ]
551+ dataframe_bf = scalars_df_index [columns ]
552+ dataframe_pd = scalars_pandas_df_index [columns ]
553+
554+ def func (x ):
555+ return x ["int64_col" ] > 0
556+
557+ other = lambda x : x * 2
558+
559+ bf_result = dataframe_bf .where (func , other ).to_pandas ()
560+ pd_result = dataframe_pd .where (func , other )
561+ pandas .testing .assert_frame_equal (bf_result , pd_result )
562+
563+
517564def test_drop_column (scalars_dfs ):
518565 scalars_df , scalars_pandas_df = scalars_dfs
519566 col_name = "int64_col"
0 commit comments