@@ -143,6 +143,8 @@ class NamedTupleType:
143143# This has the same type as Path, but different semantic meaning.
144144PathElements = Tuple [PathElement , ...]
145145PathElementsFn = Callable [[Any ], PathElements ]
146+ _ValueAndPath = Tuple [Any , PathElement ]
147+ OptimizedFlattenFn = Callable [[Any ], Iterable [_ValueAndPath ]]
146148FlattenFn = Callable [[Any ], Tuple [Tuple [Any , ...], Any ]]
147149UnflattenFn = Callable [[Iterable [Any ], Any ], Any ]
148150
@@ -155,6 +157,7 @@ class NodeTraverser:
155157 flatten : FlattenFn
156158 unflatten : UnflattenFn
157159 path_elements : PathElementsFn
160+ flatten_with_paths : OptimizedFlattenFn | None = None
158161
159162
160163class NodeTraverserRegistry :
@@ -185,6 +188,7 @@ def register_node_traverser(
185188 flatten_fn : FlattenFn ,
186189 unflatten_fn : UnflattenFn ,
187190 path_elements_fn : PathElementsFn ,
191+ flatten_with_paths_fn : OptimizedFlattenFn | None = None ,
188192 ) -> None :
189193 """Registers a node traverser for `node_type`.
190194
@@ -202,6 +206,9 @@ def register_node_traverser(
202206 flattened values returned by `flatten_fn`. This should accept an
203207 instance of `node_type`, and return a sequence of `PathElement`s aligned
204208 with the values returned by `flatten_fn`.
209+ flatten_with_paths_fn: A version of `flatten_fn` that returns an iterable
210+ of `(value, path)` pairs, where `value` is a child value and `path` is a
211+ `Path` to the value.
205212 """
206213 if not isinstance (node_type , type ):
207214 raise TypeError (f"`node_type` ({ node_type } ) must be a type." )
@@ -212,6 +219,7 @@ def register_node_traverser(
212219 flatten = flatten_fn ,
213220 unflatten = unflatten_fn ,
214221 path_elements = path_elements_fn ,
222+ flatten_with_paths = flatten_with_paths_fn ,
215223 )
216224
217225 def find_node_traverser (
@@ -282,7 +290,9 @@ def unflatten_defaultdict(values, metadata):
282290 tuple ,
283291 flatten_fn = lambda x : (x , None ),
284292 unflatten_fn = lambda x , _ : tuple (x ),
285- path_elements_fn = lambda x : tuple (Index (i ) for i in range (len (x ))))
293+ path_elements_fn = lambda x : tuple (Index (i ) for i in range (len (x ))),
294+ flatten_with_paths_fn = lambda xs : ((x , Index (i )) for i , x in enumerate (xs )),
295+ )
286296
287297register_node_traverser (
288298 NamedTupleType ,
@@ -294,7 +304,9 @@ def unflatten_defaultdict(values, metadata):
294304 list ,
295305 flatten_fn = lambda x : (tuple (x ), None ),
296306 unflatten_fn = lambda x , _ : list (x ),
297- path_elements_fn = lambda x : tuple (Index (i ) for i in range (len (x ))))
307+ path_elements_fn = lambda x : tuple (Index (i ) for i in range (len (x ))),
308+ flatten_with_paths_fn = lambda xs : ((x , Index (i )) for i , x in enumerate (xs )),
309+ )
298310
299311
300312def is_prefix (prefix_path : Path , containing_path : Path ):
@@ -610,6 +622,42 @@ def flattened_map_children(self, value: Any) -> SubTraversalResult:
610622 raise ValueError ("Please handle non-traversable values yourself." )
611623 return self ._flattened_map_children (value , node_traverser )
612624
625+ def fast_map_child_values (
626+ self , value : Any , ignore_leaves : bool = False
627+ ) -> Iterable [Any ]:
628+ """Maps over children for traversable values, but doesn't unflatten results.
629+
630+ This method only returns result values, so use it in place of
631+ `state.flattened_map_children(value).values`.
632+
633+ Args:
634+ value: Value to map over.
635+ ignore_leaves: If True, then this function will return an empty iterable
636+ if `value` is not traversable. Otherwise, it will raise a ValueError.
637+
638+ Yields:
639+ Sub-traversal results, the same type as returned by your _traverse
640+ function.
641+
642+ Raises:
643+ ValueError: If `value` is not traversable. Please test beforehand by
644+ calling `state.is_traversable()`.
645+ """
646+ node_traverser = self .traversal .find_node_traverser (type (value ))
647+ if node_traverser is None :
648+ if ignore_leaves :
649+ return
650+ else :
651+ raise ValueError ("Please handle non-traversable values yourself." )
652+ if node_traverser .flatten_with_paths is not None :
653+ for value , path in node_traverser .flatten_with_paths (value ):
654+ yield self .call (value , path )
655+ else :
656+ sub_values , unused_meta = node_traverser .flatten (value )
657+ path_elements = node_traverser .path_elements (value )
658+ for sub_value , path_element in zip (sub_values , path_elements ):
659+ yield self .call (sub_value , path_element )
660+
613661 def call (self , value , * additional_path : PathElement ):
614662 """Low-level function to execute a sub-traversal.
615663
0 commit comments