Skip to content

Commit b957aa4

Browse files
authored
Fix compatibility with pyTorch 2.0 (#627)
Signed-off-by: Przemek Tredak <[email protected]>
1 parent bea70f2 commit b957aa4

File tree

1 file changed

+6
-1
lines changed
  • transformer_engine/pytorch

1 file changed

+6
-1
lines changed

transformer_engine/pytorch/jit.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
no_torch_dynamo = lambda recursive=True: lambda func: func
2323
if torch.__version__ >= "2":
2424
import torch._dynamo
25-
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)
25+
if torch.__version__ >= "2.1":
26+
no_torch_dynamo = lambda recursive=True: lambda f: \
27+
torch._dynamo.disable(f, recursive=recursive)
28+
else:
29+
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
30+
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
2631

2732

2833
def set_jit_fusion_options() -> None:

0 commit comments

Comments
 (0)