|
35 | 35 | import bigframes.operations.comparison_ops as comp_ops |
36 | 36 | import bigframes.operations.generic_ops as gen_ops |
37 | 37 | import bigframes.operations.numeric_ops as num_ops |
| 38 | +import bigframes.operations.string_ops as string_ops |
38 | 39 |
|
39 | 40 | polars_installed = True |
40 | 41 | if TYPE_CHECKING: |
@@ -146,6 +147,14 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: |
146 | 147 | def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: |
147 | 148 | return input.abs() |
148 | 149 |
|
| 150 | + @compile_op.register(num_ops.FloorOp) |
| 151 | + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: |
| 152 | + return input.floor() |
| 153 | + |
| 154 | + @compile_op.register(num_ops.CeilOp) |
| 155 | + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: |
| 156 | + return input.ceil() |
| 157 | + |
149 | 158 | @compile_op.register(num_ops.PosOp) |
150 | 159 | def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: |
151 | 160 | return input.__pos__() |
@@ -182,10 +191,6 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: |
182 | 191 | def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: |
183 | 192 | return l_input // r_input |
184 | 193 |
|
185 | | - @compile_op.register(num_ops.FloorDivOp) |
186 | | - def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: |
187 | | - return l_input // r_input |
188 | | - |
189 | 194 | @compile_op.register(num_ops.ModOp) |
190 | 195 | def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: |
191 | 196 | return l_input % r_input |
@@ -270,6 +275,11 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: |
270 | 275 | # eg. We want "True" instead of "true" for bool to strin |
271 | 276 | return input.cast(_DTYPE_MAPPING[op.to_type], strict=not op.safe) |
272 | 277 |
|
| 278 | + @compile_op.register(string_ops.StrConcatOp) |
| 279 | + def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: |
| 280 | + assert isinstance(op, string_ops.StrConcatOp) |
| 281 | + return pl.concat_str(l_input, r_input) |
| 282 | + |
273 | 283 | @dataclasses.dataclass(frozen=True) |
274 | 284 | class PolarsAggregateCompiler: |
275 | 285 | scalar_compiler = PolarsExpressionCompiler() |
@@ -503,6 +513,30 @@ def compile_join(self, node: nodes.JoinNode): |
503 | 513 | left, right, node.type, left_on, right_on, node.joins_nulls |
504 | 514 | ) |
505 | 515 |
|
| 516 | + @compile_node.register |
| 517 | + def compile_isin(self, node: nodes.InNode): |
| 518 | + left = self.compile_node(node.left_child) |
| 519 | + right = self.compile_node(node.right_child).unique(node.right_col.id.sql) |
| 520 | + right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql)) |
| 521 | + |
| 522 | + left_ex, right_ex = lowering._coerce_comparables(node.left_col, node.right_col) |
| 523 | + |
| 524 | + left_pl_ex = self.expr_compiler.compile_expression(left_ex) |
| 525 | + right_pl_ex = self.expr_compiler.compile_expression(right_ex) |
| 526 | + |
| 527 | + joined = left.join( |
| 528 | + right, |
| 529 | + how="left", |
| 530 | + left_on=left_pl_ex, |
| 531 | + right_on=right_pl_ex, |
| 532 | + # Note: join_nulls renamed to nulls_equal for polars 1.24 |
| 533 | + join_nulls=node.joins_nulls, # type: ignore |
| 534 | + coalesce=False, |
| 535 | + ) |
| 536 | + passthrough = [pl.col(id) for id in left.columns] |
| 537 | + indicator = pl.col(node.indicator_col.sql).fill_null(False) |
| 538 | + return joined.select((*passthrough, indicator)) |
| 539 | + |
506 | 540 | def _ordered_join( |
507 | 541 | self, |
508 | 542 | left_frame: pl.LazyFrame, |
|
0 commit comments