Skip to content

Commit

Permalink
Refactor Torch Mocking in OneDiff (#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccssu authored Nov 15, 2023
1 parent 2f1fb3d commit d62abe2
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions src/onediff/infer_compiler/transform/manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import oneflow as flow
from typing import Dict, List, Union
from contextlib import contextmanager
from pathlib import Path
from ..import_tools import (
get_classes_in_package,
Expand All @@ -9,6 +11,24 @@
__all__ = ["transform_mgr"]


@contextmanager
def onediff_mock_torch():
# Fixes check the 'version' error.
attr_name = "__version__"
restore_funcs = [] # Backup
if hasattr(flow, attr_name) and hasattr(torch, attr_name):
orig_flow_attr = getattr(flow, attr_name)
restore_funcs.append(lambda: setattr(flow, attr_name, orig_flow_attr))
setattr(flow, attr_name, getattr(torch, attr_name))

# https://docs.oneflow.org/master/cookies/oneflow_torch.html
with flow.mock_torch.enable(lazy=True):
yield

for restore_func in restore_funcs:
restore_func()


class TransformManager:
def __init__(self):
self._torch_to_oflow_cls_map = {}
Expand All @@ -17,8 +37,7 @@ def __init__(self):
def load_class_proxies_from_packages(self, package_names: List[Union[Path, str]]):
print_green(f"Loading modules: {package_names}")
of_mds = {}
# https://docs.oneflow.org/master/cookies/oneflow_torch.html
with flow.mock_torch.enable(lazy=True):
with onediff_mock_torch():
for package_name in package_names:
of_mds.update(get_classes_in_package(package_name))
print_green(f"Loaded Mock Torch {len(of_mds)} classes: {package_names}")
Expand Down

0 comments on commit d62abe2

Please sign in to comment.