Skip to content

Commit

Permalink
Format imports
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerwooo committed Apr 15, 2024
1 parent a8f8d55 commit 2af3274
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 8 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

```python
import torch
from torchvision.models import vgg11

from torch_featurelayer import FeatureLayer
from torchvision.models import vgg11

# Load a pretrained VGG-11 model
model = vgg11(weights='DEFAULT').eval()
Expand All @@ -43,7 +42,7 @@ Check the [examples](./examples/) directory for more.
> `torch_featurelayer.FeatureLayers(model: torch.nn.Module, feature_layer_paths: list[str])`
> `torch_featurelayer.get_layer_candidates(module: nn.Module, max_depth: int = 1)`
> `torch_featurelayer.get_layer_candidates(module: torch.nn.Module, max_depth: int = 1)`
## License

Expand Down
3 changes: 1 addition & 2 deletions examples/extract_multiple_layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from torchvision.models import resnet50

from torch_featurelayer import FeatureLayers
from torchvision.models import resnet50

# Load a pretrained ResNet-50 model
model = resnet50(weights='DEFAULT').eval()
Expand Down
3 changes: 1 addition & 2 deletions examples/extract_one_layer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from torchvision.models import vgg11

from torch_featurelayer import FeatureLayer
from torchvision.models import vgg11

# Load a pretrained VGG-11 model
model = vgg11(weights='DEFAULT').eval()
Expand Down
2 changes: 1 addition & 1 deletion examples/list_layer_candidates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torchvision.models import resnet18
from torch_featurelayer import get_layer_candidates
from torchvision.models import resnet18

# ResNet(
# (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Expand Down

0 comments on commit 2af3274

Please sign in to comment.