From 8e573382763203151c9d2c8f23cc89ab4fe8d334 Mon Sep 17 00:00:00 2001 From: spencerwooo Date: Sat, 13 Apr 2024 20:30:08 +0800 Subject: [PATCH] Add CI and fix types --- .github/workflows/ci.yml | 33 ++++++++++++++++++++++++ src/torch_featurelayer/feature_layer.py | 3 ++- src/torch_featurelayer/feature_layers.py | 5 ++-- 3 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..53fbc82 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,33 @@ +name: lint + +on: + pull_request: + push: + branches: [main] + workflow_dispatch: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: "pip" # caching pip dependencies + + - run: pip install -e '.[dev]' + + - name: Ruff format + uses: chartboost/ruff-action@v1 + with: + args: format --check + + - name: Ruff lint + uses: chartboost/ruff-action@v1 + + - name: Mypy + run: | + mypy src/ diff --git a/src/torch_featurelayer/feature_layer.py b/src/torch_featurelayer/feature_layer.py index d07f45c..0030cfb 100644 --- a/src/torch_featurelayer/feature_layer.py +++ b/src/torch_featurelayer/feature_layer.py @@ -52,6 +52,7 @@ def hook(module: torch.nn.Module, input: tuple[torch.Tensor, ...], output: torch output: torch.Tensor = self._model(*args, **kwargs) # Remove hook - h.remove() + if h is not None: + h.remove() return self.feature_layer_output, output diff --git a/src/torch_featurelayer/feature_layers.py b/src/torch_featurelayer/feature_layers.py index d35d002..59ac27f 100644 --- a/src/torch_featurelayer/feature_layers.py +++ b/src/torch_featurelayer/feature_layers.py @@ -1,4 +1,3 @@ -from collections import OrderedDict from typing import Any import torch @@ -22,9 +21,9 @@ class FeatureLayers: def __init__(self, model: torch.nn.Module, feature_layer_paths: list[str]): self._model: torch.nn.Module = model self.feature_layer_paths: list[str] = feature_layer_paths - self.feature_layer_outputs: dict[str, torch.Tensor] = OrderedDict().fromkeys(feature_layer_paths, None) + self.feature_layer_outputs: dict[str, torch.Tensor | None] = dict().fromkeys(feature_layer_paths, None) - def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]: """Perform a forward pass through the model and update the hooked feature layers. Args: