Skip to content

Commit effb27e

Browse files
authored
Add pre-commit hooks add lint docs (#51)
fix #48 * add pre-commit hooks * add black to CI * add contributor guidelines re: linting * reformat code using black * address mypy warnings re: generators ### Testing: 1. ``` pip install -r requirements-dev.txt pre-commit install pre-commit run --all-files ``` <img width="582" alt="image" src="https://github.com/pytorch-labs/torchfix/assets/108101595/734960dd-4ca2-4d8a-b58e-dc917332b7f9"> 2. ``` git commit -m 'test' ``` 3. Black CI works: https://github.com/pytorch-labs/torchfix/actions/runs/8791096314/job/24124561579?pr=51
1 parent a71baf1 commit effb27e

File tree

10 files changed

+79
-20
lines changed

10 files changed

+79
-20
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ exclude = ./tests/fixtures/
33

44
# Match black tool's default.
55
max-line-length = 88
6+
extend-ignore = E203

.github/workflows/test-torchfix.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ jobs:
2525
- name: Run mypy
2626
run: |
2727
mypy .
28+
- name: Run black
29+
run: |
30+
black --check .

.pre-commit-config.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
repos:
2+
- repo: local
3+
hooks:
4+
- id: black
5+
name: black
6+
entry: black
7+
language: system
8+
types: [python]
9+
args: ["--config=./pyproject.toml"]
10+
exclude: ^tests/fixtures/
11+
- id: flake8
12+
name: flake8
13+
entry: flake8
14+
language: system
15+
types: [python]
16+
args: ["--config=./.flake8"]
17+
exclude: ^tests/fixtures/
18+
- id: mypy
19+
name: mypy
20+
entry: mypy
21+
language: system
22+
types: [python]
23+
exclude: ^tests/fixtures/

CONTRIBUTING.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,34 @@ We actively welcome your pull requests.
1414
1. Fork the repo and create your branch from `main`.
1515
2. If you've added code that should be tested, add tests.
1616
3. If you've changed APIs, update the documentation.
17-
4. Ensure the test suite passes.
18-
5. Make sure your code lints.
17+
4. Ensure the test suite passes (`pytest tests`).
18+
5. Make sure your code lints (see Linting section below).
1919
6. If you haven't already, complete the Contributor License Agreement ("CLA").
2020

21+
## Linting
22+
23+
We use `black`, `flake8`, and `mypy` to lint the code.
24+
```
25+
pip install -r requirements-dev.txt
26+
```
27+
28+
Linting via pre-commit hook:
29+
```
30+
# install pre-commit hooks for the first time
31+
pre-commit install
32+
33+
# manually run pre-commit hooks on all files (runs all linters)
34+
pre-commit run --all-files
35+
```
36+
37+
Manually running individual linters:
38+
```
39+
black .
40+
flake8
41+
mypy .
42+
```
43+
44+
2145
## Contributor License Agreement ("CLA")
2246

2347
In order to accept your pull request, we need you to submit a CLA. You only

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ pytest==7.2.0
33
libcst==1.1.0
44
types-PyYAML==6.0.7
55
mypy==1.7.0
6+
black==24.4.0
7+
pre-commit==3.7.0

torchfix/common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,17 @@ def get_specific_arg(
7878

7979
@staticmethod
8080
def has_specific_arg(
81-
node: cst.Call, arg_name: str, position: Optional[int] = None
81+
node: cst.Call, arg_name: str, position: Optional[int] = None
8282
) -> bool:
8383
"""
8484
Check if the specific argument is present in a call.
8585
"""
86-
return TorchVisitor.get_specific_arg(
87-
node, arg_name,
88-
position if position is not None else -1
89-
) is not None
86+
return (
87+
TorchVisitor.get_specific_arg(
88+
node, arg_name, position if position is not None else -1
89+
)
90+
is not None
91+
)
9092

9193
def add_violation(
9294
self,

torchfix/torchfix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
def GET_ALL_ERROR_CODES():
4949
codes = set()
5050
for cls in ALL_VISITOR_CLS:
51-
codes |= set(error.error_code for error in cls.ERRORS)
51+
codes |= {error.error_code for error in cls.ERRORS}
5252
return codes
5353

5454

@@ -83,7 +83,7 @@ def get_visitors_with_error_codes(error_codes):
8383
# only correspond to one visitor.
8484
found = False
8585
for visitor_cls in ALL_VISITOR_CLS:
86-
if error_code in list(err.error_code for err in visitor_cls.ERRORS):
86+
if any(error_code == err.error_code for err in visitor_cls.ERRORS):
8787
visitor_classes.add(visitor_cls)
8888
found = True
8989
break

torchfix/visitors/misc/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ class TorchReentrantCheckpointVisitor(TorchVisitor):
5959
]
6060

6161
def visit_Call(self, node):
62-
if (self.get_qualified_name_for_call(node) ==
63-
"torch.utils.checkpoint.checkpoint" and
64-
not self.has_specific_arg(node, "use_reentrant")):
62+
if self.get_qualified_name_for_call(
63+
node
64+
) == "torch.utils.checkpoint.checkpoint" and not self.has_specific_arg(
65+
node, "use_reentrant"
66+
):
6567
# This codemod maybe unsafe correctness-wise
6668
# if reentrant behavior is actually needed,
6769
# so the changes need to be verified/tested.

torchfix/visitors/nonpublic/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ class TorchNonPublicAliasVisitor(TorchVisitor):
1919

2020
ERRORS: List[TorchError] = [
2121
TorchError(
22-
"TOR104", (
22+
"TOR104",
23+
(
2324
"Use of non-public function `{qualified_name}`, "
2425
"please use `{public_name}` instead"
2526
),
2627
),
2728
TorchError(
28-
"TOR105", (
29+
"TOR105",
30+
(
2931
"Import of non-public function `{qualified_name}`, "
3032
"please use `{public_name}` instead"
3133
),

torchfix/visitors/security/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ class TorchUnsafeLoadVisitor(TorchVisitor):
1515
(
1616
"`torch.load` without `weights_only` parameter is unsafe. "
1717
"Explicitly set `weights_only` to False only if you trust "
18-
"the data you load " "and full pickle functionality is needed,"
18+
"the data you load "
19+
"and full pickle functionality is needed,"
1920
" otherwise set `weights_only=True`."
2021
),
2122
)
2223
]
2324

2425
def visit_Call(self, node):
25-
if self.get_qualified_name_for_call(node) == "torch.load" and \
26-
not self.has_specific_arg(node, "weights_only"):
26+
if self.get_qualified_name_for_call(
27+
node
28+
) == "torch.load" and not self.has_specific_arg(node, "weights_only"):
2729
# Add `weights_only=True` if there is no `pickle_module`.
2830
# (do not add `weights_only=False` with `pickle_module`, as it
2931
# needs to be an explicit choice).
@@ -37,9 +39,7 @@ def visit_Call(self, node):
3739
weights_only_arg = cst.ensure_type(
3840
cst.parse_expression("f(weights_only=True)"), cst.Call
3941
).args[0]
40-
replacement = node.with_changes(
41-
args=node.args + (weights_only_arg,)
42-
)
42+
replacement = node.with_changes(args=node.args + (weights_only_arg,))
4343
self.add_violation(
4444
node,
4545
error_code=self.ERRORS[0].error_code,

0 commit comments

Comments
 (0)