Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Apr 15, 2024
1 parent da896d6 commit 7665af9
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 1 deletion.
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Documentation = "https://github.com/spencerwooo/torch-featurelayer/blob/main/REA

[project.optional-dependencies]
dev = ["mypy", "ruff"]
test = ["torchvision>=0.15,<0.18"]
test = ["torchvision>=0.15,<0.18", "pytest"]

[build-system]
requires = ["setuptools"]
Expand Down
14 changes: 14 additions & 0 deletions test/test_feature_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from torchvision.models import alexnet
from torch_featurelayer import FeatureLayer


def test_feature_layer():
model = alexnet()
hooked_model = FeatureLayer(model, 'features.12')

x = torch.randn(1, 3, 224, 224)
feature_output, output = hooked_model(x)

assert feature_output.shape == (1, 256, 6, 6)
assert output.shape == (1, 1000)
33 changes: 33 additions & 0 deletions test/test_feature_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torchvision.models import alexnet
from torch_featurelayer import FeatureLayers


def test_feature_layers():
model = alexnet()
layer_paths = [
'features.2',
'features.5',
'features.9',
'features.12',
'avgpool',
'classifier.2',
'classifier.4',
]
hooked_model = FeatureLayers(model, layer_paths)

x = torch.randn(1, 3, 224, 224)
feature_outputs, output = hooked_model(x)

feature_output_shapes = {
'features.2': (1, 64, 27, 27),
'features.5': (1, 192, 13, 13),
'features.9': (1, 256, 13, 13),
'features.12': (1, 256, 6, 6),
'avgpool': (1, 256, 6, 6),
'classifier.2': (1, 4096),
'classifier.4': (1, 4096),
}
for layer_path, feature_output in feature_outputs.items():
assert feature_output.shape == feature_output_shapes[layer_path]
assert output.shape == (1, 1000)
42 changes: 42 additions & 0 deletions test/test_get_layer_candidates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from torchvision.models import alexnet
from torch_featurelayer import get_layer_candidates


def test_get_layer_candidates_depth1():
model = alexnet()
candidates = list(get_layer_candidates(model, max_depth=1))
assert candidates == [
'features',
'avgpool',
'classifier',
]


def test_get_layer_candidates_depth2():
model = alexnet()
candidates = list(get_layer_candidates(model, max_depth=2))
assert candidates == [
'features',
'features.0',
'features.1',
'features.2',
'features.3',
'features.4',
'features.5',
'features.6',
'features.7',
'features.8',
'features.9',
'features.10',
'features.11',
'features.12',
'avgpool',
'classifier',
'classifier.0',
'classifier.1',
'classifier.2',
'classifier.3',
'classifier.4',
'classifier.5',
'classifier.6',
]

0 comments on commit 7665af9

Please sign in to comment.