Skip to content

Commit

Permalink
Add exception tests and assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Apr 16, 2024
1 parent 7665af9 commit 7b49d54
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 9 deletions.
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
{
"python.testing.pytestArgs": [
"."
"test",
"--color=yes",
"-o",
"log_cli=true"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
Expand Down
1 change: 1 addition & 0 deletions src/torch_featurelayer/feature_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class FeatureLayer:
"""

def __init__(self, model: torch.nn.Module, feature_layer_path: str):
assert isinstance(model, torch.nn.Module), 'model must be a torch.nn.Module'
self._model: torch.nn.Module = model
self.feature_layer_path: str = feature_layer_path
self.feature_layer_output: torch.Tensor | None = None # output of the feature layer (must be global)
Expand Down
1 change: 1 addition & 0 deletions src/torch_featurelayer/feature_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FeatureLayers:
"""

def __init__(self, model: torch.nn.Module, feature_layer_paths: list[str]):
assert isinstance(model, torch.nn.Module), 'model must be a torch.nn.Module'
self._model: torch.nn.Module = model
self.feature_layer_paths: list[str] = feature_layer_paths
self.feature_layer_outputs: dict[str, torch.Tensor | None] = dict().fromkeys(feature_layer_paths, None)
Expand Down
3 changes: 3 additions & 0 deletions src/torch_featurelayer/layer_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def get_layer_candidates(module: nn.Module, max_depth: int = 1) -> Generator[str
Generator: A generator of the model's layer candidates in dot notation.
"""

assert max_depth >= 0, 'max_depth must be a non-negative integer'
assert isinstance(module, nn.Module), 'model must be a torch.nn.Module'

def get_modules(model: nn.Module, prefix: str = '', depth: int = 0) -> Generator[str, None, None]:
if prefix:
yield prefix
Expand Down
22 changes: 21 additions & 1 deletion test/test_feature_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch
from torchvision.models import alexnet
from torch_featurelayer import FeatureLayer
from torchvision.models import alexnet


def test_feature_layer():
Expand All @@ -12,3 +13,22 @@ def test_feature_layer():

assert feature_output.shape == (1, 256, 6, 6)
assert output.shape == (1, 1000)


def test_feature_layer_invalid_model():
with pytest.raises(AssertionError) as e:
model = 'invalid_model'
_ = FeatureLayer(model, 'features.12')

assert str(e.value) == 'model must be a torch.nn.Module'


def test_feature_layer_nonexistent_layer_path():
with pytest.raises(AttributeError) as e:
model = alexnet()
hooked_model = FeatureLayer(model, 'path.to.blah')

x = torch.randn(1, 3, 224, 224)
_ = hooked_model(x)

assert str(e.value) == 'Layer path.to.blah not found in model.'
33 changes: 32 additions & 1 deletion test/test_feature_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch
from torchvision.models import alexnet
from torch_featurelayer import FeatureLayers
from torchvision.models import alexnet


def test_feature_layers():
Expand Down Expand Up @@ -31,3 +32,33 @@ def test_feature_layers():
for layer_path, feature_output in feature_outputs.items():
assert feature_output.shape == feature_output_shapes[layer_path]
assert output.shape == (1, 1000)


def test_feature_layers_invalid_model():
with pytest.raises(AssertionError) as e:
model = 'invalid_model'
_ = FeatureLayers(model, ['features.2', 'features.5', 'features.9'])
assert str(e.value) == 'model must be a torch.nn.Module'


def test_feature_layers_contain_nonexistent_layer_path():
model = alexnet()
hooked_model = FeatureLayers(model, ['features.1', 'path.to.blah'])

x = torch.randn(1, 3, 224, 224)
feature_outputs, outputs = hooked_model(x)

assert feature_outputs['features.1'] is not None
assert feature_outputs['path.to.blah'] is None # nonexistent layer path ignored
assert outputs is not None


def test_feature_layers_empty_layer_paths():
model = alexnet()
hooked_model = FeatureLayers(model, [])

x = torch.randn(1, 3, 224, 224)
feature_outputs, output = hooked_model(x)

assert feature_outputs == {}
assert output is not None
23 changes: 17 additions & 6 deletions test/test_get_layer_candidates.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from torchvision.models import alexnet
import pytest
from torch_featurelayer import get_layer_candidates
from torchvision.models import alexnet


def test_get_layer_candidates_depth1():
model = alexnet()
candidates = list(get_layer_candidates(model, max_depth=1))
assert candidates == [
'features',
'avgpool',
'classifier',
]
assert candidates == ['features', 'avgpool', 'classifier']


def test_get_layer_candidates_depth2():
Expand Down Expand Up @@ -40,3 +37,17 @@ def test_get_layer_candidates_depth2():
'classifier.5',
'classifier.6',
]


def test_get_layer_candidates_invalid_model():
with pytest.raises(AssertionError) as e:
model = 'invalid_model'
get_layer_candidates(model, max_depth=1)
assert str(e.value) == 'model must be a torch.nn.Module'


def test_get_layer_candidates_negative_depth():
with pytest.raises(AssertionError) as e:
model = alexnet()
get_layer_candidates(model, max_depth=-1)
assert str(e.value) == 'max_depth must be a non-negative integer'

0 comments on commit 7b49d54

Please sign in to comment.