Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3ce5f49
add remap_partitioning_select and test coverage
rjzamora Feb 24, 2026
0379e2e
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 25, 2026
3bcc305
support renamed columns
rjzamora Feb 25, 2026
fc103a5
Merge remote-tracking branch 'upstream/main' into select_preserves_pa…
rjzamora Feb 25, 2026
3fb09c0
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 25, 2026
1d6de55
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 25, 2026
fbf7ccf
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 25, 2026
3ac39ae
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 25, 2026
d288c3a
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 26, 2026
a90d9cf
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 26, 2026
1c745bf
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 27, 2026
b213811
Merge branch 'main' into select_preserves_partitioning
rjzamora Feb 27, 2026
a0128ee
revise
rjzamora Mar 2, 2026
50f1caf
Merge remote-tracking branch 'upstream/main' into select_preserves_pa…
rjzamora Mar 2, 2026
07e2d42
remove unused arg
rjzamora Mar 2, 2026
55042bc
Merge branch 'main' into select_preserves_partitioning
rjzamora Mar 2, 2026
c52d0aa
Merge branch 'main' into select_preserves_partitioning
rjzamora Mar 2, 2026
11fd9f0
Merge branch 'main' into select_preserves_partitioning
rjzamora Mar 3, 2026
faf9c81
address code review
rjzamora Mar 3, 2026
1fc002b
Merge branch 'main' into select_preserves_partitioning
rjzamora Mar 4, 2026
3a9305d
Merge remote-tracking branch 'upstream/main' into select_preserves_pa…
rjzamora Mar 4, 2026
bcb7ee5
fix typing
rjzamora Mar 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
ChannelManager,
chunk_to_frame,
empty_table_chunk,
maybe_remap_partitioning,
process_children,
recv_metadata,
remap_partitioning,
send_metadata,
)
from cudf_polars.experimental.utils import _concat
Expand Down Expand Up @@ -107,8 +107,10 @@ async def broadcast_join_actor(
# Preserve left-side partitioning metadata
local_count = left_metadata.local_count
# Remap partitioning from child schema to output schema
partitioning = remap_partitioning(
left_metadata.partitioning, large_child.schema, ir.schema
partitioning = maybe_remap_partitioning(
ir,
left_metadata.partitioning,
child_ir=ir.children[0],
)
# Check if the right-side is already broadcasted
small_duplicated = right_metadata.duplicated
Expand All @@ -122,8 +124,10 @@ async def broadcast_join_actor(
local_count = right_metadata.local_count
if ir.options[0] == "Right":
# Remap partitioning from child schema to output schema
partitioning = remap_partitioning(
right_metadata.partitioning, large_child.schema, ir.schema
partitioning = maybe_remap_partitioning(
ir,
right_metadata.partitioning,
child_ir=ir.children[1],
)
# Check if the right-side is already broadcasted
small_duplicated = left_metadata.duplicated
Expand Down
29 changes: 5 additions & 24 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)

from cudf_polars.containers import DataFrame
from cudf_polars.dsl.ir import IR, Cache, Empty, Filter, Projection
from cudf_polars.dsl.ir import IR, Empty
from cudf_polars.experimental.rapidsmpf.dispatch import (
generate_ir_sub_network,
)
Expand All @@ -28,9 +28,9 @@
chunkwise_evaluate,
empty_table_chunk,
make_spill_function,
maybe_remap_partitioning,
process_children,
recv_metadata,
remap_partitioning,
send_metadata,
shutdown_on_error,
)
Expand All @@ -51,8 +51,6 @@ async def default_node_single(
ir_context: IRExecutionContext,
ch_out: Channel[TableChunk],
ch_in: Channel[TableChunk],
*,
preserve_partitioning: bool = False,
) -> None:
"""
Single-channel default node for rapidsmpf.
Expand All @@ -69,8 +67,6 @@ async def default_node_single(
The output Channel[TableChunk].
ch_in
The input Channel[TableChunk].
preserve_partitioning
Whether to preserve the partitioning metadata of the input chunks.

Notes
-----
Expand All @@ -79,15 +75,9 @@ async def default_node_single(
async with shutdown_on_error(context, ch_in, ch_out, trace_ir=ir) as tracer:
# Recv metadata and prepare output metadata
metadata_in = await recv_metadata(ch_in, context)
partitioning = None
if preserve_partitioning:
# Remap partitioning if schema has changed
partitioning = remap_partitioning(
metadata_in.partitioning, ir.children[0].schema, ir.schema
)
metadata_out = ChannelMetadata(
local_count=metadata_in.local_count,
partitioning=partitioning,
partitioning=maybe_remap_partitioning(ir, metadata_in.partitioning),
duplicated=metadata_in.duplicated,
)

Expand Down Expand Up @@ -148,8 +138,8 @@ async def default_node_multi(
duplicated = duplicated and md_child.duplicated
if idx == partitioning_index:
# Remap partitioning from child schema to output schema
partitioning = remap_partitioning(
md_child.partitioning, ir.children[idx].schema, ir.schema
partitioning = maybe_remap_partitioning(
ir, md_child.partitioning, child_ir=ir.children[idx]
)
metadata = ChannelMetadata(
local_count=local_count,
Expand Down Expand Up @@ -506,22 +496,13 @@ def _(

if len(ir.children) == 1:
# Single-channel default node
preserve_partitioning = isinstance(
# TODO: We don't need to worry about
# non-pointwise Filter operations here,
# because the lowering stage would have
# collapsed to one partition anyway.
ir,
(Cache, Projection, Filter),
)
nodes[ir] = [
default_node_single(
rec.state["context"],
ir,
rec.state["ir_context"],
channels[ir].reserve_input_slot(),
channels[ir.children[0]].reserve_output_slot(),
preserve_partitioning=preserve_partitioning,
)
]
else:
Expand Down
156 changes: 117 additions & 39 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
import pylibcudf as plc

from cudf_polars.containers import DataFrame
from cudf_polars.dsl.expr import Col, NamedExpr
from cudf_polars.dsl.ir import Cache, Filter, Join, Projection, Select
from cudf_polars.experimental.utils import _concat

if TYPE_CHECKING:
from collections.abc import AsyncIterator, Callable, Mapping
from collections.abc import AsyncIterator, Callable

from rapidsmpf.communicator.communicator import Communicator
from rapidsmpf.streaming.core.channel import Channel
Expand All @@ -51,7 +53,51 @@
from cudf_polars.dsl.ir import IR, IRExecutionContext
from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator
from cudf_polars.experimental.rapidsmpf.tracing import ActorTracer
from cudf_polars.typing import DataType
from cudf_polars.typing import DataType, Schema


def indices_to_names(indices: tuple[int, ...], schema: Schema) -> tuple[str, ...]:
"""
Return column names for the given column indices in schema order.

Parameters
----------
indices
The indices to get names for.
schema
The schema to get names from.

Returns
-------
The column names for each index in schema order.
"""
keys = list(schema.keys())
return tuple(keys[i] for i in indices)


def names_to_indices(
names: tuple[str | NamedExpr, ...], schema: Schema
) -> tuple[int, ...]:
"""
Return column indices for the given names in schema order.

Accepts either column names (str) or NamedExpr, so it can be used with
e.g. ir.left_on, ir.right_on as well as plain name tuples.

Parameters
----------
names
The names to get indices for.
schema
The schema to get indices from.

Returns
-------
The column indices for each name in schema order.
"""
keys = list(schema.keys())
str_names = [n.name if isinstance(n, NamedExpr) else n for n in names]
return tuple(keys.index(n) for n in str_names)


@asynccontextmanager
Expand Down Expand Up @@ -118,53 +164,85 @@ async def shutdown_on_error(
structlog.contextvars.unbind_contextvars("actor_ir_id", "actor_ir_type")


def remap_partitioning(
partitioning: Partitioning | None,
old_schema: Mapping[str, DataType],
new_schema: Mapping[str, DataType],
def _remap_scheme_select(
select: Select, scheme: HashScheme | None | str
) -> HashScheme | None | str:
# We must check if this Select node preserves partitioning
# before we return a remapped scheme.
if isinstance(scheme, HashScheme):
# Mapping from old to new names for "col" selection
old_to_new_names = {
ne.value.name: ne.name for ne in select.exprs if isinstance(ne.value, Col)
}
old_keys = indices_to_names(scheme.column_indices, select.children[0].schema)
if set(old_keys).issubset(set(old_to_new_names)):
new_keys = names_to_indices(
tuple(old_to_new_names[o] for o in old_keys), select.schema
)
return HashScheme(new_keys, scheme.modulus)
return None
elif scheme not in (None, "inherit"): # pragma: no cover
return None # Guard against new/unsupported scheme types
return scheme


def _remap_scheme_simple(
ir: IR, scheme: HashScheme | None | str, child: IR
) -> HashScheme | None | str:
# Called when we know the IR node preserves partitioning.
# Just remap to the new schema if possible.
if isinstance(scheme, HashScheme):
old_keys = indices_to_names(scheme.column_indices, child.schema)
try:
new_indices = names_to_indices(old_keys, ir.schema)
except (ValueError, IndexError):
return None # Column missing in child or output schema
return HashScheme(new_indices, scheme.modulus)
return scheme # None or "inherit" passes through unchanged


def maybe_remap_partitioning(
ir: IR, partitioning: Partitioning | None, *, child_ir: IR | None = None
) -> Partitioning | None:
"""
Remap partitioning column indices from old schema to new schema.

Since HashScheme uses column indices rather than names, we need to
remap indices when propagating partitioning through operations that
may change the schema (column order or presence).
Remap partitioning for simple IR nodes.

Parameters
----------
ir
The IR node.
partitioning
The partitioning to remap.
old_schema
The schema where the partitioning was established.
new_schema
The new schema to remap to.
The input partitioning.
child_ir
The child IR whose schema the partitioning refers to. When None,
the first child (ir.children[0]) is used.

Returns
-------
The remapped partitioning, or None if the inter-rank partitioning
columns are not present in the new schema.
The remapped partitioning. When partition keys are not preserved,
the corresponding scheme will be set to None. When the original
partitioning is None, the output will also be None.

Notes
-----
A Select preserves partitioning if all partition key columns are
output as simple Col references (unchanged values). Other columns
can be computed expressions - only the partition keys matter.
"""
if partitioning is None:
return None

old_names = list(old_schema.keys())
new_name_to_idx = {name: i for i, name in enumerate(new_schema.keys())}

def remap_hash_scheme(hs: HashScheme | None | str) -> HashScheme | None | str:
if isinstance(hs, HashScheme):
try:
new_indices = tuple(
new_name_to_idx[old_names[i]] for i in hs.column_indices
)
except (IndexError, KeyError):
return None # Column missing in old or new schema
return HashScheme(new_indices, hs.modulus)
else:
return hs # None or "inherit" passes through unchanged

new_inter_rank = remap_hash_scheme(partitioning.inter_rank)
new_local = remap_hash_scheme(partitioning.local)
return Partitioning(inter_rank=new_inter_rank, local=new_local)
return None # Nothing to preserve
if isinstance(ir, Select):
return Partitioning(
inter_rank=_remap_scheme_select(ir, partitioning.inter_rank),
local=_remap_scheme_select(ir, partitioning.local),
)
if isinstance(ir, (Cache, Join, Projection, Filter)):
child = child_ir if child_ir is not None else ir.children[0]
return Partitioning(
inter_rank=_remap_scheme_simple(ir, partitioning.inter_rank, child),
local=_remap_scheme_simple(ir, partitioning.local, child),
)
return None


async def send_metadata(
Expand Down Expand Up @@ -300,7 +378,7 @@ async def evaluate_chunk(
async def concat_batch(
batch: list[TableChunk],
context: Context,
schema: Mapping[str, DataType],
schema: Schema,
ir_context: IRExecutionContext,
) -> TableChunk:
"""
Expand Down
Loading