Skip to content

Commit 25ea157

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
[draft] Optimized API for Daglish read-only traversals (traversals that do not
re-construct objects). PiperOrigin-RevId: 614885965
1 parent 2a17618 commit 25ea157

File tree

12 files changed

+184
-37
lines changed

12 files changed

+184
-37
lines changed

fiddle/_src/codegen/auto_config/complex_to_variables.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ def traverse(value, state: daglish.State) -> int:
5050
elif not state.is_traversable(value):
5151
return 1
5252
else:
53-
sub_values = state.flattened_map_children(value)
54-
return 1 + sum(sub_values.values)
53+
return 1 + sum(state.fast_map_child_values(value))
5554

5655
return lambda x: daglish.MemoizedTraversal.run(traverse, x) > level
5756

fiddle/_src/codegen/legacy_codegen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def traverse(value, state: daglish.State):
7272
if isinstance(value, config_lib.Buildable):
7373
to_count[id(value)] += 1
7474
children_by_id[id(value)] = value
75-
if state.is_traversable(value):
76-
state.flattened_map_children(value)
75+
for _ in state.fast_map_child_values(value, ignore_leaves=True):
76+
pass # Run lazy iterator.
7777

7878
daglish.BasicTraversal.run(traverse, buildable)
7979
return [
@@ -194,10 +194,10 @@ def _configure_shared_objects(
194194
variable_name_prefix: Prefix for any variables introduced.
195195
"""
196196

197-
def traverse(child, state):
197+
def traverse(child, state: daglish.State):
198198
"""Generates code for a shared instance."""
199-
if state.is_traversable(child):
200-
state.flattened_map_children(child)
199+
for _ in state.fast_map_child_values(child, ignore_leaves=True):
200+
pass # Run lazy iterator.
201201
if isinstance(child, config_lib.Buildable):
202202
# Name this better..
203203
name = shared_manager.namespace.get_new_name(

fiddle/_src/daglish.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ class NamedTupleType:
143143
# This has the same type as Path, but different semantic meaning.
144144
PathElements = Tuple[PathElement, ...]
145145
PathElementsFn = Callable[[Any], PathElements]
146+
_ValueAndPath = Tuple[Any, PathElement]
147+
OptimizedFlattenFn = Callable[[Any], Iterable[_ValueAndPath]]
146148
FlattenFn = Callable[[Any], Tuple[Tuple[Any, ...], Any]]
147149
UnflattenFn = Callable[[Iterable[Any], Any], Any]
148150

@@ -155,6 +157,7 @@ class NodeTraverser:
155157
flatten: FlattenFn
156158
unflatten: UnflattenFn
157159
path_elements: PathElementsFn
160+
flatten_with_paths: OptimizedFlattenFn | None = None
158161

159162

160163
class NodeTraverserRegistry:
@@ -185,6 +188,7 @@ def register_node_traverser(
185188
flatten_fn: FlattenFn,
186189
unflatten_fn: UnflattenFn,
187190
path_elements_fn: PathElementsFn,
191+
flatten_with_paths_fn: OptimizedFlattenFn | None = None,
188192
) -> None:
189193
"""Registers a node traverser for `node_type`.
190194
@@ -202,6 +206,9 @@ def register_node_traverser(
202206
flattened values returned by `flatten_fn`. This should accept an
203207
instance of `node_type`, and return a sequence of `PathElement`s aligned
204208
with the values returned by `flatten_fn`.
209+
flatten_with_paths_fn: A version of `flatten_fn` that returns an iterable
210+
of `(value, path)` pairs, where `value` is a child value and `path` is a
211+
`Path` to the value.
205212
"""
206213
if not isinstance(node_type, type):
207214
raise TypeError(f"`node_type` ({node_type}) must be a type.")
@@ -212,6 +219,7 @@ def register_node_traverser(
212219
flatten=flatten_fn,
213220
unflatten=unflatten_fn,
214221
path_elements=path_elements_fn,
222+
flatten_with_paths=flatten_with_paths_fn,
215223
)
216224

217225
def find_node_traverser(
@@ -282,7 +290,9 @@ def unflatten_defaultdict(values, metadata):
282290
tuple,
283291
flatten_fn=lambda x: (x, None),
284292
unflatten_fn=lambda x, _: tuple(x),
285-
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))))
293+
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))),
294+
flatten_with_paths_fn=lambda xs: ((x, Index(i)) for i, x in enumerate(xs)),
295+
)
286296

287297
register_node_traverser(
288298
NamedTupleType,
@@ -294,7 +304,9 @@ def unflatten_defaultdict(values, metadata):
294304
list,
295305
flatten_fn=lambda x: (tuple(x), None),
296306
unflatten_fn=lambda x, _: list(x),
297-
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))))
307+
path_elements_fn=lambda x: tuple(Index(i) for i in range(len(x))),
308+
flatten_with_paths_fn=lambda xs: ((x, Index(i)) for i, x in enumerate(xs)),
309+
)
298310

299311

300312
def is_prefix(prefix_path: Path, containing_path: Path):
@@ -610,6 +622,42 @@ def flattened_map_children(self, value: Any) -> SubTraversalResult:
610622
raise ValueError("Please handle non-traversable values yourself.")
611623
return self._flattened_map_children(value, node_traverser)
612624

625+
def fast_map_child_values(
626+
self, value: Any, ignore_leaves: bool = False
627+
) -> Iterable[Any]:
628+
"""Maps over children for traversable values, but doesn't unflatten results.
629+
630+
This method only returns result values, so use it in place of
631+
`state.flattened_map_children(value).values`.
632+
633+
Args:
634+
value: Value to map over.
635+
ignore_leaves: If True, then this function will return an empty iterable
636+
if `value` is not traversable. Otherwise, it will raise a ValueError.
637+
638+
Yields:
639+
Sub-traversal results, the same type as returned by your _traverse
640+
function.
641+
642+
Raises:
643+
ValueError: If `value` is not traversable. Please test beforehand by
644+
calling `state.is_traversable()`.
645+
"""
646+
node_traverser = self.traversal.find_node_traverser(type(value))
647+
if node_traverser is None:
648+
if ignore_leaves:
649+
return
650+
else:
651+
raise ValueError("Please handle non-traversable values yourself.")
652+
if node_traverser.flatten_with_paths is not None:
653+
for value, path in node_traverser.flatten_with_paths(value):
654+
yield self.call(value, path)
655+
else:
656+
sub_values, unused_meta = node_traverser.flatten(value)
657+
path_elements = node_traverser.path_elements(value)
658+
for sub_value, path_element in zip(sub_values, path_elements):
659+
yield self.call(sub_value, path_element)
660+
613661
def call(self, value, *additional_path: PathElement):
614662
"""Low-level function to execute a sub-traversal.
615663
@@ -755,8 +803,8 @@ def collect_paths_by_id(
755803
def traverse(value, state: State):
756804
if not memoizable_only or is_memoizable(value):
757805
paths_by_id.setdefault(id(value), []).append(state.current_path)
758-
if state.is_traversable(value):
759-
state.flattened_map_children(value)
806+
for _ in state.fast_map_child_values(value, ignore_leaves=True):
807+
pass # Run lazy iterator.
760808

761809
traversal = BasicTraversal(traverse, structure, registry=registry)
762810
traverse(structure, traversal.initial_state())
@@ -794,9 +842,8 @@ def iterate(
794842

795843
def _traverse(node, state: State):
796844
yield node, state.current_path
797-
if state.is_traversable(node):
798-
for sub_result in state.flattened_map_children(node).values:
799-
yield from sub_result
845+
for sub_result in state.fast_map_child_values(node, ignore_leaves=True):
846+
yield from sub_result
800847

801848
if memoized:
802849
traversal = MemoizedTraversal(

fiddle/_src/daglish_test.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import enum
2121
import json
2222
import random
23-
from typing import Any, List, NamedTuple, Optional, cast
23+
from typing import Any, List, NamedTuple, Optional, Tuple, cast
2424

2525
from absl.testing import absltest
2626
from absl.testing import parameterized
@@ -76,6 +76,16 @@ def __call__(self, value, state: Optional[daglish.State] = None):
7676
return state.map_children(value)
7777

7878

79+
@dataclasses.dataclass
80+
class NonRecursingLoggingFunction:
81+
values_and_paths: List[Tuple[Any, Any]] = dataclasses.field(
82+
default_factory=list
83+
)
84+
85+
def __call__(self, value, state: daglish.State):
86+
self.values_and_paths.append((value, state.current_path))
87+
88+
7989
def switch_buildables_to_args(value, state: Optional[daglish.State] = None):
8090
"""Replaces buildables with their arguments dictionary.
8191
@@ -97,6 +107,28 @@ def switch_buildables_to_args(value, state: Optional[daglish.State] = None):
97107
return value
98108

99109

110+
@dataclasses.dataclass
111+
class MyRange:
112+
start: int
113+
end: int
114+
115+
116+
def myrange_flatten_with_paths_fn(myrange: MyRange):
117+
return (
118+
({"value": i}, daglish.Index(i))
119+
for i in range(myrange.start, myrange.end)
120+
)
121+
122+
123+
daglish.register_node_traverser(
124+
MyRange,
125+
flatten_fn=NotImplemented, # pytype: disable=wrong-arg-types
126+
unflatten_fn=NotImplemented, # pytype: disable=wrong-arg-types
127+
path_elements_fn=NotImplemented, # pytype: disable=wrong-arg-types
128+
flatten_with_paths_fn=myrange_flatten_with_paths_fn,
129+
)
130+
131+
100132
class PathTest(parameterized.TestCase):
101133

102134
@parameterized.named_parameters(
@@ -437,6 +469,73 @@ def test_argument_history(self):
437469
history.ChangeKind.NEW_VALUE)
438470

439471

472+
_eager_map_fns = [
473+
lambda state, obj: state.map_children(obj),
474+
lambda state, obj: list(state.fast_map_child_values(obj)),
475+
lambda state, obj: state.flattened_map_children(obj),
476+
]
477+
478+
479+
class StateApiTest(parameterized.TestCase):
480+
481+
@parameterized.parameters(_eager_map_fns)
482+
def test_map_dict(self, map_fn):
483+
obj = {"a": 1, "b": 2}
484+
log_calls = NonRecursingLoggingFunction()
485+
traversal = daglish.BasicTraversal(log_calls, obj)
486+
state = traversal.initial_state()
487+
map_fn(state, obj)
488+
self.assertEqual(
489+
log_calls.values_and_paths,
490+
[(1, (daglish.Key(key="a"),)), (2, (daglish.Key(key="b"),))],
491+
)
492+
493+
@parameterized.parameters(_eager_map_fns)
494+
def test_map_tuple(self, map_fn):
495+
obj = ((), (1, 2), 3)
496+
log_calls = NonRecursingLoggingFunction()
497+
traversal = daglish.BasicTraversal(log_calls, obj)
498+
state = traversal.initial_state()
499+
map_fn(state, obj)
500+
self.assertEqual(
501+
log_calls.values_and_paths,
502+
[
503+
((), (daglish.Index(index=0),)),
504+
((1, 2), (daglish.Index(index=1),)),
505+
(3, (daglish.Index(index=2),)),
506+
],
507+
)
508+
509+
@parameterized.parameters(_eager_map_fns)
510+
def test_map_memoized(self, map_fn):
511+
shared = {"foo": 123}
512+
obj = [shared, shared, shared]
513+
log_calls = NonRecursingLoggingFunction()
514+
traversal = daglish.MemoizedTraversal(log_calls, obj)
515+
state = traversal.initial_state()
516+
map_fn(state, obj)
517+
self.assertEqual(
518+
log_calls.values_and_paths, [({"foo": 123}, (daglish.Index(index=0),))]
519+
)
520+
521+
def test_fast_map_calls_flatten_with_paths(self):
522+
obj = MyRange(3, 7)
523+
log_calls = NonRecursingLoggingFunction()
524+
traversal = daglish.MemoizedTraversal(log_calls, obj)
525+
state = traversal.initial_state()
526+
for _ in state.fast_map_child_values(obj):
527+
pass # Run lazy iterator.
528+
self.assertEqual(
529+
log_calls.values_and_paths,
530+
[
531+
({"value": 3}, (daglish.Index(index=3),)),
532+
({"value": 4}, (daglish.Index(index=4),)),
533+
({"value": 5}, (daglish.Index(index=5),)),
534+
({"value": 6}, (daglish.Index(index=6),)),
535+
],
536+
)
537+
538+
440539
class ArgsSwitchingFuzzTest(parameterized.TestCase):
441540

442541
def test_fuzz(self):
@@ -595,7 +694,7 @@ def traverse(value, state: daglish.State):
595694
daglish.path_str(path)
596695
for path in state.get_all_paths(allow_caching=True)
597696
),
598-
"sub_values": state.flattened_map_children(value).values,
697+
"sub_values": list(state.fast_map_child_values(value)),
599698
}
600699
else:
601700
return "leaf value"

fiddle/_src/debug/grep.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def grep(
6363
def traverse(value, state: daglish.State):
6464
path_str = daglish.path_str(state.current_path)
6565
if state.is_traversable(value):
66-
state.flattened_map_children(value)
66+
for _ in state.fast_map_child_values(value):
67+
pass # Run lazy iterator.
6768
if isinstance(value, config_lib.Buildable):
6869
fn_or_cls = config_lib.get_callable(value)
6970
value_str = f"<{type(value).__name__}({fn_or_cls.__name__})>"

fiddle/_src/experimental/visualize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _helper(value, substate: daglish.State) -> bool:
120120
for sub_path in substate.get_all_paths():
121121
if not _goes_through_node(sub_path):
122122
return False
123-
return all(substate.flattened_map_children(value).values)
123+
return all(substate.fast_map_child_values(value))
124124

125125
# Creates a sub-traversal using a different function. We eventually might
126126
# make this part of the daglish API.
@@ -199,8 +199,8 @@ def traverse(node, state: daglish.State) -> None:
199199
all_paths = state.get_all_paths(allow_caching=True)
200200
node_to_depth[id(node)] = min(_path_len(path) for path in all_paths)
201201
id_to_node[id(node)] = node
202-
if state.is_traversable(node):
203-
state.flattened_map_children(node)
202+
for _ in state.fast_map_child_values(node, ignore_leaves=True):
203+
pass # Run lazy iterator.
204204

205205
daglish.MemoizedTraversal.run(traverse, config)
206206

fiddle/_src/graphviz.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,10 +833,10 @@ def visit(value, state: daglish.State):
833833
"""Returns true if any child has changed."""
834834
parents_changed = id(value) in changed_parent_ids
835835
if state.is_traversable(value):
836-
subtraversal = state.flattened_map_children(value)
837-
any_changed = any(subtraversal.values)
836+
child_results = list(state.fast_map_child_values(value))
837+
any_changed = any(child_results)
838838
if isinstance(value, dict) and id(value) in old_value_ids:
839-
_trim_dict(value, subtraversal.values)
839+
_trim_dict(value, child_results)
840840
return any_changed or parents_changed
841841
elif isinstance(value, _ChangedValue):
842842
state.call(value.old_value, daglish.Attr('old_value'))
@@ -866,7 +866,8 @@ def _find_mutable_values_with_changed_parents(structure_with_changed_values):
866866

867867
def visit(value, state: daglish.State):
868868
if state.is_traversable(value):
869-
state.flattened_map_children(value)
869+
for _ in state.fast_map_child_values(value):
870+
pass # Run lazy iterator.
870871
elif isinstance(value, _ChangedValue):
871872
assert value.old_value is not value.new_value
872873
if daglish.is_memoizable(value.old_value):

fiddle/_src/materialize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def traverse(node, state: daglish.State):
5555
for arg in node.__signature_info__.parameters.values():
5656
if arg.default is not arg.empty and arg.name not in node.__arguments__:
5757
setattr(node, arg.name, arg.default)
58-
if state.is_traversable(node):
59-
state.flattened_map_children(node)
58+
for _ in state.fast_map_child_values(node, ignore_leaves=True):
59+
pass # Run lazy iterator.
6060

6161
daglish.MemoizedTraversal.run(traverse, value)

fiddle/_src/partial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ class _BuiltArgFactory:
4545
def _contains_arg_factory(value: Any) -> bool:
4646
"""Returns true if ``value`` contains any ``_BuiltArgFactory`` instances."""
4747

48-
def visit(node, state):
48+
def visit(node, state: daglish.State):
4949
if isinstance(node, _BuiltArgFactory):
5050
return True
5151
elif state.is_traversable(node):
52-
return any(state.flattened_map_children(node).values)
52+
return any(state.fast_map_child_values(node))
5353
else:
5454
return False
5555

0 commit comments

Comments
 (0)