Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions bigframes/core/rewrite/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,21 @@ def remap_variables(
new_child_nodes.append(new_child)
new_child_mappings.append(child_mappings)

new_root = root

# Step 2: Transform children to use their new nodes.
remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = {
child: new_child for child, new_child in zip(root.child_nodes, new_child_nodes)
}
new_root = root.transform_children(lambda node: remapped_children[node])
if isinstance(new_root, nodes.JoinNode) or isinstance(new_root, nodes.InNode):
new_root = dataclasses.replace(
new_root,
left_child=new_child_nodes[0],
right_child=new_child_nodes[1],
)
else:
remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = {
child: new_child
for child, new_child in zip(root.child_nodes, new_child_nodes)
}
new_root = root.transform_children(lambda node: remapped_children[node])

# Step 3: Transform the current node using the mappings from its children.
# "reversed" is required for InNode so that in case of a duplicate column ID,
Expand All @@ -70,6 +80,20 @@ def remap_variables(
new_child_mappings[0], allow_partial_bindings=True
),
)
elif isinstance(new_root, nodes.JoinNode):
new_root = typing.cast(nodes.JoinNode, new_root)
new_conds = tuple(
(
l_cond.remap_column_refs(
new_child_mappings[0], allow_partial_bindings=True
),
r_cond.remap_column_refs(
new_child_mappings[1], allow_partial_bindings=True
),
)
for l_cond, r_cond in new_root.conditions
)
new_root = dataclasses.replace(new_root, conditions=new_conds)
else:
new_root = new_root.remap_refs(downstream_mappings)

Expand Down
42 changes: 42 additions & 0 deletions tests/unit/core/rewrite/test_identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,45 @@ def test_remap_variables_in_node_converts_dag_to_tree(leaf, leaf_too):
left_col_id = new_node.left_col.id.name
new_node.validate_tree()
assert left_col_id.startswith("id_")


def test_remap_variables_join_self_stability(leaf):
# Create a join node with the same child twice
# We wrap them in distinct SelectionNodes so they can have their IDs remapped
# independently to avoid ID collisions in the resulting tree.
leaf_selection_left = nodes.SelectionNode(
leaf,
tuple(nodes.AliasedRef.identity(f.id) for f in leaf.fields),
)
leaf_selection_right = nodes.SelectionNode(
leaf,
tuple(nodes.AliasedRef.identity(f.id) for f in leaf.fields),
)

node = nodes.JoinNode(
left_child=leaf_selection_left,
right_child=leaf_selection_right,
conditions=(
(
ex.DerefOp(leaf_selection_left.fields[0].id),
ex.DerefOp(leaf_selection_right.fields[0].id),
),
),
type="inner",
propogate_order=False,
)

# Run remap_variables
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
# This used to raise KeyError before the fix
new_node, mapping = id_rewrite.remap_variables(node, id_generator)

assert isinstance(new_node, nodes.JoinNode)
new_node.validate_tree()

# Verify that conditions use child-specific IDs
left_cond, right_cond = new_node.conditions[0]
assert left_cond.id in new_node.left_child.ids
assert right_cond.id in new_node.right_child.ids
# Since it's a self-join remapped to a tree, the left and right IDs should be different
assert left_cond.id != right_cond.id
Loading