Skip to content

Commit 0768aaf

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Pax-specific codegen: emit docstrings for classes and methods, to silence lint errors.
PiperOrigin-RevId: 551647816
1 parent b56dfee commit 0768aaf

File tree

8 files changed

+285
-1
lines changed

8 files changed

+285
-1
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# coding=utf-8
2+
# Copyright 2022 The Fiddle-Config Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Adds type signatures to modules.
17+
18+
For now, we only populate return types.
19+
"""
20+
21+
import inspect
22+
23+
from fiddle._src import config as config_lib
24+
from fiddle._src import signatures
25+
from fiddle._src.codegen import import_manager as import_manager_lib
26+
from fiddle._src.codegen.auto_config import code_ir
27+
from fiddle._src.codegen.auto_config import import_manager_wrapper
28+
29+
30+
_BUILTIN_TYPE_MAP = {
31+
type(None): "None",
32+
str: "str",
33+
int: "int",
34+
float: "float",
35+
bool: "bool",
36+
}
37+
38+
39+
def _get_annotation_from_type(typ) -> code_ir.CodegenNode:
40+
if typ in _BUILTIN_TYPE_MAP:
41+
return code_ir.BuiltinReference(code_ir.Name(_BUILTIN_TYPE_MAP[typ]))
42+
else:
43+
# TODO(b/293352960): import typing.Any correctly.
44+
# TODO(b/293509806): Handle more types, especially from function return
45+
# signatures.
46+
return code_ir.BuiltinReference(code_ir.Name("Any"))
47+
48+
49+
def get_type_annotation(
50+
value, import_manager: import_manager_lib.ImportManager
51+
) -> code_ir.CodegenNode:
52+
"""Gets the type annotation for a given value."""
53+
if isinstance(value, config_lib.Buildable):
54+
buildable_type = import_manager_wrapper.add(type(value), import_manager)
55+
fn_or_cls = config_lib.get_callable(value)
56+
if isinstance(fn_or_cls, type):
57+
sub_type = import_manager_wrapper.add(fn_or_cls, import_manager)
58+
else:
59+
signature = signatures.get_signature(fn_or_cls)
60+
if isinstance(signature.return_annotation, type) and (
61+
signature.return_annotation is not inspect.Signature.empty
62+
):
63+
sub_type = _get_annotation_from_type(signature.return_annotation)
64+
else:
65+
return buildable_type
66+
return code_ir.ParameterizedTypeExpression(buildable_type, [sub_type])
67+
elif isinstance(value, (list, tuple)):
68+
base_expression = code_ir.BuiltinReference(
69+
code_ir.Name("list" if isinstance(value, list) else "tuple")
70+
)
71+
sub_value_annotations = [
72+
get_type_annotation(item, import_manager) for item in value
73+
]
74+
if sub_value_annotations and all(
75+
annotation == sub_value_annotations[0]
76+
for annotation in sub_value_annotations
77+
):
78+
return code_ir.ParameterizedTypeExpression(
79+
base_expression, [sub_value_annotations[0]]
80+
)
81+
else:
82+
return base_expression
83+
elif isinstance(value, dict):
84+
base_expression = code_ir.BuiltinReference(code_ir.Name("dict"))
85+
key_annotations = [
86+
get_type_annotation(item, import_manager) for item in value.keys()
87+
]
88+
value_annotations = [
89+
get_type_annotation(item, import_manager) for item in value.values()
90+
]
91+
if key_annotations and all(
92+
annotation == key_annotations[0] for annotation in key_annotations
93+
):
94+
key_annotation = key_annotations[0]
95+
else:
96+
# TODO(b/293352960): import typing.Any correctly.
97+
key_annotation = code_ir.BuiltinReference(code_ir.Name("Any"))
98+
if value_annotations and all(
99+
annotation == value_annotations[0] for annotation in value_annotations
100+
):
101+
value_annotation = value_annotations[0]
102+
else:
103+
value_annotation = code_ir.BuiltinReference(code_ir.Name("Any"))
104+
return code_ir.ParameterizedTypeExpression(
105+
base_expression, [key_annotation, value_annotation]
106+
)
107+
else:
108+
return _get_annotation_from_type(type(value))
109+
110+
111+
def add_return_types(task: code_ir.CodegenTask) -> None:
112+
"""Adds return type signatures.
113+
114+
This is normally based on config types, so for `auto_config`, it would reflect
115+
the as_buildable() path. Hence, we don't add it by default yet.
116+
117+
Args:
118+
task: Codegen task.
119+
"""
120+
for fn in task.top_level_call.all_fixture_functions():
121+
fn.return_type_annotation = get_type_annotation(
122+
fn.output_value, task.import_manager
123+
)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# coding=utf-8
2+
# Copyright 2022 The Fiddle-Config Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for add_type_signatures."""
17+
18+
from typing import List
19+
20+
from absl.testing import absltest
21+
from absl.testing import parameterized
22+
import fiddle as fdl
23+
from fiddle._src.codegen import import_manager as import_manager_lib
24+
from fiddle._src.codegen import namespace as namespace_lib
25+
from fiddle._src.codegen.auto_config import add_type_signatures
26+
from fiddle._src.codegen.auto_config import ir_printer
27+
from fiddle._src.codegen.auto_config import test_fixtures
28+
from fiddle._src.testing.example import fake_encoder_decoder
29+
30+
31+
def foo(x):
32+
return x
33+
34+
35+
def bar(x: int) -> int:
36+
return x
37+
38+
39+
def baz() -> List[int]:
40+
return [1]
41+
42+
43+
def qux() -> list: # pylint: disable=g-bare-generic
44+
return [1]
45+
46+
47+
class AddTypeSignaturesTest(parameterized.TestCase):
48+
49+
@parameterized.parameters(
50+
{
51+
"value": True,
52+
"expected": "bool",
53+
},
54+
{
55+
"value": [1, 2, 3],
56+
"expected": "list[int]",
57+
},
58+
{
59+
"value": [1, 2, "a"],
60+
"expected": "list",
61+
},
62+
{
63+
"value": {"hi": 3, "bye": 4},
64+
"expected": "dict[str, int]",
65+
},
66+
{
67+
"value": {},
68+
"expected": "dict[Any, Any]",
69+
},
70+
{
71+
# Custom types are replaced with Any.
72+
# (Rationale: Don't put custom objects in Fiddle configs.)
73+
"value": namespace_lib.Namespace(set()),
74+
"expected": "Any",
75+
},
76+
{
77+
"value": fdl.Config(foo, x=1),
78+
"expected": "fdl.Config",
79+
},
80+
{
81+
"value": fdl.Config(bar, x=1),
82+
"expected": "fdl.Config[int]",
83+
},
84+
{
85+
# TODO(b/293509806): Handle more types, especially from function
86+
# return signatures.
87+
"value": fdl.Config(baz),
88+
"expected": "fdl.Config",
89+
},
90+
{
91+
# TODO(b/293509806): Handle more types, especially from function
92+
# return signatures.
93+
"value": fdl.Config(qux),
94+
"expected": "fdl.Config[Any]",
95+
},
96+
{
97+
"value": fdl.Config(fake_encoder_decoder.FakeEncoderDecoder),
98+
"expected": "fdl.Config[fake_encoder_decoder.FakeEncoderDecoder]",
99+
},
100+
{
101+
"value": fdl.Partial(foo, x=1),
102+
"expected": "fdl.Partial",
103+
},
104+
{
105+
"value": fdl.Partial(bar, x=1),
106+
"expected": "fdl.Partial[int]",
107+
},
108+
)
109+
def test_get_type_annotation(self, value, expected):
110+
import_manager = import_manager_lib.ImportManager(namespace_lib.Namespace())
111+
expression = add_type_signatures.get_type_annotation(
112+
value=value, import_manager=import_manager
113+
)
114+
formatted = ir_printer.format_expr(expression)
115+
self.assertEqual(formatted, expected)
116+
117+
@parameterized.named_parameters(*test_fixtures.parameters_for_testcases())
118+
def test_smoke_add_return_types(self, task):
119+
add_type_signatures.add_return_types(task)
120+
121+
122+
if __name__ == "__main__":
123+
absltest.main()

fiddle/_src/codegen/auto_config/code_ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class FixtureFunction(CodegenNode):
225225
parameters: List[Parameter]
226226
variables: List[VariableDeclaration]
227227
output_value: Any # Value that can involve VariableReference's
228+
return_type_annotation: Optional[Any] = None
228229

229230
def __hash__(self):
230231
return id(self)

fiddle/_src/codegen/auto_config/code_ir_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def test_daglish_iteration(self):
6868
(".variables", []),
6969
(".output_value", fn.output_value),
7070
(".output_value.x", 2),
71+
(".return_type_annotation", None),
7172
],
7273
)
7374

fiddle/_src/codegen/auto_config/ir_printer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ def traverse(value, state: daglish.State) -> str:
108108
f"(*[{positional_arg_expressions}],"
109109
f" **{arg_expressions})>"
110110
)
111+
elif isinstance(value, code_ir.ParameterizedTypeExpression):
112+
base_expression = state.call(
113+
value.base_expression, daglish.Attr("base_expression")
114+
)
115+
param_expressions = state.call(
116+
value.param_expressions, daglish.Attr("param_expressions")
117+
)
118+
return f"{base_expression}{param_expressions}"
111119
elif isinstance(value, code_ir.Name):
112120
return value.value
113121
elif isinstance(value, type):

fiddle/_src/codegen/auto_config/ir_to_cst.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ def _prepare_args_helper(
101101
elif isinstance(value, code_ir.AttributeExpression):
102102
base = state.call(value.base, daglish.Attr("base"))
103103
return cst.Attribute(value=base, attr=cst.Name(value.attribute))
104+
elif isinstance(value, code_ir.ParameterizedTypeExpression):
105+
return cst.Subscript(
106+
value=code_for_expr(value.base_expression),
107+
slice=[
108+
cst.SubscriptElement(cst.Index(code_for_expr(param)))
109+
for param in value.param_expressions
110+
],
111+
)
104112
elif isinstance(value, code_ir.SymbolOrFixtureCall):
105113
attr = daglish.Attr("arg_expressions")
106114
args = []
@@ -199,6 +207,12 @@ def code_for_fn(
199207
),
200208
]
201209
)
210+
if fn.return_type_annotation:
211+
returns = cst.Annotation(
212+
annotation=code_for_expr(fn.return_type_annotation)
213+
)
214+
else:
215+
returns = None
202216
if fn.parameters and len(fn.parameters) > 1:
203217
whitespace_before_params = cst.ParenthesizedWhitespace(
204218
cst.TrailingWhitespace(),
@@ -211,6 +225,7 @@ def code_for_fn(
211225
cst.Name(fn.name.value),
212226
params,
213227
body,
228+
returns=returns,
214229
decorators=[cst.Decorator(auto_config_expr)] if auto_config_expr else [],
215230
whitespace_before_params=whitespace_before_params,
216231
leading_lines=[cst.EmptyLine(), cst.EmptyLine()],

fiddle/_src/codegen/new_codegen.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import fiddle as fdl
2525
from fiddle._src.codegen import newcg_symbolic_references
26+
from fiddle._src.codegen.auto_config import add_type_signatures
2627
from fiddle._src.codegen.auto_config import code_ir
2728
from fiddle._src.codegen.auto_config import experimental_top_level_api
2829
from fiddle._src.codegen.auto_config import make_symbolic_references as old_symbolic_references
@@ -61,6 +62,13 @@ class MakeSymbolicReferences(experimental_top_level_api.MutationCodegenPass):
6162
)
6263

6364

65+
@dataclasses.dataclass(frozen=True)
66+
class AddTypeSignatures(experimental_top_level_api.MutationCodegenPass):
67+
"""Adds return type signatures to fixtures."""
68+
69+
fn: Callable[..., Any] = add_type_signatures.add_return_types
70+
71+
6472
def _get_pass_idx(
6573
codegen_config: fdl.Config[experimental_top_level_api.Codegen],
6674
cls: Type[experimental_top_level_api.CodegenPass],
@@ -100,6 +108,11 @@ def code_generator(
100108
# Replace MakeSymbolicReferences
101109
idx = _get_pass_idx(config, experimental_top_level_api.MakeSymbolicReferences)
102110
fdl.update_callable(config.passes[idx], MakeSymbolicReferences)
111+
112+
# Insert type annotations before MakeSymbolicReferences. These type
113+
# annotations currently make more sense for non-auto_config cases.
114+
config.passes.insert(idx, fdl.Config(AddTypeSignatures))
115+
103116
return config
104117

105118

fiddle/_src/codegen/new_codegen_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_code_output(self):
9191
from fiddle._src.testing.example import fake_encoder_decoder
9292
9393
94-
def config_fixture():
94+
def config_fixture() -> fdl.Config[fake_encoder_decoder.FakeEncoder]:
9595
mlp = fdl.Config(fake_encoder_decoder.Mlp, dtype='float32',
9696
use_bias=False, sharding_axes=['embed', 'num_heads', 'head_dim'])
9797
return fdl.Config(fake_encoder_decoder.FakeEncoder, embedders={'tokens':

0 commit comments

Comments
 (0)