Skip to content

Commit 5b69651

Browse files
mjwillsoncopybara-github
authored andcommitted
In a fdl.Config for a dataclass, when accessing a field that uses a default_factory, populate a child Config (if not already present) wrapping the default_factory call, rather than raising a "Can't get default value for dataclass field ... since it uses a default_factory" error.
This allows easy overriding of properties of child dataclasses where the config for the parent dataclass doesn't include an explicit child config. PiperOrigin-RevId: 737952148
1 parent 8536ef2 commit 5b69651

File tree

2 files changed

+81
-26
lines changed

2 files changed

+81
-26
lines changed

fiddle/_src/config.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import dataclasses
2424
import functools
2525
import types
26-
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
26+
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union, cast
2727

2828
from fiddle._src import daglish
2929
from fiddle._src import history
@@ -341,14 +341,27 @@ def __getattr__(self, name: str):
341341

342342
if value is not _UNSET_SENTINEL:
343343
return value
344-
if dataclasses.is_dataclass(
345-
self.__fn_or_cls__
346-
) and _field_uses_default_factory(self.__fn_or_cls__, name):
347-
raise ValueError(
348-
"Can't get default value for dataclass field "
349-
+ f'{self.__fn_or_cls__.__qualname__}.{name} '
350-
+ 'since it uses a default_factory.'
351-
)
344+
if dataclasses.is_dataclass(self.__fn_or_cls__) and (
345+
default_factory := _field_default_factory(self.__fn_or_cls__, name)
346+
):
347+
if _is_resolvable(default_factory):
348+
self.__arguments__[name] = Config(default_factory)
349+
return self.__arguments__[name]
350+
elif isinstance(default_factory, functools.partial) and _is_resolvable(
351+
default_factory.func
352+
):
353+
self.__arguments__[name] = Config(
354+
default_factory.func,
355+
*default_factory.args,
356+
**default_factory.keywords,
357+
)
358+
return self.__arguments__[name]
359+
else:
360+
return ValueError(
361+
"Can't expose a sub-config to build default value for field "
362+
f'{name} of dataclass {self.__fn_or_cls__} since it uses an '
363+
'anonymous default_factory.'
364+
)
352365
if param is not None and param.default is not param.empty:
353366
return param.default
354367
msg = f"No parameter '{name}' has been set on {self!r}."
@@ -848,12 +861,27 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T:
848861
return self.__fn_or_cls__(tags=self.tags, *args, **kwargs)
849862

850863

851-
def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
852-
"""Returns true if <dataclass_type>.<field_name> uses a default_factory."""
864+
def _field_default_factory(
865+
dataclass_type: Type[Any], field_name: str
866+
) -> Callable[[], Any] | None:
867+
"""Returns the default_factory of <dataclass_type>.<field_name> if present."""
853868
for field in dataclasses.fields(dataclass_type):
854-
if field.name == field_name:
855-
return field.default_factory != dataclasses.MISSING
856-
return False
869+
if (
870+
field.name == field_name
871+
and field.default_factory != dataclasses.MISSING
872+
):
873+
return cast(Callable[[], Any], field.default_factory)
874+
return None
875+
876+
877+
def _is_resolvable(value: Any) -> bool:
878+
return (
879+
hasattr(value, '__module__')
880+
and hasattr(value, '__qualname__')
881+
and
882+
# Rules out anonymous objects like <lambda>, foo.<locals>.bar etc:
883+
'<' not in value.__qualname__
884+
)
857885

858886

859887
BuildableT = TypeVar('BuildableT', bound=Buildable)

fiddle/_src/config_test.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import copy
1919
import dataclasses
20+
import functools
2021
import pickle
2122
import sys
2223
import threading
@@ -111,6 +112,13 @@ class DataclassParent:
111112
child: DataclassChild = dataclasses.field(default_factory=DataclassChild)
112113

113114

115+
@dataclasses.dataclass
116+
class DataclassParentPartialDefaultFactoryChild:
117+
child: DataclassChild = dataclasses.field(
118+
default_factory=functools.partial(DataclassChild, x=1)
119+
)
120+
121+
114122
def raise_error():
115123
raise ValueError('My fancy exception')
116124

@@ -1042,21 +1050,40 @@ def test_copy_constructor_with_updates_errors(self):
10421050
with self.assertRaises(ValueError):
10431051
fdl.Partial(cfg1, 5, a='a', b='b')
10441052

1045-
def test_dataclass_default_factory(self):
1053+
def test_dataclass_default_factory_can_read_default(self):
1054+
cfg = fdl.Config(DataclassParent)
1055+
child_config = cfg.child
1056+
self.assertIsInstance(child_config, fdl.Config)
1057+
self.assertEqual(child_config.__fn_or_cls__, DataclassChild)
1058+
self.assertEqual(fdl.build(cfg), DataclassParent(child=DataclassChild(x=0)))
1059+
1060+
def test_dataclass_partial_default_factory_can_read_default(self):
1061+
cfg = fdl.Config(DataclassParentPartialDefaultFactoryChild)
1062+
child_config = cfg.child
1063+
self.assertIsInstance(child_config, fdl.Config)
1064+
self.assertEqual(child_config.__fn_or_cls__, DataclassChild)
1065+
self.assertEqual(child_config.x, 1)
1066+
self.assertEqual(
1067+
fdl.build(cfg),
1068+
DataclassParentPartialDefaultFactoryChild(child=DataclassChild(x=1)),
1069+
)
10461070

1071+
def test_dataclass_default_factory_overriding_child_config(self):
10471072
cfg = fdl.Config(DataclassParent)
1073+
cfg.child.x = 5
1074+
self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(x=5)))
1075+
1076+
def test_dataclass_partial_default_factory_overriding_child_config(self):
1077+
cfg = fdl.Config(DataclassParentPartialDefaultFactoryChild)
1078+
cfg.child.x = 5
1079+
self.assertEqual(
1080+
fdl.build(cfg),
1081+
DataclassParentPartialDefaultFactoryChild(DataclassChild(x=5)),
1082+
)
10481083

1049-
with self.subTest('read_default_is_error'):
1050-
expected_error = (
1051-
r"Can't get default value for dataclass field DataclassParent\.child "
1052-
r'since it uses a default_factory\.')
1053-
with self.assertRaisesRegex(ValueError, expected_error):
1054-
cfg.child.x = 5
1055-
1056-
with self.subTest('read_ok_after_override'):
1057-
cfg.child = fdl.Config(DataclassChild) # override default w/ a value
1058-
cfg.child.x = 5 # now it's ok to configure child.
1059-
self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(5)))
1084+
def test_dataclass_default_factory_not_used_when_child_config_given(self):
1085+
cfg = fdl.Config(DataclassParent, child=fdl.Config(DataclassChild, x=1))
1086+
self.assertEqual(cfg.child.x, 1)
10601087

10611088
def test_unbound_method(self):
10621089
sample = fdl.Config(SampleClass, 0, 1)

0 commit comments

Comments
 (0)