diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index ff04c13..04c0df1 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -498,3 +498,97 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ) return resolve + + +def consistent_partition_window( + plan: PlanOrUnbound, + window_functions: Iterable[ExtendedExpressionOrUnbound], + partition_expressions: Iterable[ExtendedExpressionOrUnbound] = (), + sorts: Iterable[ + Union[ + ExtendedExpressionOrUnbound, + tuple[ExtendedExpressionOrUnbound, stalg.SortField.SortDirection.ValueType], + ] + ] = (), + extension: Optional[AdvancedExtension] = None, +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(bound_plan) + + bound_partitions = [ + resolve_expression(e, ns, registry) for e in partition_expressions + ] + + bound_sorts = [ + (e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST) + if not isinstance(e, tuple) + else e + for e in sorts + ] + bound_sorts = [ + (resolve_expression(e[0], ns, registry), e[1]) for e in bound_sorts + ] + + bound_window_fns = [ + resolve_expression(e, ns, registry) for e in window_functions + ] + + window_rel_functions = [] + for wf_ee in bound_window_fns: + wf_expr = wf_ee.referred_expr[0].expression.window_function + window_rel_functions.append( + stalg.ConsistentPartitionWindowRel.WindowRelFunction( + function_reference=wf_expr.function_reference, + arguments=list(wf_expr.arguments), + options=list(wf_expr.options), + output_type=wf_expr.output_type, + phase=wf_expr.phase, + invocation=wf_expr.invocation, + lower_bound=wf_expr.lower_bound + if wf_expr.HasField("lower_bound") + else None, + upper_bound=wf_expr.upper_bound + if wf_expr.HasField("upper_bound") + else None, + bounds_type=wf_expr.bounds_type, + ) + ) + + names = list(bound_plan.relations[-1].root.names) + [ + wf_ee.referred_expr[0].output_names[0] + if wf_ee.referred_expr[0].output_names + else f"window_{i}" + for i, wf_ee in enumerate(bound_window_fns) + ] + + rel = stalg.Rel( + window=stalg.ConsistentPartitionWindowRel( + input=bound_plan.relations[-1].root.input, + window_functions=window_rel_functions, + partition_expressions=[ + e.referred_expr[0].expression for e in bound_partitions + ], + sorts=[ + stalg.SortField( + expr=e[0].referred_expr[0].expression, + direction=e[1], + ) + for e in bound_sorts + ], + advanced_extension=extension, + ) + ) + + return stp.Plan( + version=default_version, + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions( + bound_plan, + *bound_partitions, + *[e[0] for e in bound_sorts], + *bound_window_fns, + ), + ) + + return resolve diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index 1868f25..5331965 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -342,6 +342,14 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct: raise Exception(f"Unhandled join_type {rel.join.type}") (common, struct) = (rel.join.common, raw_schema) + elif rel_type == "window": + parent_schema = infer_rel_schema(rel.window.input) + window_output_types = [wf.output_type for wf in rel.window.window_functions] + raw_schema = stt.Type.Struct( + types=list(parent_schema.types) + window_output_types, + nullability=parent_schema.nullability, + ) + (common, struct) = (rel.window.common, raw_schema) else: raise Exception(f"Unhandled rel_type {rel_type}") diff --git a/src/substrait/utils/display.py b/src/substrait/utils/display.py index cbfd09b..c7d6d42 100644 --- a/src/substrait/utils/display.py +++ b/src/substrait/utils/display.py @@ -171,6 +171,8 @@ def _stream_rel(self, rel: stalg.Rel, stream, depth: int): self._stream_extension_single_rel(rel.extension_single, stream, depth) elif rel.HasField("extension_multi"): self._stream_extension_multi_rel(rel.extension_multi, stream, depth) + elif rel.HasField("window"): + self._stream_window_rel(rel.window, stream, depth) else: stream.write(f"{indent}\n") @@ -401,6 +403,43 @@ def _stream_extension_multi_rel( f"{self._get_indent_with_arrow(depth + 2)}\n" ) + def _stream_window_rel( + self, window: stalg.ConsistentPartitionWindowRel, stream, depth: int + ): + """Print a consistent partition window relation concisely""" + indent = " " * (depth * self.indent_size) + + stream.write( + f"{indent}{self._color('window', Colors.MAGENTA)}: " + f"{self._color(str(len(window.window_functions)), Colors.YELLOW)} functions\n" + ) + stream.write( + f"{self._get_indent_with_arrow(depth + 1)}{self._color('input:', Colors.BLUE)}\n" + ) + self._stream_rel(window.input, stream, depth + 1) + + if window.partition_expressions: + stream.write( + f"{self._get_indent_with_arrow(depth + 1)}" + f"{self._color('partitions:', Colors.BLUE)} " + f"{self._color(str(len(window.partition_expressions)), Colors.YELLOW)}\n" + ) + + if window.sorts: + stream.write( + f"{self._get_indent_with_arrow(depth + 1)}" + f"{self._color('sorts:', Colors.BLUE)} " + f"{self._color(str(len(window.sorts)), Colors.YELLOW)}\n" + ) + + for i, wf in enumerate(window.window_functions): + stream.write( + f"{self._get_indent_with_arrow(depth + 1)}" + f"{self._color('window_fn', Colors.BLUE)}" + f"[{self._color(str(i), Colors.CYAN)}]: " + f"func_ref={wf.function_reference}\n" + ) + def _stream_expression(self, expression: stalg.Expression, stream, depth: int): """Print an expression concisely""" indent = " " * (depth * self.indent_size) diff --git a/tests/builders/plan/test_consistent_partition_window.py b/tests/builders/plan/test_consistent_partition_window.py new file mode 100644 index 0000000..212cb14 --- /dev/null +++ b/tests/builders/plan/test_consistent_partition_window.py @@ -0,0 +1,357 @@ +"""Tests for ConsistentPartitionWindowRel plan builder. + +Mirrors the Java test coverage from +ConsistentPartitionWindowRelRoundtripTest.java +""" + +import substrait.algebra_pb2 as stalg +import substrait.plan_pb2 as stp +import substrait.type_pb2 as stt +import yaml +from google.protobuf import json_format + +from substrait.builders.extended_expression import column, window_function +from substrait.builders.plan import ( + consistent_partition_window, + default_version, + read_named_table, +) +from substrait.builders.type import i16, i32, i64 +from substrait.extension_registry import ExtensionRegistry +from substrait.type_inference import infer_plan_schema + +content = """%YAML 1.2 +--- +urn: extension:test:urn +window_functions: + - name: "lead" + description: Lead window function + impls: + - args: + - name: x + value: i64 + nullability: DECLARED_OUTPUT + decomposable: NONE + return: i64 + + - name: "lag" + description: Lag window function + impls: + - args: + - name: x + value: i64 + nullability: DECLARED_OUTPUT + decomposable: NONE + return: i64 + + - name: "row_number" + description: Row number window function + impls: + - args: [] + nullability: DECLARED_OUTPUT + decomposable: NONE + return: i64 +""" + +registry = ExtensionRegistry(load_default_extensions=False) +registry.register_extension_dict( + yaml.safe_load(content), uri="https://test.example.com/test.yaml" +) + +struct = stt.Type.Struct( + types=[i64(nullable=False), i16(nullable=False), i32(nullable=False)], + nullability=stt.Type.NULLABILITY_REQUIRED, +) + +named_struct = stt.NamedStruct(names=["a", "b", "c"], struct=struct) + + +def _ref_expr(field: int): + """Helper: build a field reference expression for a struct field index.""" + return stalg.Expression( + selection=stalg.Expression.FieldReference( + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField(field=field) + ), + root_reference=stalg.Expression.FieldReference.RootReference(), + ) + ) + + +def test_consistent_partition_window_single(): + """Single window function with partition and sort. + + Mirrors Java's consistentPartitionWindowRoundtripSingle. + """ + table = read_named_table("test", named_struct) + + lead_fn = window_function( + "extension:test:urn", + "lead", + expressions=[column("a")], + partitions=[column("b")], + alias="lead_a", + ) + + actual = consistent_partition_window( + table, + window_functions=[lead_fn], + partition_expressions=[column("b")], + sorts=[(column("c"), stalg.SortField.SORT_DIRECTION_ASC_NULLS_FIRST)], + )(registry) + + ns = infer_plan_schema(table(None)) + lead_ee = lead_fn(ns, registry) + + expected = stp.Plan( + version=default_version, + extension_urns=list(lead_ee.extension_urns), + extension_uris=list(lead_ee.extension_uris), + extensions=list(lead_ee.extensions), + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + window=stalg.ConsistentPartitionWindowRel( + input=table(None).relations[-1].root.input, + window_functions=[ + stalg.ConsistentPartitionWindowRel.WindowRelFunction( + function_reference=lead_ee.referred_expr[ + 0 + ].expression.window_function.function_reference, + arguments=list( + lead_ee.referred_expr[ + 0 + ].expression.window_function.arguments + ), + output_type=lead_ee.referred_expr[ + 0 + ].expression.window_function.output_type, + ) + ], + partition_expressions=[_ref_expr(1)], + sorts=[ + stalg.SortField( + direction=stalg.SortField.SORT_DIRECTION_ASC_NULLS_FIRST, + expr=_ref_expr(2), + ) + ], + ) + ), + names=["a", "b", "c", "lead_a"], + ) + ) + ], + ) + + assert actual == expected + + +def test_consistent_partition_window_multi(): + """Multiple window functions sharing the same partition/sort. + + Mirrors Java's consistentPartitionWindowRoundtripMulti. + """ + table = read_named_table("test", named_struct) + + lead_fn = window_function( + "extension:test:urn", + "lead", + expressions=[column("a")], + partitions=[column("b")], + alias="lead_a", + ) + lag_fn = window_function( + "extension:test:urn", + "lag", + expressions=[column("a")], + partitions=[column("b")], + alias="lag_a", + ) + + actual = consistent_partition_window( + table, + window_functions=[lead_fn, lag_fn], + partition_expressions=[column("b")], + sorts=[(column("c"), stalg.SortField.SORT_DIRECTION_ASC_NULLS_FIRST)], + )(registry) + + rel = actual.relations[0].root.input + assert rel.HasField("window") + window = rel.window + + assert len(window.window_functions) == 2 + assert len(window.partition_expressions) == 1 + assert len(window.sorts) == 1 + + assert list(actual.relations[0].root.names) == ["a", "b", "c", "lead_a", "lag_a"] + + +def test_consistent_partition_window_sort_default_direction(): + """Sort without explicit direction defaults to ASC_NULLS_LAST.""" + table = read_named_table("test", named_struct) + + lead_fn = window_function( + "extension:test:urn", + "lead", + expressions=[column("a")], + partitions=[column("b")], + alias="lead_a", + ) + + actual = consistent_partition_window( + table, + window_functions=[lead_fn], + partition_expressions=[column("b")], + sorts=[column("c")], + )(registry) + + window = actual.relations[0].root.input.window + assert window.sorts[0].direction == stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST + + +def test_consistent_partition_window_no_partitions_no_sorts(): + """Window with no partitions or sorts — entire input is one partition.""" + table = read_named_table("test", named_struct) + + lead_fn = window_function( + "extension:test:urn", + "lead", + expressions=[column("a")], + alias="lead_a", + ) + + actual = consistent_partition_window( + table, + window_functions=[lead_fn], + )(registry) + + window = actual.relations[0].root.input.window + assert len(window.window_functions) == 1 + assert len(window.partition_expressions) == 0 + assert len(window.sorts) == 0 + + assert list(actual.relations[0].root.names) == ["a", "b", "c", "lead_a"] + + +def test_consistent_partition_window_schema_inference(): + """Output schema = input types + window function output types.""" + table = read_named_table("test", named_struct) + + lead_fn = window_function( + "extension:test:urn", + "lead", + expressions=[column("a")], + partitions=[column("b")], + alias="lead_a", + ) + lag_fn = window_function( + "extension:test:urn", + "lag", + expressions=[column("a")], + partitions=[column("b")], + alias="lag_a", + ) + + actual = consistent_partition_window( + table, + window_functions=[lead_fn, lag_fn], + partition_expressions=[column("b")], + sorts=[(column("c"), stalg.SortField.SORT_DIRECTION_ASC_NULLS_FIRST)], + )(registry) + + ns = infer_plan_schema(actual) + + # 3 original + 2 window = 5 + assert len(ns.struct.types) == 5 + assert list(ns.names) == ["a", "b", "c", "lead_a", "lag_a"] + + # Original types preserved + assert ns.struct.types[0].HasField("i64") + assert ns.struct.types[1].HasField("i16") + assert ns.struct.types[2].HasField("i32") + + # Window output types appended + assert ns.struct.types[3].HasField("i64") + assert ns.struct.types[4].HasField("i64") + + +def test_consistent_partition_window_roundtrip(): + """Proto serialization roundtrip — binary and JSON.""" + table = read_named_table("test", named_struct) + + lead_fn = window_function( + "extension:test:urn", + "lead", + expressions=[column("a")], + partitions=[column("b")], + alias="lead_a", + ) + + plan = consistent_partition_window( + table, + window_functions=[lead_fn], + partition_expressions=[column("b")], + sorts=[(column("c"), stalg.SortField.SORT_DIRECTION_ASC_NULLS_FIRST)], + )(registry) + + # Binary roundtrip + serialized = plan.SerializeToString() + deserialized = stp.Plan() + deserialized.ParseFromString(serialized) + assert plan == deserialized + + # JSON roundtrip + json_str = json_format.MessageToJson(plan) + json_plan = json_format.Parse(json_str, stp.Plan()) + assert plan == json_plan + + +def test_consistent_partition_window_extension_references(): + """Extension URIs/URNs/declarations are properly propagated.""" + table = read_named_table("test", named_struct) + + lead_fn = window_function( + "extension:test:urn", + "lead", + expressions=[column("a")], + alias="lead_a", + ) + + actual = consistent_partition_window( + table, + window_functions=[lead_fn], + )(registry) + + assert len(actual.extension_urns) > 0 + assert len(actual.extensions) > 0 + + func_ext = actual.extensions[0].extension_function + assert "lead" in func_ext.name + + +def test_consistent_partition_window_chained_with_project(): + """Window rel can be used as input to a project rel.""" + from substrait.builders.plan import project + + table = read_named_table("test", named_struct) + + lead_fn = window_function( + "extension:test:urn", + "lead", + expressions=[column("a")], + alias="lead_a", + ) + + windowed = consistent_partition_window( + table, + window_functions=[lead_fn], + partition_expressions=[column("b")], + ) + + actual = project(windowed, expressions=[column("a")])(registry) + + assert isinstance(actual, stp.Plan) + rel = actual.relations[0].root.input + assert rel.HasField("project") + assert rel.project.input.HasField("window")