Skip to content

Commit

Permalink
Add README and publish actions
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Apr 15, 2024
1 parent c3907ae commit 5da1508
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 3 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ jobs:
python-version: "3.12"
cache: "pip" # caching pip dependencies

- run: pip install -e '.[dev]'
- name: Setup deps
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[dev]'
- name: Ruff format
uses: chartboost/ruff-action@v1
Expand Down
33 changes: 33 additions & 0 deletions .github/workflows/pypi-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: pypi

on:
release:
types: [published]
workflow_dispatch:

jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: "pip" # caching pip dependencies

- name: Set up build tools
run: |
python -m pip install --upgrade pip
python -m pip install -e .
python -m pip install setuptools build
- name: Build package
run: python -m build

- name: Publish package
uses: pypa/gh-action-pypi-publish@3f824c73d94f6ad4925913f56d31c29caba41990
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
47 changes: 46 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,47 @@
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
[![lint](https://github.com/spencerwooo/torch-featurelayer/actions/workflows/ci.yml/badge.svg)](https://github.com/spencerwooo/torch-featurelayer/actions/workflows/ci.yml)

# torch-featurelayer
Useful utility functions and wrappers for hooking onto layers within PyTorch models (nn.Module) for feature extraction

Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.

| 🧠 | For a more sophisticated and complete implementation, check the official [`torch.fx`](https://pytorch.org/docs/stable/fx.html). |
| --- | :------------------------------------------------------------------------------------------------------------------------------ |

## Usage

```python
import torch
from torchvision.models import vgg11

from torch_featurelayer import FeatureLayer

# Load a pretrained VGG-11 model
model = vgg11(weights='DEFAULT').eval()

# Hook onto layer `features.15` of the model
layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)

# Forward pass an input tensor through the model
x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)

# Print the output shape
print(f'Feature layer output shape: {feature_output.shape}') # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}') # [1, 1000]
```

Check the [examples](./examples/) directory for more.

## API

> `torch_featurelayer.FeatureLayer(model: torch.nn.Module, feature_layer_path: str)`
> `torch_featurelayer.FeatureLayers(model: torch.nn.Module, feature_layer_paths: list[str])`
> `torch_featurelayer.get_layer_candidates(module: nn.Module, max_depth: int = 1)`
## License

[MIT](./LICENSE)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ dependencies = ["torch>=2.0,<2.3"]
dynamic = ["version"]

[project.optional-dependencies]
dev = ["mypy", "ruff", "torchvision"]
dev = ["mypy", "ruff"]
test = ["torchvision>=0.15,<0.18"]

[build-system]
requires = ["setuptools"]
Expand Down

0 comments on commit 5da1508

Please sign in to comment.