Skip to content

Commit

Permalink
Add core functions and project files
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Apr 13, 2024
1 parent 527b1b2 commit c6a8740
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 0 deletions.
21 changes: 21 additions & 0 deletions examples/extract_one_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from torchvision.models import vgg11

from torch_featurelayer import FeatureLayer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load a pretrained VGG-11 model
model = vgg11(weights='DEFAULT').eval().to(device)

# Hook onto layer `features.15` of the model
layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)

# Forward pass an input tensor through the model
x = torch.randn(1, 3, 224, 224).to(device)
feature_output, output = hooked_model(x)

# Print the output shape
print(f'Feature layer output shape: {feature_output.shape}') # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}') # [1, 1000]
21 changes: 21 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[project]
name = "torch-featurelayer"
authors = [{ name = "spencerwooo", email = "[email protected]" }]
requires-python = ">=3.10,<3.13"
readme = "README.md"
dependencies = ["torch>=2.0,<2.3"]
dynamic = ["version"]

[project.optional-dependencies]
dev = ["mypy", "ruff", "torchvision"]

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.setuptools.dynamic]
version = { attr = "featurelayer.__version__" }

[tool.ruff]
line-length = 120
format.quote-style = "single"
50 changes: 50 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml
filelock==3.13.4
# via torch
fsspec==2024.3.1
# via torch
jinja2==3.1.3
# via torch
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.3
# via torch
nvidia-cublas-cu12==12.1.3.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.19.3
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
sympy==1.12
# via torch
torch==2.2.2
typing-extensions==4.11.0
# via torch
5 changes: 5 additions & 0 deletions src/torch_featurelayer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from torch_featurelayer.feature_layer import FeatureLayer
from torch_featurelayer.feature_layers import FeatureLayers

__version__ = '0.1.0'
__all__ = ['FeatureLayer', 'FeatureLayers']
57 changes: 57 additions & 0 deletions src/torch_featurelayer/feature_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any

import torch

from torch_featurelayer.rgetattr import rgetattr


class FeatureLayer:
"""Wraps a model and provides a hook to access the output of a feature layer.
Feature layer paths are defined via dot notation:
Args:
model: The model containing the feature layer.
feature_layer_path: The path to the feature layer in the model.
Attributes:
_model (torch.nn.Module): The model containing the feature layer.
feature_layer_path (str): The path to the feature layer in the model.
feature_layer_output (torch.Tensor): The output of the feature layer.
"""

def __init__(self, model: torch.nn.Module, feature_layer_path: str):
self._model: torch.nn.Module = model
self.feature_layer_path: str = feature_layer_path
self.feature_layer_output: torch.Tensor | None = None # output of the feature layer (must be global)

def __call__(self, *args: Any, **kwargs: Any) -> tuple[torch.Tensor | None, torch.Tensor]:
"""Perform a forward pass through the model and update the hooked feature layer.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
tuple: A tuple containing the feature layer output and the model output.
"""

h: torch.utils.hooks.RemovableHandle | None = None # hook handle

def hook(module: torch.nn.Module, input: tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
self.feature_layer_output = output

try:
# Register hook
layer = rgetattr(self._model, self.feature_layer_path)
h = layer.register_forward_hook(hook)
except AttributeError as e:
raise AttributeError(f'Layer {self.feature_layer_path} not found in model.') from e

# Forward pass and update hooked feature layer
output: torch.Tensor = self._model(*args, **kwargs)

# Remove hook
h.remove()

return self.feature_layer_output, output
67 changes: 67 additions & 0 deletions src/torch_featurelayer/feature_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from collections import OrderedDict
from typing import Any

import torch

from torch_featurelayer.rgetattr import rgetattr


class FeatureLayers:
"""Wraps a model and provides hooks to access the output of multiple feature layers.
Args:
model: The model containing the feature layer.
feature_layer_paths: A list of paths to the feature layers in the model.
Attributes:
_model (torch.nn.Module): The model containing the feature layer.
feature_layer_paths (list[str]): A list of paths to the feature layers in the model.
feature_layer_outputs (list[torch.Tensor]): The output of the feature layers.
"""

def __init__(self, model: torch.nn.Module, feature_layer_paths: list[str]):
self._model: torch.nn.Module = model
self.feature_layer_paths: list[str] = feature_layer_paths
self.feature_layer_outputs: dict[str, torch.Tensor] = OrderedDict().fromkeys(feature_layer_paths, None)

def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""Perform a forward pass through the model and update the hooked feature layers.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
tuple: A tuple containing the feature layer outputs and the model output.
"""

hs = [] # hook handles

for feature_layer_path in self.feature_layer_paths:

def hook(
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
feature_layer_path=feature_layer_path,
):
self.feature_layer_outputs[feature_layer_path] = output

try:
# Register hook
layer = rgetattr(self._model, feature_layer_path)
h = layer.register_forward_hook(hook)
hs.append(h)
except AttributeError:
# skip hook register if layer not found
print(f'Warning: Layer {feature_layer_path} not found in model, skipping hook register.')
continue

# Forward pass and update hooked feature layers
output: torch.Tensor = self._model(*args, **kwargs)

# Remove hooks
for h in hs:
h.remove()

return self.feature_layer_outputs, output
22 changes: 22 additions & 0 deletions src/torch_featurelayer/rgetattr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import functools
from typing import Any


def rgetattr(obj: Any, attr: str, *args: Any) -> Any:
"""Recursively gets an attribute from an object.
https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties/31174427#31174427
Args:
obj: The object to retrieve the attribute from.
attr: The attribute to retrieve. Can be a nested attribute separated by dots.
*args: Optional default values to return if the attribute is not found.
Returns:
The value of the attribute if found, otherwise the default value(s) specified by *args.
"""

def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)

return functools.reduce(_getattr, [obj] + attr.split('.'))

0 comments on commit c6a8740

Please sign in to comment.