1515
1616"""API to use command line flags with Fiddle Buildables."""
1717
18+ import dataclasses
1819import re
1920import types
20- from typing import Any , Optional , TypeVar
21+ from typing import Any , List , Optional , Text , TypeVar , Union
2122
2223from absl import flags
2324from etils import epath
@@ -81,60 +82,32 @@ def serialize(self, value: config.Buildable) -> str:
8182 return f"config_str:{ serialized } "
8283
8384
84- class FiddleFlag (flags .MultiFlag ):
85- """ABSL flag class for a Fiddle config flag.
85+ @dataclasses .dataclass
86+ class _LazyFlagValue :
87+ """Represents a lazily evaluated Fiddle flag value.
8688
87- This class is used to parse command line flags to construct a Fiddle `Config`
88- object with certain transformations applied as specified in the command line
89- flags.
89+ This is separate from FiddleFlag because it is used by both defaults and
90+ provided flags.
9091
91- Most users should rely on the `DEFINE_fiddle_config()` API below. Using this
92- class directly provides flexibility to users to parse Fiddle flags themselves
93- programmatically. Also see the documentation for `DEFINE_fiddle_config()`
94- below.
92+ Lazy flag values are useful because they allow other parts of the system to
93+ be set up, so things like logging can be configured before a configuration is
94+ loaded.
95+ """
9596
96- Example usage where this flag is parsed from existing flag:
97- ```
98- from fiddle import absl_flags as fdl_flags
97+ flag_name : str
98+ remaining_directives : List [str ] = dataclasses .field (default_factory = list )
99+ first_command : Optional [str ] = None
100+ initial_config_expression : Optional [str ] = None
99101
100- _MY_CONFIG = fdl_flags.DEFINE_multi_string(
101- "my_config",
102- "Name of the fiddle config"
103- )
102+ default_module : Optional [types .ModuleType ] = None
103+ allow_imports : bool = True
104+ pyref_policy : Optional [serialization .PyrefPolicy ] = None
104105
105- fiddle_flag = fdl_flags.FiddleFlag(
106- name="config",
107- default_module=my_module,
108- default=None,
109- parser=flags.ArgumentParser(),
110- serializer=None,
111- help_string="My fiddle flag",
112- )
113- fiddle_flag.parse(_MY_CONFIG.value)
114- config = fiddle_flag.value
115- ```
116- """
117-
118- def __init__ (
119- self ,
120- * args ,
121- default_module : Optional [types .ModuleType ] = None ,
122- allow_imports : bool = True ,
123- pyref_policy : Optional [serialization .PyrefPolicy ] = None ,
124- ** kwargs ,
125- ):
126- self .allow_imports = allow_imports
127- self .default_module = default_module
128- self ._pyref_policy = pyref_policy
129- self .first_command = None
130- self ._initial_config_expression = None
131- # A `directive` is a str of the form e.g. 'config:...'.
132- # Due to the lazy evaluation of `value`, this list is needed to keep
133- # track of the remaining `directives`.
134- self ._remaining_directives = []
135- super ().__init__ (* args , ** kwargs )
106+ # Only set internally, please use get_value() / set_value().
107+ _value : Optional [Any ] = None
136108
137109 def _initial_config (self , expression : str ):
110+ """Generates the initial config from a config:<expression> directive."""
138111 call_expr = utils .CallExpression .parse (expression )
139112 base_name = call_expr .func_name
140113 base_fn = utils .resolve_function_reference (
@@ -150,6 +123,7 @@ def _initial_config(self, expression: str):
150123 return base_fn (* call_expr .args , ** call_expr .kwargs )
151124
152125 def _apply_fiddler (self , cfg : config .Buildable , expression : str ):
126+ """Modifies the config from the given CLI flag."""
153127 call_expr = utils .CallExpression .parse (expression )
154128 base_name = call_expr .func_name
155129 fiddler = utils .resolve_function_reference (
@@ -175,57 +149,40 @@ def _apply_fiddler(self, cfg: config.Buildable, expression: str):
175149 # `fdl.Buildable` object.
176150 return new_cfg if new_cfg is not None else cfg
177151
178- def parse (self , arguments ):
179- new_parsed = self ._parse (arguments )
180- self ._remaining_directives .extend (new_parsed )
181- self .present += len (new_parsed )
182-
183- def unparse (self ) -> None :
184- self .value = self .default
185- self .using_default_value = True
186- # Reset it so that all `directives` not being processed yet will be
187- # discarded.
188- self ._remaining_directives = []
189- self .present = 0
190-
191152 def _parse_config (self , command : str , expression : str ) -> None :
192- if self ._initial_config_expression :
153+ """Sets the initial config from the given CLI flag/directive."""
154+ if self .initial_config_expression :
193155 raise ValueError (
194156 "Only one base configuration is permitted. Received"
195157 f"{ command } :{ expression } after "
196- f"{ self .first_command } :{ self ._initial_config_expression } was"
158+ f"{ self .first_command } :{ self .initial_config_expression } was"
197159 " already provided."
198160 )
199161 else :
200- self ._initial_config_expression = expression
162+ self .initial_config_expression = expression
201163 if command == "config" :
202- self .value = self ._initial_config (expression )
164+ self ._value = self ._initial_config (expression )
203165 elif command == "config_file" :
204166 with epath .Path (expression ).open () as f :
205- self .value = serialization .load_json (
206- f .read (), pyref_policy = self ._pyref_policy
167+ self ._value = serialization .load_json (
168+ f .read (), pyref_policy = self .pyref_policy
207169 )
208170 elif command == "config_str" :
209171 serializer = utils .ZlibJSONSerializer ()
210- self .value = serializer .deserialize (
211- expression , pyref_policy = self ._pyref_policy
172+ self ._value = serializer .deserialize (
173+ expression , pyref_policy = self .pyref_policy
212174 )
213175
214- def _serialize (self , value ) -> str :
215- # Skip MultiFlag serialization as we don't truly have a multi-flag.
216- # This will invoke Flag._serialize
217- return super (flags .MultiFlag , self )._serialize (value )
218-
219- @property
220- def value (self ):
221- while self ._remaining_directives :
176+ def get_value (self ):
177+ """Gets the current value (parsing any directives)."""
178+ while self .remaining_directives :
222179 # Pop already processed `directive` so that _value won't be updated twice
223180 # by the same argument.
224- item = self ._remaining_directives .pop (0 )
181+ item = self .remaining_directives .pop (0 )
225182 match = _COMMAND_RE .fullmatch (item )
226183 if not match :
227184 raise ValueError (
228- f"All flag values to { self .name } must begin with 'config:', "
185+ f"All flag values to { self .flag_name } must begin with 'config:', "
229186 "'config_file:', 'config_str:', 'set:', or 'fiddler:'."
230187 )
231188 command , expression = match .groups ()
@@ -235,7 +192,9 @@ def value(self):
235192 raise ValueError (
236193 "First flag command must specify the input config via either "
237194 "config or config_file or config_str commands. "
238- f"Received command: { command } instead."
195+ f"Received command: { command } instead. If you have a default "
196+ "value set, you must re-provide that on the CLI before setting "
197+ "values or running fiddlers."
239198 )
240199 self .first_command = command
241200
@@ -254,15 +213,130 @@ def value(self):
254213 raise AssertionError ("Internal error; should not be reached." )
255214 return self ._value
256215
216+ def set_value (self , value : Any ):
217+ self ._value = value
218+
219+
220+ class FiddleFlag (flags .MultiFlag ):
221+ """ABSL flag class for a Fiddle config flag.
222+
223+ This class is used to parse command line flags to construct a Fiddle `Config`
224+ object with certain transformations applied as specified in the command line
225+ flags.
226+
227+ Most users should rely on the `DEFINE_fiddle_config()` API below. Using this
228+ class directly provides flexibility to users to parse Fiddle flags themselves
229+ programmatically. Also see the documentation for `DEFINE_fiddle_config()`
230+ below.
231+
232+ Example usage where this flag is parsed from existing flag:
233+ ```
234+ from fiddle import absl_flags as fdl_flags
235+
236+ _MY_CONFIG = fdl_flags.DEFINE_multi_string(
237+ "my_config",
238+ "Name of the fiddle config"
239+ )
240+
241+ fiddle_flag = fdl_flags.FiddleFlag(
242+ name="config",
243+ default_module=my_module,
244+ default=None,
245+ parser=flags.ArgumentParser(),
246+ serializer=None,
247+ help_string="My fiddle flag",
248+ )
249+ fiddle_flag.parse(_MY_CONFIG.value)
250+ config = fiddle_flag.value
251+ ```
252+ """
253+
254+ def __init__ (
255+ self ,
256+ * args ,
257+ name : Text ,
258+ default_module : Optional [types .ModuleType ] = None ,
259+ allow_imports : bool = True ,
260+ pyref_policy : Optional [serialization .PyrefPolicy ] = None ,
261+ ** kwargs ,
262+ ):
263+ self .allow_imports = allow_imports
264+ self .default_module = default_module
265+ self ._pyref_policy = pyref_policy
266+ self ._lazy_default = _LazyFlagValue (
267+ flag_name = name ,
268+ default_module = default_module ,
269+ allow_imports = allow_imports ,
270+ pyref_policy = pyref_policy ,
271+ )
272+ self ._lazy_value = _LazyFlagValue (
273+ flag_name = name ,
274+ default_module = default_module ,
275+ allow_imports = allow_imports ,
276+ pyref_policy = pyref_policy ,
277+ )
278+ kwargs ["name" ] = name
279+ super ().__init__ (* args , ** kwargs )
280+
281+ def parse (self , arguments ):
282+ new_parsed = self ._parse (arguments )
283+ self ._lazy_value .remaining_directives .extend (new_parsed )
284+ self .present += len (new_parsed )
285+
286+ def _parse_from_default (
287+ self , value : Union [Text , List [Any ]]
288+ ) -> Optional [List [Any ]]:
289+ lazy_default_value = _LazyFlagValue (
290+ flag_name = self .name ,
291+ default_module = self .default_module ,
292+ allow_imports = self .allow_imports ,
293+ pyref_policy = self ._pyref_policy ,
294+ )
295+ value = self ._parse (value )
296+ assert isinstance (value , list )
297+ lazy_default_value .remaining_directives .extend (value )
298+ return lazy_default_value # pytype: disable=bad-return-type
299+
300+ def unparse (self ) -> None :
301+ self .value = self .default
302+ self .using_default_value = True
303+ # Reset it so that all `directives` not being processed yet will be
304+ # discarded.
305+ self ._lazy_value .remaining_directives = []
306+ self .present = 0
307+
308+ def _serialize (self , value ) -> str :
309+ # Skip MultiFlag serialization as we don't truly have a multi-flag.
310+ # This will invoke Flag._serialize
311+ return super (flags .MultiFlag , self )._serialize (value )
312+
313+ @property
314+ def value (self ):
315+ return self ._lazy_value .get_value ()
316+
257317 @value .setter
258318 def value (self , value ):
259- self ._value = value
319+ self ._lazy_value .set_value (value )
320+
321+ @property
322+ def default (self ):
323+ return self ._lazy_default .get_value ()
324+
325+ @default .setter
326+ def default (self , value ):
327+ if isinstance (value , _LazyFlagValue ):
328+ # Note: This is only for _set_default(). We might choose to override that
329+ # instead of just _parse_from_default(), in which case this branch can be
330+ # removed.
331+ self ._lazy_default = value
332+ else :
333+ self ._lazy_default .set_value (value )
260334
261335
262336def DEFINE_fiddle_config ( # pylint: disable=invalid-name
263337 name : str ,
264338 * ,
265- default : Any = None ,
339+ default_flag_str : Optional [ str ] = None ,
266340 help_string : str ,
267341 default_module : Optional [types .ModuleType ] = None ,
268342 pyref_policy : Optional [serialization .PyrefPolicy ] = None ,
@@ -317,12 +391,12 @@ def main(argv) -> None:
317391 python3 -m path.to.my.binary --my_config=config_file:path/to/file
318392
319393 Args:
320- name: name of the command line flag.
321- default: default value of the flag.
322- help_string: help string describing what the flag does.
323- default_module: the python module where this flag is defined.
324- pyref_policy: a policy for importing references to Python objects.
325- flag_values: the ``FlagValues`` instance with which the flag will be
394+ name: Name of the command line flag.
395+ default_flag_str: Default value of the flag.
396+ help_string: Help string describing what the flag does.
397+ default_module: The python module where this flag is defined.
398+ pyref_policy: A policy for importing references to Python objects.
399+ flag_values: The ``FlagValues`` instance with which the flag will be
326400 registered. This should almost never need to be overridden.
327401 required: bool, is this a required flag. This must be used as a keyword
328402 argument.
@@ -334,7 +408,7 @@ def main(argv) -> None:
334408 FiddleFlag (
335409 name = name ,
336410 default_module = default_module ,
337- default = default ,
411+ default = default_flag_str ,
338412 pyref_policy = pyref_policy ,
339413 parser = flags .ArgumentParser (),
340414 serializer = FiddleFlagSerializer (pyref_policy = pyref_policy ),
0 commit comments