|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import typing |
| 16 | + |
15 | 17 | import pytest |
16 | 18 |
|
17 | | -from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes |
| 19 | +from bigframes.core import agg_expressions as agg_exprs |
| 20 | +from bigframes.core import array_value, identifiers, nodes |
18 | 21 | from bigframes.operations import aggregations as agg_ops |
19 | 22 | import bigframes.pandas as bpd |
20 | 23 |
|
21 | 24 | pytest.importorskip("pytest_snapshot") |
22 | 25 |
|
23 | 26 |
|
24 | | -def _apply_unary_op(obj: bpd.DataFrame, op: agg_ops.UnaryWindowOp, arg: str) -> str: |
25 | | - agg_node = nodes.AggregateNode( |
26 | | - obj._block.expr.node, |
27 | | - aggregations=( |
28 | | - ( |
29 | | - agg_expressions.UnaryAggregation(op, expression.deref(arg)), |
30 | | - identifiers.ColumnId(arg + "_agg"), |
31 | | - ), |
32 | | - ), |
33 | | - ) |
| 27 | +def _apply_unary_agg_ops( |
| 28 | + obj: bpd.DataFrame, |
| 29 | + ops_list: typing.Sequence[agg_exprs.UnaryAggregation], |
| 30 | + new_names: typing.Sequence[str], |
| 31 | +) -> str: |
| 32 | + aggs = [(op, identifiers.ColumnId(name)) for op, name in zip(ops_list, new_names)] |
| 33 | + |
| 34 | + agg_node = nodes.AggregateNode(obj._block.expr.node, aggregations=tuple(aggs)) |
34 | 35 | result = array_value.ArrayValue(agg_node) |
35 | 36 |
|
36 | 37 | sql = result.session._executor.to_sql(result, enable_cache=False) |
37 | 38 | return sql |
38 | 39 |
|
39 | 40 |
|
40 | | -def test_size(scalar_types_df: bpd.DataFrame, snapshot): |
41 | | - bf_df = scalar_types_df[["string_col"]] |
42 | | - sql = _apply_unary_op(bf_df, agg_ops.SizeUnaryOp(), "string_col") |
| 41 | +def test_count(scalar_types_df: bpd.DataFrame, snapshot): |
| 42 | + col_name = "int64_col" |
| 43 | + bf_df = scalar_types_df[[col_name]] |
| 44 | + agg_expr = agg_ops.CountOp().as_expr(col_name) |
| 45 | + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) |
| 46 | + |
| 47 | + snapshot.assert_match(sql, "out.sql") |
| 48 | + |
| 49 | + |
| 50 | +def test_max(scalar_types_df: bpd.DataFrame, snapshot): |
| 51 | + col_name = "int64_col" |
| 52 | + bf_df = scalar_types_df[[col_name]] |
| 53 | + agg_expr = agg_ops.MaxOp().as_expr(col_name) |
| 54 | + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) |
| 55 | + |
| 56 | + snapshot.assert_match(sql, "out.sql") |
| 57 | + |
| 58 | + |
| 59 | +def test_min(scalar_types_df: bpd.DataFrame, snapshot): |
| 60 | + col_name = "int64_col" |
| 61 | + bf_df = scalar_types_df[[col_name]] |
| 62 | + agg_expr = agg_ops.MinOp().as_expr(col_name) |
| 63 | + sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name]) |
43 | 64 |
|
44 | 65 | snapshot.assert_match(sql, "out.sql") |
45 | 66 |
|
46 | 67 |
|
47 | 68 | def test_sum(scalar_types_df: bpd.DataFrame, snapshot): |
48 | | - bf_df = scalar_types_df[["int64_col"]] |
49 | | - sql = _apply_unary_op(bf_df, agg_ops.SumOp(), "int64_col") |
| 69 | + bf_df = scalar_types_df[["int64_col", "bool_col"]] |
| 70 | + agg_ops_map = { |
| 71 | + "int64_col": agg_ops.SumOp().as_expr("int64_col"), |
| 72 | + "bool_col": agg_ops.SumOp().as_expr("bool_col"), |
| 73 | + } |
| 74 | + sql = _apply_unary_agg_ops( |
| 75 | + bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) |
| 76 | + ) |
50 | 77 |
|
51 | 78 | snapshot.assert_match(sql, "out.sql") |
0 commit comments