Skip to content

Commit

Permalink
Add docs for the API exposed
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Apr 15, 2024
1 parent 2af3274 commit da896d6
Showing 1 changed file with 73 additions and 10 deletions.
83 changes: 73 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,47 @@
🧠 Simple utility functions and wrappers for hooking onto layers within PyTorch models for feature extraction.

> [!TIP]
> This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using [`torchvision.models.feature_extraction`](https://pytorch.org/vision/stable/feature_extraction.html), or check the official [`torch.fx`](https://pytorch.org/docs/stable/fx.html).
> This library is intended to be a simplified and well-documented implementation for extracting a PyTorch model's intermediate layer output(s). For a more sophisticated and complete implementation, either consider using [`torchvision.models.feature_extraction`](https://pytorch.org/vision/stable/feature_extraction.html), or check the official [`torch.fx`](https://pytorch.org/docs/stable/fx.html).
## Install

```shell
pip install torch-featurelayer
```

## Usage

Imports:

```python
import torch
from torch_featurelayer import FeatureLayer
from torchvision.models import vgg11
from torch_featurelayer import FeatureLayer
```

Load a pretrained VGG-11 model:

# Load a pretrained VGG-11 model
```python
model = vgg11(weights='DEFAULT').eval()
```

# Hook onto layer `features.15` of the model
Hook onto layer `features.15` of the model:

```python
layer_path = 'features.15'
hooked_model = FeatureLayer(model, layer_path)
```

# Forward pass an input tensor through the model
Forward pass an input tensor through the model:

```python
x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)
```

`feature_output` is the output of layer `features.15`. Print the output shape:

# Print the output shape
```python
print(f'Feature layer output shape: {feature_output.shape}') # [1, 512, 14, 14]
print(f'Model output shape: {output.shape}') # [1, 1000]
```
Expand All @@ -38,12 +58,55 @@ Check the [examples](./examples/) directory for more.

## API

> `torch_featurelayer.FeatureLayer(model: torch.nn.Module, feature_layer_path: str)`
### `torch_featurelayer.FeatureLayer`

The `FeatureLayer` class wraps a model and provides a hook to access the output of a specific feature layer.

- `__init__(self, model: torch.nn.Module, feature_layer_path: str)`

Initializes the `FeatureLayer` instance.

- `model`: The model containing the feature layer.
- `feature_layer_path`: The path to the feature layer in the model.

- `__call__(self, *args: Any, **kwargs: Any) -> tuple[torch.Tensor | None, torch.Tensor]`

Performs a forward pass through the model and updates the hooked feature layer.

- `*args`: Variable length argument list.
- `**kwargs`: Arbitrary keyword arguments.

Returns a tuple containing the feature layer output and the model output.

### `torch_featurelayer.FeatureLayers`

The `FeatureLayers` class wraps a model and provides hooks to access the output of multiple feature layers.

- `__init__(self, model: torch.nn.Module, feature_layer_paths: list[str])`

Initializes the `FeatureLayers` instance.

- `model`: The model containing the feature layers.
- `feature_layer_paths`: A list of paths to the feature layers in the model.

- `__call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], torch.Tensor]`

Performs a forward pass through the model and updates the hooked feature layers.

- `*args`: Variable length argument list.
- `**kwargs`: Arbitrary keyword arguments.

Returns a tuple containing the feature layer outputs and the model output.

### `torch_featurelayer.get_layer_candidates(module: torch.nn.Module, max_depth: int = 1) -> Generator[str, None, None]`

The `get_layer_candidates` function returns a generator of layer paths for a given model up to a specified depth.

> `torch_featurelayer.FeatureLayers(model: torch.nn.Module, feature_layer_paths: list[str])`
- `model`: The model to get layer paths from.
- `max_depth`: The maximum depth to traverse in the model's layers.

> `torch_featurelayer.get_layer_candidates(module: torch.nn.Module, max_depth: int = 1)`
Returns a generator of layer paths.

## License

[MIT](./LICENSE)
[MIT](./LICENSE)

0 comments on commit da896d6

Please sign in to comment.