Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def _parse_args() -> argparse.Namespace:
type=str,
default=None,
)
parser.add_argument(
"--disable",
help="Comma-separated list of rules to disable. Defaults to None.",
type=str,
default=None,
)
parser.add_argument("--version", action="version", version=f"{TorchFixVersion}")

# XXX TODO: Get rid of this!
Expand Down Expand Up @@ -101,7 +107,15 @@ def main() -> None:
if not torch_files:
return
config = TorchCodemodConfig()
config.select = list(process_error_code_str(args.select))
selected_rules = process_error_code_str(args.select, True)
if args.disable is not None:
if args.disable == "ALL":
print("No rule to apply", file=sys.stderr)
sys.exit(1)
disabled_rules = process_error_code_str(args.disable, False)
selected_rules = set(selected_rules) - set(disabled_rules)

config.select = list(selected_rules)
command_instance = TorchCodemod(codemod.CodemodContext(), config)
DIFF_CONTEXT = 5
try:
Expand Down
4 changes: 3 additions & 1 deletion torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ def get_visitors_with_error_codes(error_codes):
return [construct_visitor(cls) for cls in visitor_classes]


def process_error_code_str(code_str):
def process_error_code_str(code_str, enabled = True):
# Allow duplicates in the input string, e.g. --select ALL,TOR0,TOR001.
# We deduplicate them here.

# Default when --select is not provided.
if code_str is None:
if not enabled:
return set()
exclude_set = expand_error_codes(tuple(DISABLED_BY_DEFAULT))
return set(GET_ALL_ERROR_CODES()) - exclude_set

Expand Down