diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 13753a6d5949..4d39efd84341 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -30,7 +30,13 @@ TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent - +# Safely load fast C Yaml loader/dumper if they are available +try: + from yaml import CSafeLoader as Loader +except ImportError: + from yaml import SafeLoader as Loader #type:ignore[assignment, misc] + +dimsa3-reGdoj-ciqbac def reindent(text, prefix=""): return indent(dedent(text), prefix) @@ -175,7 +181,7 @@ def generate_native_functions(self): ) ts_native_yaml = None if ts_native_yaml_path.exists(): - ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader) + ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), Loader) else: logging.warning( f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}" @@ -208,7 +214,7 @@ def get_opnames(ops): ) with self.config_path.open() as f: - config = yaml.load(f, yaml.CLoader) + config = yaml.load(f, Loader) # List of unsupported ops in LTC autogen because of some error blacklist = set(config.get("blacklist", []))