Skip to content

Commit 668b8e6

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Options to remove eval / decoder datasets, if those are not needed for downstream usage and just add verbosity to the baseline config.
PiperOrigin-RevId: 551629157
1 parent b56dfee commit 668b8e6

8 files changed

Lines changed: 261 additions & 1 deletion

File tree

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# coding=utf-8
2+
# Copyright 2022 The Fiddle-Config Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Adds type signatures to modules.
17+
18+
For now, we only populate return types.
19+
"""
20+
21+
import inspect
22+
23+
from fiddle._src import config as config_lib
24+
from fiddle._src import signatures
25+
from fiddle._src.codegen import import_manager as import_manager_lib
26+
from fiddle._src.codegen.auto_config import code_ir
27+
from fiddle._src.codegen.auto_config import import_manager_wrapper
28+
29+
30+
_BUILTIN_TYPE_MAP = {
31+
type(None): "None",
32+
str: "str",
33+
int: "int",
34+
float: "float",
35+
bool: "bool",
36+
}
37+
38+
39+
def _get_annotation_from_type(typ) -> code_ir.CodegenNode:
40+
if typ in _BUILTIN_TYPE_MAP:
41+
return code_ir.BuiltinReference(code_ir.Name(_BUILTIN_TYPE_MAP[typ]))
42+
else:
43+
# TODO(b/293352960): import typing.Any correctly.
44+
return code_ir.BuiltinReference(code_ir.Name("Any"))
45+
46+
47+
def get_type_annotation(
48+
value, import_manager: import_manager_lib.ImportManager
49+
) -> code_ir.CodegenNode:
50+
"""Gets the type annotation for a given value."""
51+
if isinstance(value, config_lib.Buildable):
52+
buildable_type = import_manager_wrapper.add(type(value), import_manager)
53+
fn_or_cls = config_lib.get_callable(value)
54+
if isinstance(fn_or_cls, type):
55+
sub_type = import_manager_wrapper.add(fn_or_cls, import_manager)
56+
else:
57+
signature = signatures.get_signature(fn_or_cls)
58+
if isinstance(signature.return_annotation, type) and (
59+
signature.return_annotation is not inspect.Signature.empty
60+
):
61+
sub_type = _get_annotation_from_type(signature.return_annotation)
62+
else:
63+
return buildable_type
64+
return code_ir.ParameterizedTypeExpression(buildable_type, [sub_type])
65+
elif isinstance(value, (list, tuple)):
66+
base_expression = code_ir.BuiltinReference(
67+
code_ir.Name("list" if isinstance(value, list) else "tuple")
68+
)
69+
sub_value_annotations = [
70+
get_type_annotation(item, import_manager) for item in value
71+
]
72+
if sub_value_annotations and all(
73+
annotation == sub_value_annotations[0]
74+
for annotation in sub_value_annotations
75+
):
76+
return code_ir.ParameterizedTypeExpression(
77+
base_expression, [sub_value_annotations[0]]
78+
)
79+
else:
80+
return base_expression
81+
elif isinstance(value, dict):
82+
base_expression = code_ir.BuiltinReference(code_ir.Name("dict"))
83+
key_annotations = [
84+
get_type_annotation(item, import_manager) for item in value.keys()
85+
]
86+
value_annotations = [
87+
get_type_annotation(item, import_manager) for item in value.values()
88+
]
89+
if key_annotations and all(
90+
annotation == key_annotations[0] for annotation in key_annotations
91+
):
92+
key_annotation = key_annotations[0]
93+
else:
94+
# TODO(b/293352960): import typing.Any correctly.
95+
key_annotation = code_ir.BuiltinReference(code_ir.Name("Any"))
96+
if value_annotations and all(
97+
annotation == value_annotations[0] for annotation in value_annotations
98+
):
99+
value_annotation = value_annotations[0]
100+
else:
101+
value_annotation = code_ir.BuiltinReference(code_ir.Name("Any"))
102+
return code_ir.ParameterizedTypeExpression(
103+
base_expression, [key_annotation, value_annotation]
104+
)
105+
else:
106+
return _get_annotation_from_type(type(value))
107+
108+
109+
def add_return_types(task: code_ir.CodegenTask) -> None:
110+
"""Adds return type signatures.
111+
112+
This is normally based on config types, so for `auto_config`, it would reflect
113+
the as_buildable() path. Hence, we don't add it by default yet.
114+
115+
Args:
116+
task: Codegen task.
117+
"""
118+
for fn in task.top_level_call.all_fixture_functions():
119+
fn.return_type_annotation = get_type_annotation(
120+
fn.output_value, task.import_manager
121+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# coding=utf-8
2+
# Copyright 2022 The Fiddle-Config Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for add_type_signatures."""
17+
18+
from absl.testing import absltest
19+
from absl.testing import parameterized
20+
import fiddle as fdl
21+
from fiddle._src.codegen import import_manager as import_manager_lib
22+
from fiddle._src.codegen import namespace as namespace_lib
23+
from fiddle._src.codegen.auto_config import add_type_signatures
24+
from fiddle._src.codegen.auto_config import ir_printer
25+
from fiddle._src.codegen.auto_config import test_fixtures
26+
from fiddle._src.testing.example import fake_encoder_decoder
27+
28+
29+
def foo(x):
30+
return x
31+
32+
33+
def bar(x: int) -> int:
34+
return x
35+
36+
37+
class AddTypeSignaturesTest(parameterized.TestCase):
38+
39+
@parameterized.parameters(
40+
{
41+
"value": True,
42+
"expected": "bool",
43+
},
44+
{
45+
"value": [1, 2, 3],
46+
"expected": "list[int]",
47+
},
48+
{
49+
"value": [1, 2, "a"],
50+
"expected": "list",
51+
},
52+
{
53+
"value": {"hi": 3, "bye": 4},
54+
"expected": "dict[str, int]",
55+
},
56+
{
57+
"value": {},
58+
"expected": "dict[Any, Any]",
59+
},
60+
{
61+
# Custom types are replaced with Any.
62+
# (Rationale: Don't put custom objects in Fiddle configs.)
63+
"value": namespace_lib.Namespace(set()),
64+
"expected": "Any",
65+
},
66+
{
67+
"value": fdl.Config(foo, x=1),
68+
"expected": "fdl.Config",
69+
},
70+
{
71+
"value": fdl.Config(bar, x=1),
72+
"expected": "fdl.Config[int]",
73+
},
74+
{
75+
"value": fdl.Config(fake_encoder_decoder.FakeEncoderDecoder),
76+
"expected": "fdl.Config[fake_encoder_decoder.FakeEncoderDecoder]",
77+
},
78+
{
79+
"value": fdl.Partial(foo, x=1),
80+
"expected": "fdl.Partial",
81+
},
82+
{
83+
"value": fdl.Partial(bar, x=1),
84+
"expected": "fdl.Partial[int]",
85+
},
86+
)
87+
def test_get_type_annotation(self, value, expected):
88+
import_manager = import_manager_lib.ImportManager(namespace_lib.Namespace())
89+
expression = add_type_signatures.get_type_annotation(
90+
value=value, import_manager=import_manager
91+
)
92+
formatted = ir_printer.format_expr(expression)
93+
self.assertEqual(formatted, expected)
94+
95+
@parameterized.named_parameters(*test_fixtures.parameters_for_testcases())
96+
def test_smoke_add_return_types(self, task):
97+
add_type_signatures.add_return_types(task)
98+
99+
100+
if __name__ == "__main__":
101+
absltest.main()

fiddle/_src/codegen/auto_config/code_ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class FixtureFunction(CodegenNode):
225225
parameters: List[Parameter]
226226
variables: List[VariableDeclaration]
227227
output_value: Any # Value that can involve VariableReference's
228+
return_type_annotation: Optional[Any] = None
228229

229230
def __hash__(self):
230231
return id(self)

fiddle/_src/codegen/auto_config/code_ir_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_daglish_iteration(self):
6868
(".variables", []),
6969
(".output_value", fn.output_value),
7070
(".output_value.x", 2),
71+
(".return_type_annotation", None),
7172
],
7273
)
7374

fiddle/_src/codegen/auto_config/ir_printer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ def traverse(value, state: daglish.State) -> str:
108108
f"(*[{positional_arg_expressions}],"
109109
f" **{arg_expressions})>"
110110
)
111+
elif isinstance(value, code_ir.ParameterizedTypeExpression):
112+
base_expression = state.call(
113+
value.base_expression, daglish.Attr("base_expression")
114+
)
115+
param_expressions = state.call(
116+
value.param_expressions, daglish.Attr("param_expressions")
117+
)
118+
return f"{base_expression}{param_expressions}"
111119
elif isinstance(value, code_ir.Name):
112120
return value.value
113121
elif isinstance(value, type):

fiddle/_src/codegen/auto_config/ir_to_cst.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ def _prepare_args_helper(
101101
elif isinstance(value, code_ir.AttributeExpression):
102102
base = state.call(value.base, daglish.Attr("base"))
103103
return cst.Attribute(value=base, attr=cst.Name(value.attribute))
104+
elif isinstance(value, code_ir.ParameterizedTypeExpression):
105+
return cst.Subscript(
106+
value=code_for_expr(value.base_expression),
107+
slice=[
108+
cst.SubscriptElement(cst.Index(code_for_expr(param)))
109+
for param in value.param_expressions
110+
],
111+
)
104112
elif isinstance(value, code_ir.SymbolOrFixtureCall):
105113
attr = daglish.Attr("arg_expressions")
106114
args = []
@@ -199,6 +207,12 @@ def code_for_fn(
199207
),
200208
]
201209
)
210+
if fn.return_type_annotation:
211+
returns = cst.Annotation(
212+
annotation=code_for_expr(fn.return_type_annotation)
213+
)
214+
else:
215+
returns = None
202216
if fn.parameters and len(fn.parameters) > 1:
203217
whitespace_before_params = cst.ParenthesizedWhitespace(
204218
cst.TrailingWhitespace(),
@@ -211,6 +225,7 @@ def code_for_fn(
211225
cst.Name(fn.name.value),
212226
params,
213227
body,
228+
returns=returns,
214229
decorators=[cst.Decorator(auto_config_expr)] if auto_config_expr else [],
215230
whitespace_before_params=whitespace_before_params,
216231
leading_lines=[cst.EmptyLine(), cst.EmptyLine()],

fiddle/_src/codegen/new_codegen.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import fiddle as fdl
2525
from fiddle._src.codegen import newcg_symbolic_references
26+
from fiddle._src.codegen.auto_config import add_type_signatures
2627
from fiddle._src.codegen.auto_config import code_ir
2728
from fiddle._src.codegen.auto_config import experimental_top_level_api
2829
from fiddle._src.codegen.auto_config import make_symbolic_references as old_symbolic_references
@@ -61,6 +62,13 @@ class MakeSymbolicReferences(experimental_top_level_api.MutationCodegenPass):
6162
)
6263

6364

65+
@dataclasses.dataclass(frozen=True)
66+
class AddTypeSignatures(experimental_top_level_api.MutationCodegenPass):
67+
"""Adds return type signatures to fixtures."""
68+
69+
fn: Callable[..., Any] = add_type_signatures.add_return_types
70+
71+
6472
def _get_pass_idx(
6573
codegen_config: fdl.Config[experimental_top_level_api.Codegen],
6674
cls: Type[experimental_top_level_api.CodegenPass],
@@ -100,6 +108,11 @@ def code_generator(
100108
# Replace MakeSymbolicReferences
101109
idx = _get_pass_idx(config, experimental_top_level_api.MakeSymbolicReferences)
102110
fdl.update_callable(config.passes[idx], MakeSymbolicReferences)
111+
112+
# Insert type annotations before MakeSymbolicReferences. These type
113+
# annotations currently make more sense for non-auto_config cases.
114+
config.passes.insert(idx, fdl.Config(AddTypeSignatures))
115+
103116
return config
104117

105118

fiddle/_src/codegen/new_codegen_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_code_output(self):
9191
from fiddle._src.testing.example import fake_encoder_decoder
9292
9393
94-
def config_fixture():
94+
def config_fixture() -> fdl.Config[fake_encoder_decoder.FakeEncoder]:
9595
mlp = fdl.Config(fake_encoder_decoder.Mlp, dtype='float32',
9696
use_bias=False, sharding_axes=['embed', 'num_heads', 'head_dim'])
9797
return fdl.Config(fake_encoder_decoder.FakeEncoder, embedders={'tokens':

0 commit comments

Comments
 (0)