Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 1ffd2b4

Browse files
feat: Add where, coalesce, fillna, casewhen, invert local impl
1 parent 82175a4 commit 1ffd2b4

8 files changed

Lines changed: 162 additions & 3 deletions

File tree

bigframes/core/compile/polars/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def compile_op(self, op: ops.ScalarOp, *args: pl.Expr) -> pl.Expr:
168168

169169
@compile_op.register(gen_ops.InvertOp)
170170
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
171-
return ~input
171+
return input.not_()
172172

173173
@compile_op.register(num_ops.AbsOp)
174174
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:

bigframes/core/compile/polars/lowering.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@
1414

1515
import dataclasses
1616

17+
import numpy as np
18+
1719
from bigframes import dtypes
1820
from bigframes.core import bigframe_node, expression
1921
from bigframes.core.rewrite import op_lowering
20-
from bigframes.operations import comparison_ops, datetime_ops, json_ops, numeric_ops
22+
from bigframes.operations import (
23+
comparison_ops,
24+
datetime_ops,
25+
generic_ops,
26+
json_ops,
27+
numeric_ops,
28+
)
2129
import bigframes.operations as ops
2230

2331
# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
@@ -288,6 +296,26 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
288296
return _lower_cast(expr.op, expr.inputs[0])
289297

290298

299+
def invert_bytes(byte_string):
300+
inverted_bytes = ~np.frombuffer(byte_string, dtype=np.uint8)
301+
return inverted_bytes.tobytes()
302+
303+
304+
class LowerInvertOp(op_lowering.OpLoweringRule):
305+
@property
306+
def op(self) -> type[ops.ScalarOp]:
307+
return generic_ops.InvertOp
308+
309+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
310+
assert isinstance(expr.op, generic_ops.InvertOp)
311+
arg = expr.children[0]
312+
if arg.output_type == dtypes.BYTES_DTYPE:
313+
return generic_ops.PyUdfOp(invert_bytes, dtypes.BYTES_DTYPE).as_expr(
314+
expr.inputs[0]
315+
)
316+
return expr
317+
318+
291319
def _coerce_comparables(
292320
expr1: expression.Expression,
293321
expr2: expression.Expression,
@@ -385,6 +413,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
385413
LowerFloorDivRule(),
386414
LowerModRule(),
387415
LowerAsTypeRule(),
416+
LowerInvertOp(),
388417
)
389418

390419

bigframes/core/compile/polars/operations/generic_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,14 @@ def isnull_op_impl(
4545
input: pl.Expr,
4646
) -> pl.Expr:
4747
return input.is_null()
48+
49+
50+
@polars_compiler.register_op(generic_ops.PyUdfOp)
51+
def pyudf_op_impl(
52+
compiler: polars_compiler.PolarsExpressionCompiler,
53+
op: generic_ops.PyUdfOp, # type: ignore
54+
input: pl.Expr,
55+
) -> pl.Expr:
56+
return input.map_elements(
57+
op.fn, return_dtype=polars_compiler._DTYPE_MAPPING[op._output_type]
58+
)

bigframes/operations/generic_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,15 @@ class SqlScalarOp(base_ops.NaryOp):
446446

447447
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
448448
return self._output_type
449+
450+
451+
@dataclasses.dataclass(frozen=True)
452+
class PyUdfOp(base_ops.NaryOp):
453+
"""Represents a local UDF."""
454+
455+
name: typing.ClassVar[str] = "py_udf"
456+
fn: typing.Callable
457+
_output_type: dtypes.ExpressionType
458+
459+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
460+
return self._output_type

bigframes/session/polars_executor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@
5858
numeric_ops.FloorDivOp,
5959
numeric_ops.ModOp,
6060
generic_ops.AsTypeOp,
61+
generic_ops.WhereOp,
62+
generic_ops.CoalesceOp,
63+
generic_ops.FillNaOp,
64+
generic_ops.CaseWhenOp,
65+
generic_ops.InvertOp,
6166
)
6267
_COMPATIBLE_AGG_OPS = (
6368
agg_ops.SizeOp,

tests/system/small/engines/test_generic_ops.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,102 @@ def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, e
266266
ops.AsTypeOp(to_type=bigframes.dtypes.TIMEDELTA_DTYPE),
267267
)
268268
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
269+
270+
271+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
272+
def test_engines_where_op(scalars_array_value: array_value.ArrayValue, engine):
273+
arr, _ = scalars_array_value.compute_values(
274+
[
275+
ops.where_op.as_expr(
276+
expression.deref("int64_col"),
277+
expression.deref("bool_col"),
278+
expression.deref("float64_col"),
279+
)
280+
]
281+
)
282+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
283+
284+
285+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
286+
def test_engines_coalesce_op(scalars_array_value: array_value.ArrayValue, engine):
287+
arr, _ = scalars_array_value.compute_values(
288+
[
289+
ops.coalesce_op.as_expr(
290+
expression.deref("int64_col"),
291+
expression.deref("float64_col"),
292+
)
293+
]
294+
)
295+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
296+
297+
298+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
299+
def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine):
300+
arr, _ = scalars_array_value.compute_values(
301+
[
302+
ops.fillna_op.as_expr(
303+
expression.deref("int64_col"),
304+
expression.deref("float64_col"),
305+
)
306+
]
307+
)
308+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
309+
310+
311+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
312+
def test_engines_casewhen_op_single_case(
313+
scalars_array_value: array_value.ArrayValue, engine
314+
):
315+
arr, _ = scalars_array_value.compute_values(
316+
[
317+
ops.case_when_op.as_expr(
318+
expression.deref("bool_col"),
319+
expression.deref("int64_col"),
320+
)
321+
]
322+
)
323+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
324+
325+
326+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
327+
def test_engines_casewhen_op_double_case(
328+
scalars_array_value: array_value.ArrayValue, engine
329+
):
330+
arr, _ = scalars_array_value.compute_values(
331+
[
332+
ops.case_when_op.as_expr(
333+
ops.gt_op.as_expr(expression.deref("int64_col"), expression.const(3)),
334+
expression.deref("int64_col"),
335+
ops.lt_op.as_expr(expression.deref("int64_col"), expression.const(-3)),
336+
expression.deref("int64_too"),
337+
)
338+
]
339+
)
340+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
341+
342+
343+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
344+
def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine):
345+
arr, _ = scalars_array_value.compute_values(
346+
[ops.isnull_op.as_expr(expression.deref("string_col"))]
347+
)
348+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
349+
350+
351+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
352+
def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine):
353+
arr, _ = scalars_array_value.compute_values(
354+
[ops.notnull_op.as_expr(expression.deref("string_col"))]
355+
)
356+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
357+
358+
359+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
360+
def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine):
361+
arr, _ = scalars_array_value.compute_values(
362+
[
363+
ops.invert_op.as_expr(expression.deref("bytes_col")),
364+
ops.invert_op.as_expr(expression.deref("bool_col")),
365+
]
366+
)
367+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

third_party/bigframes_vendored/ibis/expr/operations/numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ class Tan(TrigonometricUnary):
326326
class BitwiseNot(Unary):
327327
"""Bitwise NOT operation."""
328328

329-
arg: Integer
329+
arg: Value[dt.Integer | dt.Binary]
330330

331331
dtype = rlz.numeric_like("args", operator.invert)
332332

third_party/bigframes_vendored/ibis/expr/types/binary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def hashbytes(
3232
"""
3333
return ops.HashBytes(self, how).to_expr()
3434

35+
def __invert__(self) -> BinaryValue:
36+
return ops.BitwiseNot(self).to_expr()
37+
3538

3639
@public
3740
class BinaryScalar(Scalar, BinaryValue):

0 commit comments

Comments
 (0)