-
Notifications
You must be signed in to change notification settings - Fork 0
/
ResNet18.py
39 lines (34 loc) · 1.39 KB
/
ResNet18.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torchvision
from torchvision import models
from torchvision.models import convnext_base, ConvNeXt_Base_Weights, ResNet18_Weights
from torch import nn
import torch
from torchsummary import summary
from torchinfo import summary
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
cmmd_gt = 2
class Res(nn.Module):
def __init__(self, n_classes=cmmd_gt):
super().__init__()
R18 = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
feature_extractor = nn.Sequential(*list(R18.children())[:-2])
self.feature = feature_extractor
self.calssifier =nn.Sequential(nn.Linear(512 * 16 * 16, 1024),
nn.Linear(1024, 512),
nn.Dropout(0.25),
nn.Linear(512, n_classes))
def forward(self, x):
feature = self.feature(x) # this feature we can use when doing stnad.Att
print(feature.shape)
flatten_featur = feature.reshape(feature.size(0), -1) #this we need to plot tsne
print(flatten_featur.shape)
#x = x.view(x.size(0), -1)
x = self.calssifier(flatten_featur)
return flatten_featur, x
model =Res()
#summary(model, (3,512,512))
print(summary(model, (2,3,512,512)))
img = torch.randn(2,3,512,512)
fea,out = model(img)
print(f"shape of feature:{fea.shape}\nshape of output {out.shape}")#
out.shape