Skip to content

Commit 1cace11

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 568426297
1 parent 2813ea4 commit 1cace11

3 files changed

Lines changed: 209 additions & 5 deletions

File tree

fiddle/_src/printing.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import dataclasses
2020
import inspect
2121
import types
22-
from typing import Any, Iterator, List, Optional, Type
22+
from typing import Any, Dict, Iterator, List, Optional, Type
2323

2424
from fiddle._src import config
2525
from fiddle._src import daglish
@@ -131,8 +131,9 @@ def _get_tags(cfg, path):
131131
return None
132132

133133

134-
def _rearrange_buildable_args_and_insert_unset_sentinels(
135-
value: config.Buildable) -> config.Buildable:
134+
def _rearrange_buildable_args(
135+
value: config.Buildable, insert_unset_sentinels: bool = True
136+
) -> config.Buildable:
136137
"""Returns a copy of a Buildable with normalized arguments.
137138
138139
This normalizes arguments by re-creating the __arguments__ dictionary in the
@@ -141,6 +142,8 @@ def _rearrange_buildable_args_and_insert_unset_sentinels(
141142
142143
Args:
143144
value: Buildable to copy and normalize.
145+
insert_unset_sentinels: If true, insert unset sentinels to arguments as the
146+
default values.
144147
145148
Returns:
146149
Copy of `value` with arguments normalized.
@@ -155,7 +158,7 @@ def _rearrange_buildable_args_and_insert_unset_sentinels(
155158
continue
156159
elif param_name in old_arguments:
157160
new_arguments[param_name] = old_arguments.pop(param_name)
158-
else:
161+
elif insert_unset_sentinels:
159162
new_arguments[param_name] = _UnsetValue(param)
160163
new_arguments.update(old_arguments) # Add in kwargs, in current order.
161164
object.__setattr__(value, '__arguments__', new_arguments)
@@ -190,7 +193,7 @@ def generate(value, state=None) -> Iterator[_LeafSetting]:
190193

191194
# Rearrange parameters in signature order, and add "unset" sentinels.
192195
if isinstance(value, config.Buildable):
193-
value = _rearrange_buildable_args_and_insert_unset_sentinels(value)
196+
value = _rearrange_buildable_args(value)
194197

195198
if isinstance(value, tagging.TaggedValueCls):
196199
value = _TaggedValueWrapper(value)
@@ -331,3 +334,43 @@ def make_previous_text(entry: history.HistoryEntry) -> str:
331334
value_history[-1].new_value, raw_value_repr=raw_value_repr)
332335
current = f'{current_value} @ {value_history[-1].location}'
333336
return f'{_path_str(path)} = {current}{past}'
337+
338+
339+
def as_dict_flattened(cfg: config.Buildable) -> Dict[str, Any]:
340+
"""Returns a flattened dict of cfg's paths (dot syntax) and values.
341+
342+
Default values won't be included in the flattened dict.
343+
344+
Args:
345+
cfg: A buildable to generate a string representation for.
346+
347+
Returns: A flattened Dict representation of `cfg`.
348+
"""
349+
350+
def dict_generate(value, state=None) -> Iterator[_LeafSetting]:
351+
state = state or daglish.BasicTraversal.begin(dict_generate, value)
352+
353+
tags = _get_tags(cfg, state.current_path)
354+
if tags:
355+
value = tagging.TaggedValue(tags=tags, default=value)
356+
357+
# Rearrange parameters in signature order, and add "unset" sentinels.
358+
if isinstance(value, config.Buildable):
359+
value = _rearrange_buildable_args(value, insert_unset_sentinels=False)
360+
361+
if isinstance(value, tagging.TaggedValueCls):
362+
value = _TaggedValueWrapper(value)
363+
yield _LeafSetting(state.current_path, None, value)
364+
elif not _has_nested_builder(value):
365+
yield _LeafSetting(state.current_path, None, value)
366+
else:
367+
# value must be a Buildable or a traversable containing a Buidable.
368+
assert state.is_traversable(value)
369+
for sub_result in state.flattened_map_children(value).values:
370+
yield from sub_result
371+
372+
args_dict = {}
373+
for leaf in dict_generate(cfg):
374+
args_dict[_path_str(leaf.path)] = leaf.value
375+
376+
return args_dict

fiddle/_src/printing_test.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,5 +473,165 @@ def test_collection_of_two_buildables_history(self):
473473
self.assertRegex(output, expected)
474474

475475

476+
class AsFlattenedDictTests(absltest.TestCase):
477+
478+
def test_simple_flattened_dict(self):
479+
cfg = fdl.Config(fn_x_y, 1, 'abc')
480+
output = printing.as_dict_flattened(cfg)
481+
482+
expected = {'x': 1, 'y': 'abc'}
483+
self.assertEqual(output, expected)
484+
485+
def test_skip_unset_argument(self):
486+
cfg = fdl.Config(fn_x_y, 3.14)
487+
output = printing.as_dict_flattened(cfg)
488+
489+
expected = {'x': 3.14}
490+
self.assertEqual(output, expected)
491+
492+
def test_nested(self):
493+
cfg = fdl.Config(fn_x_y, 'x', fdl.Config(fn_x_y, 'nest_x', 123))
494+
output = printing.as_dict_flattened(cfg)
495+
496+
expected = {'x': 'x', 'y.x': 'nest_x', 'y.y': 123}
497+
self.assertEqual(output, expected)
498+
499+
def test_class(self):
500+
cfg = fdl.Config(SampleClass, 'a_param', b=123)
501+
output = printing.as_dict_flattened(cfg)
502+
503+
expected = {'a': 'a_param', 'b': 123}
504+
self.assertEqual(output, expected)
505+
506+
def test_kwargs(self):
507+
cfg = fdl.Config(fn_with_kwargs, 1, abc='extra kwarg value')
508+
output = printing.as_dict_flattened(cfg)
509+
510+
expected = {'x': 1, 'abc': 'extra kwarg value'}
511+
self.assertEqual(output, expected)
512+
513+
def test_nested_kwargs(self):
514+
cfg = fdl.Config(
515+
fn_with_kwargs, extra=fdl.Config(fn_with_kwargs, 1, nested_extra='whee')
516+
)
517+
output = printing.as_dict_flattened(cfg)
518+
519+
expected = {'extra.x': 1, 'extra.nested_extra': 'whee'}
520+
self.assertEqual(output, expected)
521+
522+
def test_nested_collections(self):
523+
cfg = fdl.Config(
524+
fn_x_y, [fdl.Config(fn_x_y, 1, '1'), fdl.Config(SampleClass, 2)]
525+
)
526+
output = printing.as_dict_flattened(cfg)
527+
528+
expected = {'x[0].x': 1, 'x[0].y': '1', 'x[1].a': 2}
529+
self.assertEqual(output, expected)
530+
531+
def test_multiple_nested_collections(self):
532+
cfg = fdl.Config(
533+
fn_x_y,
534+
{'a': fdl.Config(fn_with_kwargs, abc=[1, 2, 3]), 'b': [3, 2, 1]},
535+
[fdl.Config(fn_x_y, [fdl.Config(fn_x_y, 1, 2)])],
536+
)
537+
output = printing.as_dict_flattened(cfg)
538+
539+
expected = {
540+
"x['a'].abc": [1, 2, 3],
541+
"x['b']": [3, 2, 1],
542+
'y[0].x[0].x': 1,
543+
'y[0].x[0].y': 2,
544+
}
545+
self.assertEqual(output, expected)
546+
547+
def test_skip_default_values(self):
548+
def test_fn(w, x, y=3, z='abc'): # pylint: disable=unused-argument
549+
pass
550+
551+
cfg = fdl.Config(test_fn, 1)
552+
output = printing.as_dict_flattened(cfg)
553+
554+
expected = {'w': 1}
555+
self.assertEqual(output, expected)
556+
557+
def test_tagged_values(self):
558+
cfg = fdl.Config(fn_x_y, x=SampleTag.new(), y=SampleTag.new(default='abc'))
559+
output = printing.as_dict_flattened(cfg)
560+
561+
expected = "'abc' #__main__.SampleTag"
562+
self.assertEqual(repr(output['y']), expected)
563+
564+
fdl.set_tagged(cfg, tag=SampleTag, value='cba')
565+
output = printing.as_dict_flattened(cfg)
566+
567+
expected = "'cba' #__main__.SampleTag"
568+
self.assertEqual(repr(output['x']), expected)
569+
self.assertEqual(repr(output['y']), expected)
570+
571+
def test_partial(self):
572+
partial = fdl.Partial(fn_x_y)
573+
partial.x = 'abc'
574+
output = printing.as_dict_flattened(partial)
575+
576+
expected = {'x': 'abc'}
577+
self.assertEqual(output, expected)
578+
579+
def test_builtin_types_annotations(self):
580+
cfg = fdl.Config(fn_with_type_annotations, 1)
581+
cfg.y = 'abc'
582+
output = printing.as_dict_flattened(cfg)
583+
584+
expected = {'x': 1, 'y': 'abc'}
585+
self.assertEqual(output, expected)
586+
587+
def test_annotated_kwargs(self):
588+
cfg = fdl.Config(annotated_kwargs_helper, x=1, y='oops')
589+
output = printing.as_dict_flattened(cfg)
590+
591+
expected = {'x': 1, 'y': 'oops'}
592+
self.assertEqual(output, expected)
593+
594+
def test_disabling_type_annotations(self):
595+
cfg = fdl.Config(fn_with_type_annotations, 1)
596+
cfg.y = 'abc'
597+
output = printing.as_dict_flattened(cfg)
598+
599+
expected = {'x': 1, 'y': 'abc'}
600+
self.assertEqual(output, expected)
601+
602+
def test_union_type(self):
603+
def to_integer(x: Union[int, str]):
604+
return int(x)
605+
606+
cfg = fdl.Config(to_integer, 1)
607+
output = printing.as_dict_flattened(cfg)
608+
609+
expected = {'x': 1}
610+
self.assertEqual(output, expected)
611+
612+
def test_parameterized_generic(self):
613+
if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
614+
self.skipTest('types.GenericAlias is 3.9+ only.')
615+
616+
def takes_list(x: list[int]):
617+
return x
618+
619+
cfg = fdl.Config(takes_list, [1, 2, 3])
620+
output = printing.as_dict_flattened(cfg)
621+
622+
expected = {'x': [1, 2, 3]}
623+
self.assertEqual(output, expected)
624+
625+
def test_materialized_default_values(self):
626+
def test_fn(w, x, y=3, z='abc'):
627+
del w, x, y, z # Unused.
628+
629+
cfg = fdl.Config(test_fn, 1)
630+
fdl.materialize_defaults(cfg)
631+
output = printing.as_dict_flattened(cfg)
632+
expected = {'w': 1, 'y': 3, 'z': 'abc'}
633+
self.assertEqual(output, expected)
634+
635+
476636
if __name__ == '__main__':
477637
absltest.main()

fiddle/printing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
"""Functions to output representations of `fdl.Buildable`s."""
1717

1818
# pylint: disable=unused-import
19+
from fiddle._src.printing import as_dict_flattened
1920
from fiddle._src.printing import as_str_flattened
2021
from fiddle._src.printing import history_per_leaf_parameter

0 commit comments

Comments
 (0)