Skip to content

Commit

Permalink
graph module retracing without preserving MCS (pytorch#143676)
Browse files Browse the repository at this point in the history
Retracing while preserving module call signatures used to be a problem because graph modules don't have submodules at given paths. This led to a number of failing retracebility tests. By not trying to wrap modules with export tracepoints we can pass most of these tests; the only exception is where you do module swapping on retraced programs, which is still not possible.

Differential Revision: [D67539304](https://our.internmc.facebook.com/intern/diff/D67539304/)
Pull Request resolved: pytorch#143676
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
ghstack dependencies: pytorch#143664
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Dec 21, 2024
1 parent d7e59c2 commit 51eacea
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 102 deletions.
159 changes: 68 additions & 91 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4038,7 +4038,7 @@ def forward(self, x):
return self.n(x + 1, True) + self.n(x + 1, False)

x = torch.zeros(4)
types = {} if is_retracebility_test(self._testMethodName) else {"n": N}
types = {"n": N}
ep = export(
M(),
(x,),
Expand Down Expand Up @@ -7722,14 +7722,11 @@ def forward(self, x):

inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
Expand Down Expand Up @@ -7787,14 +7784,11 @@ def forward(self, x):

inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
Expand Down Expand Up @@ -7851,14 +7845,11 @@ def forward(self, x):

inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
Expand Down Expand Up @@ -7934,15 +7925,12 @@ def forward(self, x):

inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
Expand Down Expand Up @@ -8060,17 +8048,14 @@ def forward(self, x):

inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
"n1.n2.n3.n4.n5",
"n1.n2.n3.n4.n5.n6",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
"n1.n2.n3.n4.n5",
"n1.n2.n3.n4.n5.n6",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
Expand Down Expand Up @@ -8248,20 +8233,17 @@ def forward(self, x):

inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
"n1.n2.n3.n4.n5",
"n1.n2.n3.n4.n5.n6",
"n1.n2.n3.n4.n5.n6.n7",
"n1.n2.n3.n4.n5.n6.n7.n8",
"n1.n2.n3.n4.n5.n6.n7.n8.n9",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
"n1.n2.n3.n4.n5",
"n1.n2.n3.n4.n5.n6",
"n1.n2.n3.n4.n5.n6.n7",
"n1.n2.n3.n4.n5.n6.n7.n8",
"n1.n2.n3.n4.n5.n6.n7.n8.n9",
)
ep = export(
N0(),
inp,
Expand Down Expand Up @@ -8307,13 +8289,10 @@ def forward(self, x):

inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
)
fqns = (
"n1",
"n1.n2",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
Expand Down Expand Up @@ -8354,13 +8333,10 @@ def forward(self, x):

inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
)
fqns = (
"n1",
"n1.n2",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
Expand Down Expand Up @@ -8451,6 +8427,7 @@ def test(ep, swap):
self.assertTrue(torch.allclose(unflattened_result, eager_result))

if not is_retracebility_test(self._testMethodName):
# swapping will not work with retrace
test(
export(Mod(), inp, preserve_module_call_signature=(path_n,)),
swap={path_n: N()},
Expand Down Expand Up @@ -8484,6 +8461,7 @@ def forward(self, x):
eager_result = m(*inp)

if not is_retracebility_test(self._testMethodName):
# swapping will not work with retrace
ep = export(M(), inp, preserve_module_call_signature=("n",))
epm = ep.module()
ufm = torch.export.unflatten(ep)
Expand Down Expand Up @@ -8535,18 +8513,17 @@ def test(ep):
unflattened_result = ufm(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))

if not is_retracebility_test(self._testMethodName):
if is_training_ir_test(self._testMethodName):
test(
torch.export.export_for_training(
M(),
inp,
strict=not is_non_strict_test(self._testMethodName),
preserve_module_call_signature=("n",),
)
if is_training_ir_test(self._testMethodName):
test(
torch.export.export_for_training(
M(),
inp,
strict=not is_non_strict_test(self._testMethodName),
preserve_module_call_signature=("n",),
)
)

test(export(M(), inp, preserve_module_call_signature=("n",)))
test(export(M(), inp, preserve_module_call_signature=("n",)))

def test_unflatten_multiple_graphs_preserve_signature_no_error(self):
class N(torch.nn.Module):
Expand Down Expand Up @@ -8590,6 +8567,7 @@ def test(ep, swap=None):
self.assertTrue(torch.allclose(unflattened_result, eager_result))

if not is_retracebility_test(self._testMethodName):
# swapping will not work with retrace
test(
export(M(), inp, preserve_module_call_signature=("n",)),
swap={"n": N()},
Expand Down Expand Up @@ -8646,6 +8624,7 @@ def test(ep, swap=None):
self.assertTrue(torch.allclose(unflattened_result, eager_result))

if not is_retracebility_test(self._testMethodName):
# swapping will not work with retrace
test(
export(M(), inp, preserve_module_call_signature=("n",)),
swap={"n": N()},
Expand Down Expand Up @@ -8790,15 +8769,13 @@ def test(m, expected_graph, expected_fqns, expected_duplicates):
id(getattr(unflattened, a)), id(getattr(unflattened, b))
)

if not is_retracebility_test(self._testMethodName):
# preserving module call signatures
ep = export(m, inp, preserve_module_call_signature=("n", "p"))
exported_result = ep.module()(*inp)
self.assertTrue(torch.allclose(exported_result, eager_result))
ep = export(m, inp, preserve_module_call_signature=("n", "p"))
exported_result = ep.module()(*inp)
self.assertTrue(torch.allclose(exported_result, eager_result))

unflattened = torch.export.unflatten(ep)
unflattened_result = unflattened(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))
unflattened = torch.export.unflatten(ep)
unflattened_result = unflattened(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))

test(
gen_m(n=True, n_1=False, p=False, p_1=False),
Expand Down
6 changes: 1 addition & 5 deletions torch/_export/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def export_tracepoint_cpu(*args, **kwargs):
def _wrap_submodule(mod, path, module_call_specs):
assert isinstance(mod, torch.nn.Module)
assert path != ""
submodule = mod
for name in path.split("."):
if not hasattr(submodule, name):
raise RuntimeError(f"Couldn't find submodule at path {path}")
submodule = getattr(submodule, name)
submodule = torch.fx.graph_module._get_attr(mod, path)

def update_module_call_signatures(path, in_spec, out_spec):
if path in module_call_specs:
Expand Down
18 changes: 12 additions & 6 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,9 +658,12 @@ def _export_to_torch_ir(
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
try:
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
with _wrap_submodules(
f, preserve_module_call_signature, module_call_specs
), _ignore_backend_decomps():
ctx = nullcontext()
if not isinstance(f, torch.fx.GraphModule):
ctx = _wrap_submodules( # type: ignore[assignment]
f, preserve_module_call_signature, module_call_specs
)
with ctx, _ignore_backend_decomps():
gm_torch_level, _ = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
Expand Down Expand Up @@ -1684,9 +1687,12 @@ def forward(self, *args, **kwargs):
new_preserved_call_signatures = [
"_export_root." + i for i in preserve_module_call_signature
]
with _wrap_submodules(
wrapped_mod, new_preserved_call_signatures, module_call_specs
):
ctx = nullcontext()
if not isinstance(mod, torch.fx.GraphModule):
ctx = _wrap_submodules( # type: ignore[assignment]
wrapped_mod, new_preserved_call_signatures, module_call_specs
)
with ctx:
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
log.debug("Exported program from AOTAutograd:\n%s", gm)

Expand Down

0 comments on commit 51eacea

Please sign in to comment.