Skip to content

Commit bb9ebcd

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Allow default values to be set on Fiddle CLI flags
PiperOrigin-RevId: 625033837
1 parent 77cd853 commit bb9ebcd

1 file changed

Lines changed: 163 additions & 89 deletions

File tree

fiddle/_src/absl_flags/flags.py

Lines changed: 163 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
"""API to use command line flags with Fiddle Buildables."""
1717

18+
import dataclasses
1819
import re
1920
import types
20-
from typing import Any, Optional, TypeVar
21+
from typing import Any, List, Optional, Text, TypeVar, Union
2122

2223
from absl import flags
2324
from etils import epath
@@ -81,60 +82,32 @@ def serialize(self, value: config.Buildable) -> str:
8182
return f"config_str:{serialized}"
8283

8384

84-
class FiddleFlag(flags.MultiFlag):
85-
"""ABSL flag class for a Fiddle config flag.
85+
@dataclasses.dataclass
86+
class _LazyFlagValue:
87+
"""Represents a lazily evaluated Fiddle flag value.
8688
87-
This class is used to parse command line flags to construct a Fiddle `Config`
88-
object with certain transformations applied as specified in the command line
89-
flags.
89+
This is separate from FiddleFlag because it is used by both defaults and
90+
provided flags.
9091
91-
Most users should rely on the `DEFINE_fiddle_config()` API below. Using this
92-
class directly provides flexibility to users to parse Fiddle flags themselves
93-
programmatically. Also see the documentation for `DEFINE_fiddle_config()`
94-
below.
92+
Lazy flag values are useful because they allow other parts of the system to
93+
be set up, so things like logging can be configured before a configuration is
94+
loaded.
95+
"""
9596

96-
Example usage where this flag is parsed from existing flag:
97-
```
98-
from fiddle import absl_flags as fdl_flags
97+
flag_name: str
98+
remaining_directives: List[str] = dataclasses.field(default_factory=list)
99+
first_command: Optional[str] = None
100+
initial_config_expression: Optional[str] = None
99101

100-
_MY_CONFIG = fdl_flags.DEFINE_multi_string(
101-
"my_config",
102-
"Name of the fiddle config"
103-
)
102+
default_module: Optional[types.ModuleType] = None
103+
allow_imports: bool = True
104+
pyref_policy: Optional[serialization.PyrefPolicy] = None
104105

105-
fiddle_flag = fdl_flags.FiddleFlag(
106-
name="config",
107-
default_module=my_module,
108-
default=None,
109-
parser=flags.ArgumentParser(),
110-
serializer=None,
111-
help_string="My fiddle flag",
112-
)
113-
fiddle_flag.parse(_MY_CONFIG.value)
114-
config = fiddle_flag.value
115-
```
116-
"""
117-
118-
def __init__(
119-
self,
120-
*args,
121-
default_module: Optional[types.ModuleType] = None,
122-
allow_imports: bool = True,
123-
pyref_policy: Optional[serialization.PyrefPolicy] = None,
124-
**kwargs,
125-
):
126-
self.allow_imports = allow_imports
127-
self.default_module = default_module
128-
self._pyref_policy = pyref_policy
129-
self.first_command = None
130-
self._initial_config_expression = None
131-
# A `directive` is a str of the form e.g. 'config:...'.
132-
# Due to the lazy evaluation of `value`, this list is needed to keep
133-
# track of the remaining `directives`.
134-
self._remaining_directives = []
135-
super().__init__(*args, **kwargs)
106+
# Only set internally, please use get_value() / set_value().
107+
_value: Optional[Any] = None
136108

137109
def _initial_config(self, expression: str):
110+
"""Generates the initial config from a config:<expression> directive."""
138111
call_expr = utils.CallExpression.parse(expression)
139112
base_name = call_expr.func_name
140113
base_fn = utils.resolve_function_reference(
@@ -150,6 +123,7 @@ def _initial_config(self, expression: str):
150123
return base_fn(*call_expr.args, **call_expr.kwargs)
151124

152125
def _apply_fiddler(self, cfg: config.Buildable, expression: str):
126+
"""Modifies the config from the given CLI flag."""
153127
call_expr = utils.CallExpression.parse(expression)
154128
base_name = call_expr.func_name
155129
fiddler = utils.resolve_function_reference(
@@ -175,57 +149,40 @@ def _apply_fiddler(self, cfg: config.Buildable, expression: str):
175149
# `fdl.Buildable` object.
176150
return new_cfg if new_cfg is not None else cfg
177151

178-
def parse(self, arguments):
179-
new_parsed = self._parse(arguments)
180-
self._remaining_directives.extend(new_parsed)
181-
self.present += len(new_parsed)
182-
183-
def unparse(self) -> None:
184-
self.value = self.default
185-
self.using_default_value = True
186-
# Reset it so that all `directives` not being processed yet will be
187-
# discarded.
188-
self._remaining_directives = []
189-
self.present = 0
190-
191152
def _parse_config(self, command: str, expression: str) -> None:
192-
if self._initial_config_expression:
153+
"""Sets the initial config from the given CLI flag/directive."""
154+
if self.initial_config_expression:
193155
raise ValueError(
194156
"Only one base configuration is permitted. Received"
195157
f"{command}:{expression} after "
196-
f"{self.first_command}:{self._initial_config_expression} was"
158+
f"{self.first_command}:{self.initial_config_expression} was"
197159
" already provided."
198160
)
199161
else:
200-
self._initial_config_expression = expression
162+
self.initial_config_expression = expression
201163
if command == "config":
202-
self.value = self._initial_config(expression)
164+
self._value = self._initial_config(expression)
203165
elif command == "config_file":
204166
with epath.Path(expression).open() as f:
205-
self.value = serialization.load_json(
206-
f.read(), pyref_policy=self._pyref_policy
167+
self._value = serialization.load_json(
168+
f.read(), pyref_policy=self.pyref_policy
207169
)
208170
elif command == "config_str":
209171
serializer = utils.ZlibJSONSerializer()
210-
self.value = serializer.deserialize(
211-
expression, pyref_policy=self._pyref_policy
172+
self._value = serializer.deserialize(
173+
expression, pyref_policy=self.pyref_policy
212174
)
213175

214-
def _serialize(self, value) -> str:
215-
# Skip MultiFlag serialization as we don't truly have a multi-flag.
216-
# This will invoke Flag._serialize
217-
return super(flags.MultiFlag, self)._serialize(value)
218-
219-
@property
220-
def value(self):
221-
while self._remaining_directives:
176+
def get_value(self):
177+
"""Gets the current value (parsing any directives)."""
178+
while self.remaining_directives:
222179
# Pop already processed `directive` so that _value won't be updated twice
223180
# by the same argument.
224-
item = self._remaining_directives.pop(0)
181+
item = self.remaining_directives.pop(0)
225182
match = _COMMAND_RE.fullmatch(item)
226183
if not match:
227184
raise ValueError(
228-
f"All flag values to {self.name} must begin with 'config:', "
185+
f"All flag values to {self.flag_name} must begin with 'config:', "
229186
"'config_file:', 'config_str:', 'set:', or 'fiddler:'."
230187
)
231188
command, expression = match.groups()
@@ -235,7 +192,9 @@ def value(self):
235192
raise ValueError(
236193
"First flag command must specify the input config via either "
237194
"config or config_file or config_str commands. "
238-
f"Received command: {command} instead."
195+
f"Received command: {command} instead. If you have a default "
196+
"value set, you must re-provide that on the CLI before setting "
197+
"values or running fiddlers."
239198
)
240199
self.first_command = command
241200

@@ -254,15 +213,130 @@ def value(self):
254213
raise AssertionError("Internal error; should not be reached.")
255214
return self._value
256215

216+
def set_value(self, value: Any):
217+
self._value = value
218+
219+
220+
class FiddleFlag(flags.MultiFlag):
221+
"""ABSL flag class for a Fiddle config flag.
222+
223+
This class is used to parse command line flags to construct a Fiddle `Config`
224+
object with certain transformations applied as specified in the command line
225+
flags.
226+
227+
Most users should rely on the `DEFINE_fiddle_config()` API below. Using this
228+
class directly provides flexibility to users to parse Fiddle flags themselves
229+
programmatically. Also see the documentation for `DEFINE_fiddle_config()`
230+
below.
231+
232+
Example usage where this flag is parsed from existing flag:
233+
```
234+
from fiddle import absl_flags as fdl_flags
235+
236+
_MY_CONFIG = fdl_flags.DEFINE_multi_string(
237+
"my_config",
238+
"Name of the fiddle config"
239+
)
240+
241+
fiddle_flag = fdl_flags.FiddleFlag(
242+
name="config",
243+
default_module=my_module,
244+
default=None,
245+
parser=flags.ArgumentParser(),
246+
serializer=None,
247+
help_string="My fiddle flag",
248+
)
249+
fiddle_flag.parse(_MY_CONFIG.value)
250+
config = fiddle_flag.value
251+
```
252+
"""
253+
254+
def __init__(
255+
self,
256+
*args,
257+
name: Text,
258+
default_module: Optional[types.ModuleType] = None,
259+
allow_imports: bool = True,
260+
pyref_policy: Optional[serialization.PyrefPolicy] = None,
261+
**kwargs,
262+
):
263+
self.allow_imports = allow_imports
264+
self.default_module = default_module
265+
self._pyref_policy = pyref_policy
266+
self._lazy_default = _LazyFlagValue(
267+
flag_name=name,
268+
default_module=default_module,
269+
allow_imports=allow_imports,
270+
pyref_policy=pyref_policy,
271+
)
272+
self._lazy_value = _LazyFlagValue(
273+
flag_name=name,
274+
default_module=default_module,
275+
allow_imports=allow_imports,
276+
pyref_policy=pyref_policy,
277+
)
278+
kwargs["name"] = name
279+
super().__init__(*args, **kwargs)
280+
281+
def parse(self, arguments):
282+
new_parsed = self._parse(arguments)
283+
self._lazy_value.remaining_directives.extend(new_parsed)
284+
self.present += len(new_parsed)
285+
286+
def _parse_from_default(
287+
self, value: Union[Text, List[Any]]
288+
) -> Optional[List[Any]]:
289+
lazy_default_value = _LazyFlagValue(
290+
flag_name=self.name,
291+
default_module=self.default_module,
292+
allow_imports=self.allow_imports,
293+
pyref_policy=self._pyref_policy,
294+
)
295+
value = self._parse(value)
296+
assert isinstance(value, list)
297+
lazy_default_value.remaining_directives.extend(value)
298+
return lazy_default_value # pytype: disable=bad-return-type
299+
300+
def unparse(self) -> None:
301+
self.value = self.default
302+
self.using_default_value = True
303+
# Reset it so that all `directives` not being processed yet will be
304+
# discarded.
305+
self._lazy_value.remaining_directives = []
306+
self.present = 0
307+
308+
def _serialize(self, value) -> str:
309+
# Skip MultiFlag serialization as we don't truly have a multi-flag.
310+
# This will invoke Flag._serialize
311+
return super(flags.MultiFlag, self)._serialize(value)
312+
313+
@property
314+
def value(self):
315+
return self._lazy_value.get_value()
316+
257317
@value.setter
258318
def value(self, value):
259-
self._value = value
319+
self._lazy_value.set_value(value)
320+
321+
@property
322+
def default(self):
323+
return self._lazy_default.get_value()
324+
325+
@default.setter
326+
def default(self, value):
327+
if isinstance(value, _LazyFlagValue):
328+
# Note: This is only for _set_default(). We might choose to override that
329+
# instead of just _parse_from_default(), in which case this branch can be
330+
# removed.
331+
self._lazy_default = value
332+
else:
333+
self._lazy_default.set_value(value)
260334

261335

262336
def DEFINE_fiddle_config( # pylint: disable=invalid-name
263337
name: str,
264338
*,
265-
default: Any = None,
339+
default_flag_str: Optional[str] = None,
266340
help_string: str,
267341
default_module: Optional[types.ModuleType] = None,
268342
pyref_policy: Optional[serialization.PyrefPolicy] = None,
@@ -317,12 +391,12 @@ def main(argv) -> None:
317391
python3 -m path.to.my.binary --my_config=config_file:path/to/file
318392
319393
Args:
320-
name: name of the command line flag.
321-
default: default value of the flag.
322-
help_string: help string describing what the flag does.
323-
default_module: the python module where this flag is defined.
324-
pyref_policy: a policy for importing references to Python objects.
325-
flag_values: the ``FlagValues`` instance with which the flag will be
394+
name: Name of the command line flag.
395+
default_flag_str: Default value of the flag.
396+
help_string: Help string describing what the flag does.
397+
default_module: The python module where this flag is defined.
398+
pyref_policy: A policy for importing references to Python objects.
399+
flag_values: The ``FlagValues`` instance with which the flag will be
326400
registered. This should almost never need to be overridden.
327401
required: bool, is this a required flag. This must be used as a keyword
328402
argument.
@@ -334,7 +408,7 @@ def main(argv) -> None:
334408
FiddleFlag(
335409
name=name,
336410
default_module=default_module,
337-
default=default,
411+
default=default_flag_str,
338412
pyref_policy=pyref_policy,
339413
parser=flags.ArgumentParser(),
340414
serializer=FiddleFlagSerializer(pyref_policy=pyref_policy),

0 commit comments

Comments
 (0)