Skip to content

Commit 4e339b5

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Pax-specific flattened config codegen (MVP, will be building on it further soon).
This will help flatten nested experiment hierarchies, making configs for baseline models easier to read. PiperOrigin-RevId: 543513463
1 parent a2e585a commit 4e339b5

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

fiddle/_src/codegen/auto_config/code_ir.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ class AttributeExpression(CodegenNode):
121121
base: Any # Wrapped expression, can involve VariableReference's
122122
attribute: str
123123

124+
def __hash__(self):
125+
return id(self)
126+
124127

125128
@dataclasses.dataclass
126129
class ArgFactoryExpr(CodegenNode):

fiddle/_src/codegen/auto_config/ir_printer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,24 @@ def traverse(value, state: daglish.State) -> str:
9090
elif isinstance(value, code_ir.WithTagsCall):
9191
sub_value = state.map_children(value).expression
9292
return f"WithTagsCall[{sub_value}]"
93+
elif isinstance(value, code_ir.SymbolOrFixtureCall):
94+
symbol_expression = state.call(
95+
value.symbol_expression, daglish.Attr("symbol_expression")
96+
)
97+
positional_arg_expressions = state.call(
98+
value.positional_arg_expressions,
99+
daglish.Attr("positional_arg_expressions"),
100+
)
101+
arg_expressions = state.call(
102+
value.arg_expressions, daglish.Attr("arg_expressions")
103+
)
104+
return (
105+
f"call:<{symbol_expression}"
106+
f"(*[{positional_arg_expressions}],"
107+
f" **{arg_expressions})>"
108+
)
109+
elif isinstance(value, code_ir.ModuleReference):
110+
return value.name.value
93111
elif isinstance(value, code_ir.Name):
94112
return value.value
95113
elif isinstance(value, type):

fiddle/_src/codegen/auto_config/ir_to_cst.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _prepare_args_helper(
154154
except:
155155
print(f"\n\nERROR CONVERTING: {value!r}")
156156
print(f"\n\nTYPE: {type(value)}")
157+
print("Path", daglish.path_str(state.current_path))
157158
raise
158159

159160
return daglish.MemoizedTraversal.run(traverse, expr)
@@ -198,7 +199,7 @@ def code_for_fn(
198199
),
199200
]
200201
)
201-
if fn.parameters:
202+
if fn.parameters and len(fn.parameters) > 1:
202203
whitespace_before_params = cst.ParenthesizedWhitespace(
203204
cst.TrailingWhitespace(),
204205
indent=True,

fiddle/_src/codegen/codegen_diff.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import functools
2020
import re
2121
import types
22-
from typing import Any, Callable, Dict, List, Set, Tuple
22+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
2323

2424
from fiddle import daglish
2525
from fiddle import diffing
@@ -31,10 +31,13 @@
3131
import libcst as cst
3232

3333

34-
def fiddler_from_diff(diff: diffing.Diff,
35-
old: Any = None,
36-
func_name: str = 'fiddler',
37-
param_name: str = 'cfg'):
34+
def fiddler_from_diff(
35+
diff: diffing.Diff,
36+
old: Any = None,
37+
func_name: str = 'fiddler',
38+
param_name: str = 'cfg',
39+
import_manager: Optional[import_manager_lib.ImportManager] = None,
40+
):
3841
"""Returns the CST for a fiddler function that applies the changes in `diff`.
3942
4043
The returned `cst.Module` consists of a set of `import` statements for any
@@ -66,18 +69,26 @@ def fiddler_from_diff(diff: diffing.Diff,
6669
all referenced paths.
6770
func_name: The name for the fiddler function.
6871
param_name: The name for the parameter to the fiddler function.
72+
import_manager: Existing import manager. Usually set to None, but if you are
73+
integrating this with other code generation tasks, it can be nice to
74+
share.
6975
7076
Returns:
7177
An `cst.Module` object. You can convert this to a string using
7278
`result.code`.
7379
"""
74-
# Create a namespace to keep track of variables that we add. Reserve the
75-
# names of the param & func.
76-
namespace = namespace_lib.Namespace()
77-
namespace.add(param_name)
78-
namespace.add(func_name)
79-
80-
import_manager = import_manager_lib.ImportManager(namespace)
80+
if import_manager is None:
81+
# Create a namespace to keep track of variables that we add. Reserve the
82+
# names of the param & func.
83+
namespace = namespace_lib.Namespace()
84+
namespace.add(param_name)
85+
namespace.add(func_name)
86+
87+
import_manager = import_manager_lib.ImportManager(namespace)
88+
else:
89+
namespace = import_manager.namespace
90+
namespace.add(param_name)
91+
namespace.add(func_name)
8192

8293
# Get a list of paths that are referenced by the diff.
8394
used_paths = _find_used_paths(diff)

0 commit comments

Comments
 (0)