diff --git a/bigframes/core/rewrite/identifiers.py b/bigframes/core/rewrite/identifiers.py index da43fdf8b9..3d2841436f 100644 --- a/bigframes/core/rewrite/identifiers.py +++ b/bigframes/core/rewrite/identifiers.py @@ -50,11 +50,21 @@ def remap_variables( new_child_nodes.append(new_child) new_child_mappings.append(child_mappings) + new_root = root + # Step 2: Transform children to use their new nodes. - remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = { - child: new_child for child, new_child in zip(root.child_nodes, new_child_nodes) - } - new_root = root.transform_children(lambda node: remapped_children[node]) + if isinstance(new_root, nodes.JoinNode) or isinstance(new_root, nodes.InNode): + new_root = dataclasses.replace( + new_root, + left_child=new_child_nodes[0], + right_child=new_child_nodes[1], + ) + else: + remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = { + child: new_child + for child, new_child in zip(root.child_nodes, new_child_nodes) + } + new_root = root.transform_children(lambda node: remapped_children[node]) # Step 3: Transform the current node using the mappings from its children. # "reversed" is required for InNode so that in case of a duplicate column ID, @@ -66,10 +76,18 @@ def remap_variables( new_root = typing.cast(nodes.InNode, new_root) new_root = dataclasses.replace( new_root, - left_col=new_root.left_col.remap_column_refs( - new_child_mappings[0], allow_partial_bindings=True - ), + left_col=new_root.left_col.remap_column_refs(new_child_mappings[0]), + ) + elif isinstance(new_root, nodes.JoinNode): + new_root = typing.cast(nodes.JoinNode, new_root) + new_conds = tuple( + ( + l_cond.remap_column_refs(new_child_mappings[0]), + r_cond.remap_column_refs(new_child_mappings[1]), + ) + for l_cond, r_cond in new_root.conditions ) + new_root = dataclasses.replace(new_root, conditions=new_conds) else: new_root = new_root.remap_refs(downstream_mappings) diff --git a/tests/unit/core/rewrite/test_identifiers.py b/tests/unit/core/rewrite/test_identifiers.py index 09904ac4ba..69b17cb093 100644 --- a/tests/unit/core/rewrite/test_identifiers.py +++ b/tests/unit/core/rewrite/test_identifiers.py @@ -151,3 +151,45 @@ def test_remap_variables_in_node_converts_dag_to_tree(leaf, leaf_too): left_col_id = new_node.left_col.id.name new_node.validate_tree() assert left_col_id.startswith("id_") + + +def test_remap_variables_join_self_stability(leaf): + # Create a join node with the same child twice + # We wrap them in distinct SelectionNodes so they can have their IDs remapped + # independently to avoid ID collisions in the resulting tree. + leaf_selection_left = nodes.SelectionNode( + leaf, + tuple(nodes.AliasedRef.identity(f.id) for f in leaf.fields), + ) + leaf_selection_right = nodes.SelectionNode( + leaf, + tuple(nodes.AliasedRef.identity(f.id) for f in leaf.fields), + ) + + node = nodes.JoinNode( + left_child=leaf_selection_left, + right_child=leaf_selection_right, + conditions=( + ( + ex.DerefOp(leaf_selection_left.fields[0].id), + ex.DerefOp(leaf_selection_right.fields[0].id), + ), + ), + type="inner", + propogate_order=False, + ) + + # Run remap_variables + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + # This used to raise KeyError before the fix + new_node, mapping = id_rewrite.remap_variables(node, id_generator) + + assert isinstance(new_node, nodes.JoinNode) + new_node.validate_tree() + + # Verify that conditions use child-specific IDs + left_cond, right_cond = new_node.conditions[0] + assert left_cond.id in new_node.left_child.ids + assert right_cond.id in new_node.right_child.ids + # Since it's a self-join remapped to a tree, the left and right IDs should be different + assert left_cond.id != right_cond.id