Skip to content

Commit 3b109aa

Browse files
committed
refactor: fix the remap varaibles errors on InNode
1 parent b0ff718 commit 3b109aa

File tree

4 files changed

+130
-28
lines changed

4 files changed

+130
-28
lines changed

bigframes/core/nodes.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,15 @@ def remap_vars(
300300
def remap_refs(
301301
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
302302
) -> InNode:
303-
return dataclasses.replace(self, left_col=self.left_col.remap_column_refs(mappings, allow_partial_bindings=True), right_col=self.right_col.remap_column_refs(mappings, allow_partial_bindings=True)) # type: ignore
303+
return dataclasses.replace(
304+
self,
305+
left_col=self.left_col.remap_column_refs(
306+
mappings, allow_partial_bindings=True
307+
),
308+
right_col=self.right_col.remap_column_refs(
309+
mappings, allow_partial_bindings=True
310+
),
311+
) # type: ignore
304312

305313

306314
@dataclasses.dataclass(frozen=True, eq=False)

bigframes/core/rewrite/identifiers.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import dataclasses
1617
import typing
1718

1819
from bigframes.core import identifiers, nodes
@@ -26,32 +27,68 @@ def remap_variables(
2627
nodes.BigFrameNode,
2728
dict[identifiers.ColumnId, identifiers.ColumnId],
2829
]:
29-
"""Remaps `ColumnId`s in the BFET to produce deterministic and sequential UIDs.
30+
"""Remaps `ColumnId`s in the expression tree to be deterministic and sequential.
3031
31-
Note: this will convert a DAG to a tree.
32+
This function performs a post-order traversal. It recursively remaps children
33+
nodes first, then remaps the current node's references and definitions.
34+
35+
Note: this will convert a DAG to a tree by duplicating shared nodes.
36+
37+
Args:
38+
root: The root node of the expression tree.
39+
id_generator: An iterator that yields new column IDs.
40+
41+
Returns:
42+
A tuple of the new root node and a mapping from old to new column IDs
43+
visible to the parent node.
3244
"""
33-
child_replacement_map = dict()
34-
ref_mapping = dict()
35-
# Sequential ids are assigned bottom-up left-to-right
45+
# Step 1: Recursively remap children to get their new nodes and ID mappings.
46+
new_child_nodes: list[nodes.BigFrameNode] = []
47+
new_child_mappings: list[dict[identifiers.ColumnId, identifiers.ColumnId]] = []
3648
for child in root.child_nodes:
37-
new_child, child_var_mapping = remap_variables(child, id_generator=id_generator)
38-
child_replacement_map[child] = new_child
39-
ref_mapping.update(child_var_mapping)
40-
41-
# This is actually invalid until we've replaced all of children, refs and var defs
42-
with_new_children = root.transform_children(
43-
lambda node: child_replacement_map[node]
44-
)
45-
46-
with_new_refs = with_new_children.remap_refs(ref_mapping)
47-
48-
node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids}
49-
with_new_vars = with_new_refs.remap_vars(node_var_mapping)
50-
with_new_vars._validate()
51-
52-
return (
53-
with_new_vars,
54-
node_var_mapping
55-
if root.defines_namespace
56-
else (ref_mapping | node_var_mapping),
57-
)
49+
new_child, child_mappings = remap_variables(child, id_generator=id_generator)
50+
new_child_nodes.append(new_child)
51+
new_child_mappings.append(child_mappings)
52+
53+
# Step 2: Transform children to use their new nodes.
54+
remapped_children: dict[nodes.BigFrameNode, nodes.BigFrameNode] = {
55+
child: new_child for child, new_child in zip(root.child_nodes, new_child_nodes)
56+
}
57+
new_root = root.transform_children(lambda node: remapped_children[node])
58+
59+
# Step 3: Transform the current node using the mappings from its children.
60+
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
61+
k: v for mapping in new_child_mappings for k, v in mapping.items()
62+
}
63+
if isinstance(new_root, nodes.InNode):
64+
new_root = typing.cast(nodes.InNode, new_root)
65+
new_root = dataclasses.replace(
66+
new_root,
67+
left_col=new_root.left_col.remap_column_refs(
68+
new_child_mappings[0], allow_partial_bindings=True
69+
),
70+
right_col=new_root.right_col.remap_column_refs(
71+
new_child_mappings[1], allow_partial_bindings=True
72+
),
73+
)
74+
else:
75+
new_root = new_root.remap_refs(downstream_mappings)
76+
77+
# Step 4: Create new IDs for columns defined by the current node.
78+
node_defined_mappings = {
79+
old_id: next(id_generator) for old_id in root.node_defined_ids
80+
}
81+
new_root = new_root.remap_vars(node_defined_mappings)
82+
83+
new_root._validate()
84+
85+
# Step 5: Determine which mappings to propagate up to the parent.
86+
if root.defines_namespace:
87+
# If a node defines a new namespace (e.g., a join), mappings from its
88+
# children are not visible to its parents.
89+
mappings_for_parent = node_defined_mappings
90+
else:
91+
# Otherwise, pass up the combined mappings from children and the current node.
92+
mappings_for_parent = downstream_mappings | node_defined_mappings
93+
94+
return new_root, mappings_for_parent

tests/unit/core/rewrite/conftest.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,32 @@
3434

3535
@pytest.fixture
3636
def table():
37-
return TABLE
37+
table_ref = google.cloud.bigquery.TableReference.from_string(
38+
"project.dataset.table"
39+
)
40+
schema = (
41+
google.cloud.bigquery.SchemaField("col_a", "INTEGER"),
42+
google.cloud.bigquery.SchemaField("col_b", "INTEGER"),
43+
)
44+
return google.cloud.bigquery.Table(
45+
table_ref=table_ref,
46+
schema=schema,
47+
)
48+
49+
50+
@pytest.fixture
51+
def table_too():
52+
table_ref = google.cloud.bigquery.TableReference.from_string(
53+
"project.dataset.table_too"
54+
)
55+
schema = (
56+
google.cloud.bigquery.SchemaField("col_a", "INTEGER"),
57+
google.cloud.bigquery.SchemaField("col_c", "INTEGER"),
58+
)
59+
return google.cloud.bigquery.Table(
60+
table_ref=table_ref,
61+
schema=schema,
62+
)
3863

3964

4065
@pytest.fixture
@@ -49,3 +74,12 @@ def leaf(fake_session, table):
4974
table=table,
5075
schema=bigframes.core.schema.ArraySchema.from_bq_table(table),
5176
).node
77+
78+
79+
@pytest.fixture
80+
def leaf_too(fake_session, table_too):
81+
return core.ArrayValue.from_table(
82+
session=fake_session,
83+
table=table_too,
84+
schema=bigframes.core.schema.ArraySchema.from_bq_table(table_too),
85+
).node

tests/unit/core/rewrite/test_identifiers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import typing
1415

1516
import bigframes.core as core
17+
import bigframes.core.expression as ex
1618
import bigframes.core.identifiers as identifiers
1719
import bigframes.core.nodes as nodes
1820
import bigframes.core.rewrite.identifiers as id_rewrite
@@ -130,3 +132,24 @@ def test_remap_variables_concat_self_stability(leaf):
130132

131133
assert new_node1 == new_node2
132134
assert mapping1 == mapping2
135+
136+
137+
def test_remap_variables_in_node_converts_dag_to_tree(leaf, leaf_too):
138+
# Create an InNode with the same child twice, should create a tree from a DAG
139+
node = nodes.InNode(
140+
left_child=leaf,
141+
right_child=leaf_too,
142+
left_col=ex.DerefOp(identifiers.ColumnId("col_a")),
143+
right_col=ex.DerefOp(identifiers.ColumnId("col_a")),
144+
indicator_col=identifiers.ColumnId("indicator"),
145+
)
146+
147+
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
148+
new_node, _ = id_rewrite.remap_variables(node, id_generator)
149+
new_node = typing.cast(nodes.InNode, new_node)
150+
151+
left_col_id = new_node.left_col.id.name
152+
right_col_id = new_node.right_col.id.name
153+
assert left_col_id.startswith("id_")
154+
assert right_col_id.startswith("id_")
155+
assert left_col_id != right_col_id

0 commit comments

Comments
 (0)