diff --git a/torchfix/__main__.py b/torchfix/__main__.py index 4964e01..3e7de8e 100644 --- a/torchfix/__main__.py +++ b/torchfix/__main__.py @@ -70,6 +70,12 @@ def _parse_args() -> argparse.Namespace: type=str, default=None, ) + parser.add_argument( + "--disable", + help="Comma-separated list of rules to disable. Defaults to None.", + type=str, + default=None, + ) parser.add_argument("--version", action="version", version=f"{TorchFixVersion}") # XXX TODO: Get rid of this! @@ -101,7 +107,15 @@ def main() -> None: if not torch_files: return config = TorchCodemodConfig() - config.select = list(process_error_code_str(args.select)) + selected_rules = process_error_code_str(args.select, True) + if args.disable is not None: + if args.disable == "ALL": + print("No rule to apply", file=sys.stderr) + sys.exit(1) + disabled_rules = process_error_code_str(args.disable, False) + selected_rules = set(selected_rules) - set(disabled_rules) + + config.select = list(selected_rules) command_instance = TorchCodemod(codemod.CodemodContext(), config) DIFF_CONTEXT = 5 try: diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 5e96e38..a8325f9 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -95,12 +95,14 @@ def get_visitors_with_error_codes(error_codes): return [construct_visitor(cls) for cls in visitor_classes] -def process_error_code_str(code_str): +def process_error_code_str(code_str, enabled = True): # Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001. # We deduplicate them here. # Default when --select is not provided. if code_str is None: + if not enabled: + return set() exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT)) return set(GET_ALL_ERROR_CODES()) - exclude_set