Skip to content

Commit

Permalink
Add CI and fix types
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Apr 13, 2024
1 parent c6a8740 commit 8e57338
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: lint

on:
pull_request:
push:
branches: [main]
workflow_dispatch:

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: "pip" # caching pip dependencies

- run: pip install -e '.[dev]'

- name: Ruff format
uses: chartboost/ruff-action@v1
with:
args: format --check

- name: Ruff lint
uses: chartboost/ruff-action@v1

- name: Mypy
run: |
mypy src/
3 changes: 2 additions & 1 deletion src/torch_featurelayer/feature_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def hook(module: torch.nn.Module, input: tuple[torch.Tensor, ...], output: torch
output: torch.Tensor = self._model(*args, **kwargs)

# Remove hook
h.remove()
if h is not None:
h.remove()

return self.feature_layer_output, output
5 changes: 2 additions & 3 deletions src/torch_featurelayer/feature_layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections import OrderedDict
from typing import Any

import torch
Expand All @@ -22,9 +21,9 @@ class FeatureLayers:
def __init__(self, model: torch.nn.Module, feature_layer_paths: list[str]):
self._model: torch.nn.Module = model
self.feature_layer_paths: list[str] = feature_layer_paths
self.feature_layer_outputs: dict[str, torch.Tensor] = OrderedDict().fromkeys(feature_layer_paths, None)
self.feature_layer_outputs: dict[str, torch.Tensor | None] = dict().fromkeys(feature_layer_paths, None)

def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]:
"""Perform a forward pass through the model and update the hooked feature layers.
Args:
Expand Down

0 comments on commit 8e57338

Please sign in to comment.