diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..53fbc82 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,33 @@ +name: lint + +on: + pull_request: + push: + branches: [main] + workflow_dispatch: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: "pip" # caching pip dependencies + + - run: pip install -e '.[dev]' + + - name: Ruff format + uses: chartboost/ruff-action@v1 + with: + args: format --check + + - name: Ruff lint + uses: chartboost/ruff-action@v1 + + - name: Mypy + run: | + mypy src/ diff --git a/src/torch_featurelayer/feature_layer.py b/src/torch_featurelayer/feature_layer.py index d07f45c..0030cfb 100644 --- a/src/torch_featurelayer/feature_layer.py +++ b/src/torch_featurelayer/feature_layer.py @@ -52,6 +52,7 @@ def hook(module: torch.nn.Module, input: tuple[torch.Tensor, ...], output: torch output: torch.Tensor = self._model(*args, **kwargs) # Remove hook - h.remove() + if h is not None: + h.remove() return self.feature_layer_output, output diff --git a/src/torch_featurelayer/feature_layers.py b/src/torch_featurelayer/feature_layers.py index d35d002..59ac27f 100644 --- a/src/torch_featurelayer/feature_layers.py +++ b/src/torch_featurelayer/feature_layers.py @@ -1,4 +1,3 @@ -from collections import OrderedDict from typing import Any import torch @@ -22,9 +21,9 @@ class FeatureLayers: def __init__(self, model: torch.nn.Module, feature_layer_paths: list[str]): self._model: torch.nn.Module = model self.feature_layer_paths: list[str] = feature_layer_paths - self.feature_layer_outputs: dict[str, torch.Tensor] = OrderedDict().fromkeys(feature_layer_paths, None) + self.feature_layer_outputs: dict[str, torch.Tensor | None] = dict().fromkeys(feature_layer_paths, None) - def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]: """Perform a forward pass through the model and update the hooked feature layers. Args: