Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 2301b97

Browse files
authored
Merge branch 'main' into main_chelsealin_durations
2 parents 148bb20 + 44c1ec4 commit 2301b97

15 files changed

Lines changed: 346 additions & 33 deletions

File tree

bigframes/core/compile/polars/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
198198
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
199199
return l_input | r_input
200200

201+
@compile_op.register(bool_ops.XorOp)
202+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
203+
return l_input ^ r_input
204+
201205
@compile_op.register(num_ops.AddOp)
202206
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
203207
return l_input + r_input

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ def compile(
3737
return UNARY_OP_REGISTRATION[op](op, column, window=window)
3838

3939

40+
@UNARY_OP_REGISTRATION.register(agg_ops.CountOp)
41+
def _(
42+
op: agg_ops.CountOp,
43+
column: typed_expr.TypedExpr,
44+
window: typing.Optional[window_spec.WindowSpec] = None,
45+
) -> sge.Expression:
46+
return apply_window_if_present(sge.func("COUNT", column.expr), window)
47+
48+
4049
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
4150
def _(
4251
op: agg_ops.SumOp,

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ def compile_window(
336336
this=is_observation_expr, expression=expr
337337
)
338338
is_observation = ir._cast(is_observation_expr, "INT64")
339+
observation_count = windows.apply_window_if_present(
340+
sge.func("SUM", is_observation), window_spec
341+
)
339342
else:
340343
# Operations like count treat even NULLs as valid observations
341344
# for the sake of min_periods notnull is just used to convert
@@ -344,10 +347,10 @@ def compile_window(
344347
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
345348
"INT64",
346349
)
350+
observation_count = windows.apply_window_if_present(
351+
sge.func("COUNT", is_observation), window_spec
352+
)
347353

348-
observation_count = windows.apply_window_if_present(
349-
sge.func("SUM", is_observation), window_spec
350-
)
351354
clauses.append(
352355
(
353356
observation_count < sge.convert(window_spec.min_periods),

bigframes/dataframe.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2797,10 +2797,17 @@ def where(self, cond, other=None):
27972797
)
27982798

27992799
# Execute it with the DataFrame when cond or/and other is callable.
2800+
# It can be either a plain python function or remote/managed function.
28002801
if callable(cond):
2801-
cond = cond(self)
2802+
if hasattr(cond, "bigframes_bigquery_function"):
2803+
cond = self.apply(cond, axis=1)
2804+
else:
2805+
cond = cond(self)
28022806
if callable(other):
2803-
other = other(self)
2807+
if hasattr(other, "bigframes_bigquery_function"):
2808+
other = self.apply(other, axis=1)
2809+
else:
2810+
other = other(self)
28042811

28052812
aligned_block, (_, _) = self._block.join(cond._block, how="left")
28062813
# No left join is needed when 'other' is None or constant.
@@ -2813,7 +2820,7 @@ def where(self, cond, other=None):
28132820
labels = aligned_block.column_labels[:self_len]
28142821
self_col = {x: ex.deref(y) for x, y in zip(labels, ids)}
28152822

2816-
if isinstance(cond, bigframes.series.Series) and cond.name in self_col:
2823+
if isinstance(cond, bigframes.series.Series):
28172824
# This is when 'cond' is a valid series.
28182825
y = aligned_block.value_columns[self_len]
28192826
cond_col = {x: ex.deref(y) for x in self_col.keys()}

bigframes/display/table_widget.js

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ function render({ model, el }) {
8888
const totalPages = Math.ceil(rowCount / pageSize);
8989

9090
rowCountLabel.textContent = `${rowCount.toLocaleString()} total rows`;
91-
paginationLabel.textContent = `Page ${currentPage + 1} of ${totalPages || 1}`;
91+
paginationLabel.textContent = `Page ${(
92+
currentPage + 1
93+
).toLocaleString()} of ${(totalPages || 1).toLocaleString()}`;
9294
prevPage.disabled = currentPage === 0;
9395
nextPage.disabled = currentPage >= totalPages - 1;
9496
pageSizeSelect.value = pageSize;

bigframes/functions/_function_session.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -555,10 +555,6 @@ def wrapper(func):
555555
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
556556
py_sig = py_sig.replace(return_annotation=output_type)
557557

558-
# Try to get input types via type annotations.
559-
560-
# The function will actually be receiving a pandas Series, but allow both
561-
# BigQuery DataFrames and pandas object types for compatibility.
562558
# The function will actually be receiving a pandas Series, but allow
563559
# both BigQuery DataFrames and pandas object types for compatibility.
564560
is_row_processor = False

bigframes/session/polars_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from bigframes.core import array_value, bigframe_node, expression, local_data, nodes
2222
import bigframes.operations
2323
from bigframes.operations import aggregations as agg_ops
24-
from bigframes.operations import comparison_ops, generic_ops, numeric_ops
24+
from bigframes.operations import bool_ops, comparison_ops, generic_ops, numeric_ops
2525
from bigframes.session import executor, semi_executor
2626

2727
if TYPE_CHECKING:
@@ -44,6 +44,9 @@
4444
)
4545

4646
_COMPATIBLE_SCALAR_OPS = (
47+
bool_ops.AndOp,
48+
bool_ops.OrOp,
49+
bool_ops.XorOp,
4750
comparison_ops.EqOp,
4851
comparison_ops.EqNullsMatchOp,
4952
comparison_ops.NeOp,
@@ -63,6 +66,8 @@
6366
generic_ops.FillNaOp,
6467
generic_ops.CaseWhenOp,
6568
generic_ops.InvertOp,
69+
generic_ops.IsNullOp,
70+
generic_ops.NotNullOp,
6671
)
6772
_COMPATIBLE_AGG_OPS = (
6873
agg_ops.SizeOp,

tests/system/large/functions/test_managed_function.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,3 +963,115 @@ def float_parser(row):
963963
cleanup_function_assets(
964964
float_parser_mf, session.bqclient, ignore_failures=False
965965
)
966+
967+
968+
def test_managed_function_df_where(session, dataset_id, scalars_dfs):
969+
try:
970+
971+
# The return type has to be bool type for callable where condition.
972+
def is_sum_positive(a, b):
973+
return a + b > 0
974+
975+
is_sum_positive_mf = session.udf(
976+
input_types=[int, int],
977+
output_type=bool,
978+
dataset=dataset_id,
979+
name=prefixer.create_prefix(),
980+
)(is_sum_positive)
981+
982+
scalars_df, scalars_pandas_df = scalars_dfs
983+
int64_cols = ["int64_col", "int64_too"]
984+
985+
bf_int64_df = scalars_df[int64_cols]
986+
bf_int64_df_filtered = bf_int64_df.dropna()
987+
pd_int64_df = scalars_pandas_df[int64_cols]
988+
pd_int64_df_filtered = pd_int64_df.dropna()
989+
990+
# Use callable condition in dataframe.where method.
991+
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas()
992+
# Pandas doesn't support such case, use following as workaround.
993+
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0)
994+
995+
# Ignore any dtype difference.
996+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
997+
998+
# Make sure the read_gbq_function path works for this function.
999+
is_sum_positive_ref = session.read_gbq_function(
1000+
function_name=is_sum_positive_mf.bigframes_bigquery_function
1001+
)
1002+
1003+
bf_result_gbq = bf_int64_df_filtered.where(
1004+
is_sum_positive_ref, -bf_int64_df_filtered
1005+
).to_pandas()
1006+
pd_result_gbq = pd_int64_df_filtered.where(
1007+
pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered
1008+
)
1009+
1010+
# Ignore any dtype difference.
1011+
pandas.testing.assert_frame_equal(
1012+
bf_result_gbq, pd_result_gbq, check_dtype=False
1013+
)
1014+
1015+
finally:
1016+
# Clean up the gcp assets created for the managed function.
1017+
cleanup_function_assets(
1018+
is_sum_positive_mf, session.bqclient, ignore_failures=False
1019+
)
1020+
1021+
1022+
def test_managed_function_df_where_series(session, dataset_id, scalars_dfs):
1023+
try:
1024+
1025+
# The return type has to be bool type for callable where condition.
1026+
def is_sum_positive_series(s):
1027+
return s["int64_col"] + s["int64_too"] > 0
1028+
1029+
is_sum_positive_series_mf = session.udf(
1030+
input_types=bigframes.series.Series,
1031+
output_type=bool,
1032+
dataset=dataset_id,
1033+
name=prefixer.create_prefix(),
1034+
)(is_sum_positive_series)
1035+
1036+
scalars_df, scalars_pandas_df = scalars_dfs
1037+
int64_cols = ["int64_col", "int64_too"]
1038+
1039+
bf_int64_df = scalars_df[int64_cols]
1040+
bf_int64_df_filtered = bf_int64_df.dropna()
1041+
pd_int64_df = scalars_pandas_df[int64_cols]
1042+
pd_int64_df_filtered = pd_int64_df.dropna()
1043+
1044+
# Use callable condition in dataframe.where method.
1045+
bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas()
1046+
pd_result = pd_int64_df_filtered.where(is_sum_positive_series)
1047+
1048+
# Ignore any dtype difference.
1049+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
1050+
1051+
# Make sure the read_gbq_function path works for this function.
1052+
is_sum_positive_series_ref = session.read_gbq_function(
1053+
function_name=is_sum_positive_series_mf.bigframes_bigquery_function,
1054+
is_row_processor=True,
1055+
)
1056+
1057+
# This is for callable `other` arg in dataframe.where method.
1058+
def func_for_other(x):
1059+
return -x
1060+
1061+
bf_result_gbq = bf_int64_df_filtered.where(
1062+
is_sum_positive_series_ref, func_for_other
1063+
).to_pandas()
1064+
pd_result_gbq = pd_int64_df_filtered.where(
1065+
is_sum_positive_series, func_for_other
1066+
)
1067+
1068+
# Ignore any dtype difference.
1069+
pandas.testing.assert_frame_equal(
1070+
bf_result_gbq, pd_result_gbq, check_dtype=False
1071+
)
1072+
1073+
finally:
1074+
# Clean up the gcp assets created for the managed function.
1075+
cleanup_function_assets(
1076+
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
1077+
)

tests/system/large/functions/test_remote_function.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2847,3 +2847,86 @@ def foo(x: int) -> int:
28472847
finally:
28482848
# clean up the gcp assets created for the remote function
28492849
cleanup_function_assets(foo, session.bqclient, session.cloudfunctionsclient)
2850+
2851+
2852+
@pytest.mark.flaky(retries=2, delay=120)
2853+
def test_remote_function_df_where(session, dataset_id, scalars_dfs):
2854+
try:
2855+
2856+
# The return type has to be bool type for callable where condition.
2857+
def is_sum_positive(a, b):
2858+
return a + b > 0
2859+
2860+
is_sum_positive_mf = session.remote_function(
2861+
input_types=[int, int],
2862+
output_type=bool,
2863+
dataset=dataset_id,
2864+
reuse=False,
2865+
cloud_function_service_account="default",
2866+
)(is_sum_positive)
2867+
2868+
scalars_df, scalars_pandas_df = scalars_dfs
2869+
int64_cols = ["int64_col", "int64_too"]
2870+
2871+
bf_int64_df = scalars_df[int64_cols]
2872+
bf_int64_df_filtered = bf_int64_df.dropna()
2873+
pd_int64_df = scalars_pandas_df[int64_cols]
2874+
pd_int64_df_filtered = pd_int64_df.dropna()
2875+
2876+
# Use callable condition in dataframe.where method.
2877+
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas()
2878+
# Pandas doesn't support such case, use following as workaround.
2879+
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0)
2880+
2881+
# Ignore any dtype difference.
2882+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2883+
2884+
finally:
2885+
# Clean up the gcp assets created for the remote function.
2886+
cleanup_function_assets(
2887+
is_sum_positive_mf, session.bqclient, ignore_failures=False
2888+
)
2889+
2890+
2891+
@pytest.mark.flaky(retries=2, delay=120)
2892+
def test_remote_function_df_where_series(session, dataset_id, scalars_dfs):
2893+
try:
2894+
2895+
# The return type has to be bool type for callable where condition.
2896+
def is_sum_positive_series(s):
2897+
return s["int64_col"] + s["int64_too"] > 0
2898+
2899+
is_sum_positive_series_mf = session.remote_function(
2900+
input_types=bigframes.series.Series,
2901+
output_type=bool,
2902+
dataset=dataset_id,
2903+
reuse=False,
2904+
cloud_function_service_account="default",
2905+
)(is_sum_positive_series)
2906+
2907+
scalars_df, scalars_pandas_df = scalars_dfs
2908+
int64_cols = ["int64_col", "int64_too"]
2909+
2910+
bf_int64_df = scalars_df[int64_cols]
2911+
bf_int64_df_filtered = bf_int64_df.dropna()
2912+
pd_int64_df = scalars_pandas_df[int64_cols]
2913+
pd_int64_df_filtered = pd_int64_df.dropna()
2914+
2915+
# This is for callable `other` arg in dataframe.where method.
2916+
def func_for_other(x):
2917+
return -x
2918+
2919+
# Use callable condition in dataframe.where method.
2920+
bf_result = bf_int64_df_filtered.where(
2921+
is_sum_positive_series, func_for_other
2922+
).to_pandas()
2923+
pd_result = pd_int64_df_filtered.where(is_sum_positive_series, func_for_other)
2924+
2925+
# Ignore any dtype difference.
2926+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2927+
2928+
finally:
2929+
# Clean up the gcp assets created for the remote function.
2930+
cleanup_function_assets(
2931+
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
2932+
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
17+
import pytest
18+
19+
from bigframes.core import array_value
20+
import bigframes.operations as ops
21+
from bigframes.session import polars_executor
22+
from bigframes.testing.engine_utils import assert_equivalence_execution
23+
24+
pytest.importorskip("polars")
25+
26+
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
27+
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
28+
29+
30+
def apply_op_pairwise(
31+
array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[]
32+
) -> array_value.ArrayValue:
33+
exprs = []
34+
for l_arg, r_arg in itertools.permutations(array.column_ids, 2):
35+
if (l_arg in excluded_cols) or (r_arg in excluded_cols):
36+
continue
37+
try:
38+
_ = op.output_type(
39+
array.get_column_type(l_arg), array.get_column_type(r_arg)
40+
)
41+
exprs.append(op.as_expr(l_arg, r_arg))
42+
except TypeError:
43+
continue
44+
assert len(exprs) > 0
45+
new_arr, _ = array.compute_values(exprs)
46+
return new_arr
47+
48+
49+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
50+
@pytest.mark.parametrize(
51+
"op",
52+
[
53+
ops.and_op,
54+
ops.or_op,
55+
ops.xor_op,
56+
],
57+
)
58+
def test_engines_project_boolean_op(
59+
scalars_array_value: array_value.ArrayValue, engine, op
60+
):
61+
# exclude string cols as does not contain dates
62+
# bool col actually doesn't work properly for bq engine
63+
arr = apply_op_pairwise(scalars_array_value, op, excluded_cols=["string_col"])
64+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)