Skip to content

Commit

Permalink
Add pre-commit hooks add lint docs (#51)
Browse files Browse the repository at this point in the history
fix #48
* add pre-commit hooks
* add black to CI
* add contributor guidelines re: linting
* reformat code using black
* address mypy warnings re: generators

### Testing:

1.
```
pip install -r requirements-dev.txt
pre-commit install
pre-commit run --all-files
```
<img width="582" alt="image"
src="https://github.com/pytorch-labs/torchfix/assets/108101595/734960dd-4ca2-4d8a-b58e-dc917332b7f9">


2. 
```
git commit -m 'test'
```

3.
Black CI works:
https://github.com/pytorch-labs/torchfix/actions/runs/8791096314/job/24124561579?pr=51
  • Loading branch information
izaitsevfb authored Apr 22, 2024
1 parent a71baf1 commit effb27e
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 20 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ exclude = ./tests/fixtures/

# Match black tool's default.
max-line-length = 88
extend-ignore = E203
3 changes: 3 additions & 0 deletions .github/workflows/test-torchfix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ jobs:
- name: Run mypy
run: |
mypy .
- name: Run black
run: |
black --check .
23 changes: 23 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
repos:
- repo: local
hooks:
- id: black
name: black
entry: black
language: system
types: [python]
args: ["--config=./pyproject.toml"]
exclude: ^tests/fixtures/
- id: flake8
name: flake8
entry: flake8
language: system
types: [python]
args: ["--config=./.flake8"]
exclude: ^tests/fixtures/
- id: mypy
name: mypy
entry: mypy
language: system
types: [python]
exclude: ^tests/fixtures/
28 changes: 26 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,34 @@ We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
4. Ensure the test suite passes (`pytest tests`).
5. Make sure your code lints (see Linting section below).
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Linting

We use `black`, `flake8`, and `mypy` to lint the code.
```
pip install -r requirements-dev.txt
```

Linting via pre-commit hook:
```
# install pre-commit hooks for the first time
pre-commit install
# manually run pre-commit hooks on all files (runs all linters)
pre-commit run --all-files
```

Manually running individual linters:
```
black .
flake8
mypy .
```


## Contributor License Agreement ("CLA")

In order to accept your pull request, we need you to submit a CLA. You only
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ pytest==7.2.0
libcst==1.1.0
types-PyYAML==6.0.7
mypy==1.7.0
black==24.4.0
pre-commit==3.7.0
12 changes: 7 additions & 5 deletions torchfix/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,17 @@ def get_specific_arg(

@staticmethod
def has_specific_arg(
node: cst.Call, arg_name: str, position: Optional[int] = None
node: cst.Call, arg_name: str, position: Optional[int] = None
) -> bool:
"""
Check if the specific argument is present in a call.
"""
return TorchVisitor.get_specific_arg(
node, arg_name,
position if position is not None else -1
) is not None
return (
TorchVisitor.get_specific_arg(
node, arg_name, position if position is not None else -1
)
is not None
)

def add_violation(
self,
Expand Down
4 changes: 2 additions & 2 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
def GET_ALL_ERROR_CODES():
codes = set()
for cls in ALL_VISITOR_CLS:
codes |= set(error.error_code for error in cls.ERRORS)
codes |= {error.error_code for error in cls.ERRORS}
return codes


Expand Down Expand Up @@ -83,7 +83,7 @@ def get_visitors_with_error_codes(error_codes):
# only correspond to one visitor.
found = False
for visitor_cls in ALL_VISITOR_CLS:
if error_code in list(err.error_code for err in visitor_cls.ERRORS):
if any(error_code == err.error_code for err in visitor_cls.ERRORS):
visitor_classes.add(visitor_cls)
found = True
break
Expand Down
8 changes: 5 additions & 3 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ class TorchReentrantCheckpointVisitor(TorchVisitor):
]

def visit_Call(self, node):
if (self.get_qualified_name_for_call(node) ==
"torch.utils.checkpoint.checkpoint" and
not self.has_specific_arg(node, "use_reentrant")):
if self.get_qualified_name_for_call(
node
) == "torch.utils.checkpoint.checkpoint" and not self.has_specific_arg(
node, "use_reentrant"
):
# This codemod maybe unsafe correctness-wise
# if reentrant behavior is actually needed,
# so the changes need to be verified/tested.
Expand Down
6 changes: 4 additions & 2 deletions torchfix/visitors/nonpublic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ class TorchNonPublicAliasVisitor(TorchVisitor):

ERRORS: List[TorchError] = [
TorchError(
"TOR104", (
"TOR104",
(
"Use of non-public function `{qualified_name}`, "
"please use `{public_name}` instead"
),
),
TorchError(
"TOR105", (
"TOR105",
(
"Import of non-public function `{qualified_name}`, "
"please use `{public_name}` instead"
),
Expand Down
12 changes: 6 additions & 6 deletions torchfix/visitors/security/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ class TorchUnsafeLoadVisitor(TorchVisitor):
(
"`torch.load` without `weights_only` parameter is unsafe. "
"Explicitly set `weights_only` to False only if you trust "
"the data you load " "and full pickle functionality is needed,"
"the data you load "
"and full pickle functionality is needed,"
" otherwise set `weights_only=True`."
),
)
]

def visit_Call(self, node):
if self.get_qualified_name_for_call(node) == "torch.load" and \
not self.has_specific_arg(node, "weights_only"):
if self.get_qualified_name_for_call(
node
) == "torch.load" and not self.has_specific_arg(node, "weights_only"):
# Add `weights_only=True` if there is no `pickle_module`.
# (do not add `weights_only=False` with `pickle_module`, as it
# needs to be an explicit choice).
Expand All @@ -37,9 +39,7 @@ def visit_Call(self, node):
weights_only_arg = cst.ensure_type(
cst.parse_expression("f(weights_only=True)"), cst.Call
).args[0]
replacement = node.with_changes(
args=node.args + (weights_only_arg,)
)
replacement = node.with_changes(args=node.args + (weights_only_arg,))
self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
Expand Down

0 comments on commit effb27e

Please sign in to comment.