Skip to content

Commit 2c2237b

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Correctly pass the old config to the model sharding fiddler generation, which makes the output significantly less verbose. Also uses the new short namer in case intermediate variables are needed.
PiperOrigin-RevId: 549703159
1 parent 01780f0 commit 2c2237b

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

fiddle/_src/codegen/codegen_diff.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
"""Library for converting generating fiddlers from diffs."""
1717

18+
import abc
1819
import collections
20+
import dataclasses
1921
import functools
2022
import re
2123
import types
@@ -31,12 +33,63 @@
3133
import libcst as cst
3234

3335

36+
@dataclasses.dataclass(frozen=True)
37+
class ObjectToName:
38+
prefix: str
39+
path: daglish.Path
40+
41+
def __hash__(self):
42+
return id(self)
43+
44+
45+
class VariableNamer(metaclass=abc.ABCMeta):
46+
"""Assigns names to intermediate variables."""
47+
48+
@abc.abstractmethod
49+
def assign_names(self, all_to_name: List[ObjectToName]) -> List[str]:
50+
"""Assigns names, which will be further disambiguated by a Namespace."""
51+
52+
53+
class ExplicitNamer(VariableNamer):
54+
"""Assigns long and explicit names to all variables."""
55+
56+
def assign_names(self, all_to_name: List[ObjectToName]) -> List[str]:
57+
return [
58+
to_name.prefix + _path_to_name(to_name.path) for to_name in all_to_name
59+
]
60+
61+
62+
class ShorterNamer(VariableNamer):
63+
"""Assigns shorter, partly-disambiguated names to all variables."""
64+
65+
initial_suffix_elements: int = 1
66+
67+
def assign_names(self, all_to_name: List[ObjectToName]) -> List[str]:
68+
name_to_paths = {}
69+
for to_name in all_to_name:
70+
sub_path = to_name.path[-self.initial_suffix_elements :]
71+
name_to_paths.setdefault(
72+
to_name.prefix + _path_to_name(sub_path), []
73+
).append(to_name)
74+
75+
result_as_dict = {}
76+
for name, group in name_to_paths.items():
77+
if len(group) == 1:
78+
result_as_dict[group[0]] = name
79+
else:
80+
for to_name in group:
81+
sub_path = to_name.path[-(self.initial_suffix_elements + 1) :]
82+
result_as_dict[to_name] = to_name.prefix + _path_to_name(sub_path)
83+
return [result_as_dict[to_name] for to_name in all_to_name]
84+
85+
3486
def fiddler_from_diff(
3587
diff: diffing.Diff,
3688
old: Any = None,
3789
func_name: str = 'fiddler',
3890
param_name: str = 'cfg',
3991
import_manager: Optional[import_manager_lib.ImportManager] = None,
92+
variable_namer: Optional[VariableNamer] = None,
4093
):
4194
"""Returns the CST for a fiddler function that applies the changes in `diff`.
4295
@@ -72,6 +125,8 @@ def fiddler_from_diff(
72125
import_manager: Existing import manager. Usually set to None, but if you are
73126
integrating this with other code generation tasks, it can be nice to
74127
share.
128+
variable_namer: Object (basically just a function) to assign names to
129+
intermediate variables.
75130
76131
Returns:
77132
An `cst.Module` object. You can convert this to a string using
@@ -90,25 +145,30 @@ def fiddler_from_diff(
90145
namespace.add(param_name)
91146
namespace.add(func_name)
92147

148+
variable_namer = ExplicitNamer() if variable_namer is None else variable_namer
149+
93150
# Get a list of paths that are referenced by the diff.
94151
used_paths = _find_used_paths(diff)
95152

96153
# Add variables for any used paths where the value (or any of the value's
97154
# ancestors) will be replaced by a change in the diff. If we don't have an
98155
# `old` structure, then we pessimistically assume that we need to create
99156
# variables for all used paths.
100-
moved_value_names = {}
157+
moved_values_to_name = []
101158
if old is not None:
102159
modified_paths = set([change.target for change in diff.changes])
103160
_add_path_aliases(modified_paths, old)
104161
for path in sorted(used_paths, key=daglish.path_str):
105162
if any(path[:i] in modified_paths for i in range(len(path) + 1)):
106-
moved_value_names[path] = namespace.get_new_name(
107-
_path_to_name(path), f'moved_{param_name}_')
163+
moved_values_to_name.append(ObjectToName(f'moved_{param_name}_', path))
108164
else:
109165
for path in sorted(used_paths, key=daglish.path_str):
110-
moved_value_names[path] = namespace.get_new_name(
111-
_path_to_name(path), f'original_{param_name}_')
166+
moved_values_to_name.append(ObjectToName(f'original_{param_name}_', path))
167+
initial_names = variable_namer.assign_names(moved_values_to_name)
168+
moved_value_names = {
169+
to_name.path: namespace.get_new_name(name, prefix='')
170+
for to_name, name in zip(moved_values_to_name, initial_names)
171+
}
112172

113173
# Add variables for new shared values added by the diff.
114174
new_shared_value_names = [

fiddle/codegen/codegen_diff.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,7 @@
1616
"""Library for converting generating fiddlers from diffs."""
1717

1818
# pylint: disable=unused-import
19+
from fiddle._src.codegen.codegen_diff import ExplicitNamer
1920
from fiddle._src.codegen.codegen_diff import fiddler_from_diff
21+
from fiddle._src.codegen.codegen_diff import ShorterNamer
22+
from fiddle._src.codegen.codegen_diff import VariableNamer

0 commit comments

Comments
 (0)