diff --git a/examples/extract_one_layer.py b/examples/extract_one_layer.py new file mode 100644 index 0000000..ee1bf52 --- /dev/null +++ b/examples/extract_one_layer.py @@ -0,0 +1,21 @@ +import torch +from torchvision.models import vgg11 + +from torch_featurelayer import FeatureLayer + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# Load a pretrained VGG-11 model +model = vgg11(weights='DEFAULT').eval().to(device) + +# Hook onto layer `features.15` of the model +layer_path = 'features.15' +hooked_model = FeatureLayer(model, layer_path) + +# Forward pass an input tensor through the model +x = torch.randn(1, 3, 224, 224).to(device) +feature_output, output = hooked_model(x) + +# Print the output shape +print(f'Feature layer output shape: {feature_output.shape}') # [1, 512, 14, 14] +print(f'Model output shape: {output.shape}') # [1, 1000] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..234a9ae --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[project] +name = "torch-featurelayer" +authors = [{ name = "spencerwooo", email = "spencer.woo@outlook.com" }] +requires-python = ">=3.10,<3.13" +readme = "README.md" +dependencies = ["torch>=2.0,<2.3"] +dynamic = ["version"] + +[project.optional-dependencies] +dev = ["mypy", "ruff", "torchvision"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.dynamic] +version = { attr = "featurelayer.__version__" } + +[tool.ruff] +line-length = 120 +format.quote-style = "single" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..204c1e7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,50 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml +filelock==3.13.4 + # via torch +fsspec==2024.3.1 + # via torch +jinja2==3.1.3 + # via torch +markupsafe==2.1.5 + # via jinja2 +mpmath==1.3.0 + # via sympy +networkx==3.3 + # via torch +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.19.3 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +sympy==1.12 + # via torch +torch==2.2.2 +typing-extensions==4.11.0 + # via torch diff --git a/src/torch_featurelayer/__init__.py b/src/torch_featurelayer/__init__.py new file mode 100644 index 0000000..ab88ce4 --- /dev/null +++ b/src/torch_featurelayer/__init__.py @@ -0,0 +1,5 @@ +from torch_featurelayer.feature_layer import FeatureLayer +from torch_featurelayer.feature_layers import FeatureLayers + +__version__ = '0.1.0' +__all__ = ['FeatureLayer', 'FeatureLayers'] diff --git a/src/torch_featurelayer/feature_layer.py b/src/torch_featurelayer/feature_layer.py new file mode 100644 index 0000000..d07f45c --- /dev/null +++ b/src/torch_featurelayer/feature_layer.py @@ -0,0 +1,57 @@ +from typing import Any + +import torch + +from torch_featurelayer.rgetattr import rgetattr + + +class FeatureLayer: + """Wraps a model and provides a hook to access the output of a feature layer. + + Feature layer paths are defined via dot notation: + + Args: + model: The model containing the feature layer. + feature_layer_path: The path to the feature layer in the model. + + Attributes: + _model (torch.nn.Module): The model containing the feature layer. + feature_layer_path (str): The path to the feature layer in the model. + feature_layer_output (torch.Tensor): The output of the feature layer. + """ + + def __init__(self, model: torch.nn.Module, feature_layer_path: str): + self._model: torch.nn.Module = model + self.feature_layer_path: str = feature_layer_path + self.feature_layer_output: torch.Tensor | None = None # output of the feature layer (must be global) + + def __call__(self, *args: Any, **kwargs: Any) -> tuple[torch.Tensor | None, torch.Tensor]: + """Perform a forward pass through the model and update the hooked feature layer. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + tuple: A tuple containing the feature layer output and the model output. + """ + + h: torch.utils.hooks.RemovableHandle | None = None # hook handle + + def hook(module: torch.nn.Module, input: tuple[torch.Tensor, ...], output: torch.Tensor) -> None: + self.feature_layer_output = output + + try: + # Register hook + layer = rgetattr(self._model, self.feature_layer_path) + h = layer.register_forward_hook(hook) + except AttributeError as e: + raise AttributeError(f'Layer {self.feature_layer_path} not found in model.') from e + + # Forward pass and update hooked feature layer + output: torch.Tensor = self._model(*args, **kwargs) + + # Remove hook + h.remove() + + return self.feature_layer_output, output diff --git a/src/torch_featurelayer/feature_layers.py b/src/torch_featurelayer/feature_layers.py new file mode 100644 index 0000000..d35d002 --- /dev/null +++ b/src/torch_featurelayer/feature_layers.py @@ -0,0 +1,67 @@ +from collections import OrderedDict +from typing import Any + +import torch + +from torch_featurelayer.rgetattr import rgetattr + + +class FeatureLayers: + """Wraps a model and provides hooks to access the output of multiple feature layers. + + Args: + model: The model containing the feature layer. + feature_layer_paths: A list of paths to the feature layers in the model. + + Attributes: + _model (torch.nn.Module): The model containing the feature layer. + feature_layer_paths (list[str]): A list of paths to the feature layers in the model. + feature_layer_outputs (list[torch.Tensor]): The output of the feature layers. + """ + + def __init__(self, model: torch.nn.Module, feature_layer_paths: list[str]): + self._model: torch.nn.Module = model + self.feature_layer_paths: list[str] = feature_layer_paths + self.feature_layer_outputs: dict[str, torch.Tensor] = OrderedDict().fromkeys(feature_layer_paths, None) + + def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """Perform a forward pass through the model and update the hooked feature layers. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + tuple: A tuple containing the feature layer outputs and the model output. + """ + + hs = [] # hook handles + + for feature_layer_path in self.feature_layer_paths: + + def hook( + module: torch.nn.Module, + input: torch.Tensor, + output: torch.Tensor, + feature_layer_path=feature_layer_path, + ): + self.feature_layer_outputs[feature_layer_path] = output + + try: + # Register hook + layer = rgetattr(self._model, feature_layer_path) + h = layer.register_forward_hook(hook) + hs.append(h) + except AttributeError: + # skip hook register if layer not found + print(f'Warning: Layer {feature_layer_path} not found in model, skipping hook register.') + continue + + # Forward pass and update hooked feature layers + output: torch.Tensor = self._model(*args, **kwargs) + + # Remove hooks + for h in hs: + h.remove() + + return self.feature_layer_outputs, output diff --git a/src/torch_featurelayer/rgetattr.py b/src/torch_featurelayer/rgetattr.py new file mode 100644 index 0000000..9936dfc --- /dev/null +++ b/src/torch_featurelayer/rgetattr.py @@ -0,0 +1,22 @@ +import functools +from typing import Any + + +def rgetattr(obj: Any, attr: str, *args: Any) -> Any: + """Recursively gets an attribute from an object. + + https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties/31174427#31174427 + + Args: + obj: The object to retrieve the attribute from. + attr: The attribute to retrieve. Can be a nested attribute separated by dots. + *args: Optional default values to return if the attribute is not found. + + Returns: + The value of the attribute if found, otherwise the default value(s) specified by *args. + """ + + def _getattr(obj: Any, attr: str) -> Any: + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.'))