Skip to content
2 changes: 1 addition & 1 deletion bigframes/_config/experiment_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ExperimentOptions:
def __init__(self):
self._semantic_operators: bool = False
self._ai_operators: bool = False
self._sql_compiler: Literal["legacy", "stable", "experimental"] = "stable"
self._sql_compiler: Literal["legacy", "stable", "experimental"] = "experimental"

@property
def semantic_operators(self) -> bool:
Expand Down
61 changes: 5 additions & 56 deletions bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations

import functools
import itertools
import typing
from typing import Literal, Optional, Sequence
Expand All @@ -27,7 +26,7 @@
from google.cloud import bigquery
import pyarrow as pa

from bigframes.core import agg_expressions
from bigframes.core import agg_expressions, rewrite
import bigframes.core.agg_expressions as ex_types
import bigframes.core.compile.googlesql
import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compiler
Expand All @@ -38,8 +37,6 @@
import bigframes.core.sql
from bigframes.core.window_spec import WindowSpec
import bigframes.dtypes
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops

op_compiler = op_compilers.scalar_op_compiler

Expand Down Expand Up @@ -424,59 +421,11 @@ def project_window_op(
output_name,
)

if expression.op.order_independent and window_spec.is_unbounded:
# notably percentile_cont does not support ordering clause
window_spec = window_spec.without_order()

# TODO: Turn this logic into a true rewriter
result_expr: ex.Expression = agg_expressions.WindowExpression(
expression, window_spec
rewritten_expr = rewrite.simplify_complex_windows(
agg_expressions.WindowExpression(expression, window_spec)
)
clauses: list[tuple[ex.Expression, ex.Expression]] = []
if window_spec.min_periods and len(expression.inputs) > 0:
if not expression.op.nulls_count_for_min_values:
is_observation = ops.notnull_op.as_expr()

# Most operations do not count NULL values towards min_periods
per_col_does_count = (
ops.notnull_op.as_expr(input) for input in expression.inputs
)
# All inputs must be non-null for observation to count
is_observation = functools.reduce(
lambda x, y: ops.and_op.as_expr(x, y), per_col_does_count
)
observation_sentinel = ops.AsTypeOp(bigframes.dtypes.INT_DTYPE).as_expr(
is_observation
)
observation_count_expr = agg_expressions.WindowExpression(
ex_types.UnaryAggregation(agg_ops.sum_op, observation_sentinel),
window_spec,
)
else:
# Operations like count treat even NULLs as valid observations for the sake of min_periods
# notnull is just used to convert null values to non-null (FALSE) values to be counted
is_observation = ops.notnull_op.as_expr(expression.inputs[0])
observation_count_expr = agg_expressions.WindowExpression(
agg_ops.count_op.as_expr(is_observation),
window_spec,
)
clauses.append(
(
ops.lt_op.as_expr(
observation_count_expr, ex.const(window_spec.min_periods)
),
ex.const(None),
)
)
if clauses:
case_inputs = [
*itertools.chain.from_iterable(clauses),
ex.const(True),
result_expr,
]
result_expr = ops.CaseWhenOp().as_expr(*case_inputs)

ibis_expr = op_compiler.compile_expression(result_expr, self._ibis_bindings)

ibis_expr = op_compiler.compile_expression(rewritten_expr, self._ibis_bindings)

return UnorderedIR(self._table, (*self.columns, ibis_expr.name(output_name)))

Expand Down
10 changes: 5 additions & 5 deletions bigframes/core/compile/sqlglot/aggregate_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
ordered_unary_compiler,
unary_compiler,
)
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
from bigframes.core.compile.sqlglot.expressions import typed_expr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler


def compile_aggregate(
Expand All @@ -35,7 +35,7 @@ def compile_aggregate(
return nullary_compiler.compile(aggregate.op)
if isinstance(aggregate, agg_expressions.UnaryAggregation):
column = typed_expr.TypedExpr(
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg),
expression_compiler.expression_compiler.compile_expression(aggregate.arg),
aggregate.arg.output_type,
)
if not aggregate.op.order_independent:
Expand All @@ -46,11 +46,11 @@ def compile_aggregate(
return unary_compiler.compile(aggregate.op, column)
elif isinstance(aggregate, agg_expressions.BinaryAggregation):
left = typed_expr.TypedExpr(
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.left),
expression_compiler.expression_compiler.compile_expression(aggregate.left),
aggregate.left.output_type,
)
right = typed_expr.TypedExpr(
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.right),
expression_compiler.expression_compiler.compile_expression(aggregate.right),
aggregate.right.output_type,
)
return binary_compiler.compile(aggregate.op, left, right)
Expand All @@ -66,7 +66,7 @@ def compile_analytic(
return nullary_compiler.compile(aggregate.op, window)
if isinstance(aggregate, agg_expressions.UnaryAggregation):
column = typed_expr.TypedExpr(
scalar_compiler.scalar_op_compiler.compile_expression(aggregate.arg),
expression_compiler.expression_compiler.compile_expression(aggregate.arg),
aggregate.arg.output_type,
)
return unary_compiler.compile(aggregate.op, column, window)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def compile(
right: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
if op.order_independent and (window is not None) and window.is_unbounded:
window = window.without_order()
return BINARY_OP_REGISTRATION[op](op, left, right, window=window)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def compile(
op: agg_ops.WindowOp,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
if op.order_independent and (window is not None) and window.is_unbounded:
window = window.without_order()
return NULLARY_OP_REGISTRATION[op](op, window=window)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def compile(
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
if op.order_independent and (window is not None) and window.is_unbounded:
window = window.without_order()
return UNARY_OP_REGISTRATION[op](op, column, window=window)


Expand Down
14 changes: 7 additions & 7 deletions bigframes/core/compile/sqlglot/aggregations/windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import bigframes_vendored.sqlglot.expressions as sge

from bigframes.core import utils, window_spec
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
import bigframes.core.compile.sqlglot.expression_compiler as expression_compiler
import bigframes.core.expression as ex
import bigframes.core.ordering as ordering_spec
import bigframes.dtypes as dtypes
Expand Down Expand Up @@ -116,7 +116,7 @@ def get_window_order_by(

order_by = []
for ordering_spec_item in ordering:
expr = scalar_compiler.scalar_op_compiler.compile_expression(
expr = expression_compiler.expression_compiler.compile_expression(
ordering_spec_item.scalar_expression
)
desc = not ordering_spec_item.direction.is_ascending
Expand Down Expand Up @@ -191,15 +191,15 @@ def _get_window_bounds(


def _compile_group_by_key(key: ex.Expression) -> sge.Expression:
expr = scalar_compiler.scalar_op_compiler.compile_expression(key)
expr = expression_compiler.expression_compiler.compile_expression(key)
# The group_by keys has been rewritten by bind_schema_to_node
assert isinstance(key, ex.ResolvedDerefOp)
assert key.is_scalar_expr and key.is_resolved

# Some types need to be converted to another type to enable groupby
if key.dtype == dtypes.FLOAT_DTYPE:
if key.output_type == dtypes.FLOAT_DTYPE:
expr = sge.Cast(this=expr, to="STRING")
elif key.dtype == dtypes.GEO_DTYPE:
elif key.output_type == dtypes.GEO_DTYPE:
expr = sge.func("ST_ASBINARY", expr)
elif key.dtype == dtypes.JSON_DTYPE:
elif key.output_type == dtypes.JSON_DTYPE:
expr = sge.func("TO_JSON_STRING", expr)
return expr
Loading