diff --git a/README.md b/README.md index 99d46af..38fdb5e 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,8 @@ ```python import torch -from torchvision.models import vgg11 - from torch_featurelayer import FeatureLayer +from torchvision.models import vgg11 # Load a pretrained VGG-11 model model = vgg11(weights='DEFAULT').eval() @@ -43,7 +42,7 @@ Check the [examples](./examples/) directory for more. > `torch_featurelayer.FeatureLayers(model: torch.nn.Module, feature_layer_paths: list[str])` -> `torch_featurelayer.get_layer_candidates(module: nn.Module, max_depth: int = 1)` +> `torch_featurelayer.get_layer_candidates(module: torch.nn.Module, max_depth: int = 1)` ## License diff --git a/examples/extract_multiple_layers.py b/examples/extract_multiple_layers.py index bcf3767..dd1bfa1 100644 --- a/examples/extract_multiple_layers.py +++ b/examples/extract_multiple_layers.py @@ -1,7 +1,6 @@ import torch -from torchvision.models import resnet50 - from torch_featurelayer import FeatureLayers +from torchvision.models import resnet50 # Load a pretrained ResNet-50 model model = resnet50(weights='DEFAULT').eval() diff --git a/examples/extract_one_layer.py b/examples/extract_one_layer.py index 0be3aa3..2047d65 100644 --- a/examples/extract_one_layer.py +++ b/examples/extract_one_layer.py @@ -1,7 +1,6 @@ import torch -from torchvision.models import vgg11 - from torch_featurelayer import FeatureLayer +from torchvision.models import vgg11 # Load a pretrained VGG-11 model model = vgg11(weights='DEFAULT').eval() diff --git a/examples/list_layer_candidates.py b/examples/list_layer_candidates.py index beece94..84cbb43 100644 --- a/examples/list_layer_candidates.py +++ b/examples/list_layer_candidates.py @@ -1,5 +1,5 @@ -from torchvision.models import resnet18 from torch_featurelayer import get_layer_candidates +from torchvision.models import resnet18 # ResNet( # (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)