-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_embeddings.py
119 lines (103 loc) · 4.47 KB
/
extract_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from dataset import TripletDataSet
import torchvision.transforms as transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from model import resnet18
from inception import Inception
import os
from sklearn.manifold.t_sne import TSNE
import matplotlib.pyplot as plt
class L2_norm(nn.Module):
def __init__(self):
super(L2_norm, self).__init__()
def forward(self, x):
return F.normalize(x, p=2, dim=-1)
if __name__ == '__main__':
# choose cuda or cpu device
device = 0
torch.cuda.set_device(device)
# dataloader
batch_size = 64
data_path = r'D:\BaiduNetdiskDownload\catsdogs'
# load model
# model = model.resnet18(pretrained=False, num_classes=1000)
# model.avgpool = nn.AdaptiveAvgPool2d(1)
# model.fc = nn.Sequential(nn.Linear(512, 128, bias=False), L2_norm())
# checkpoint = torch.load('resnet18_triplet_33_best.pth')
# model.load_state_dict(checkpoint['state_dict'])
# model.cuda()
# model.eval()
arch = 'resnet18_triplet'
if arch.lower().startswith('resnet'):
model = resnet18(pretrained=True, num_classes=1000)
model.avgpool = nn.AdaptiveAvgPool2d(1)
model.fc = nn.Sequential(nn.Linear(512, 128, bias=False), L2_norm())
elif arch.lower().startswith('inception'):
model = Inception(3)
model.norm = L2_norm()
else:
raise ValueError('Wrong arch')
checkpoint = torch.load('resnet18_triplet_33_best.pth')
model.load_state_dict(checkpoint['state_dict'])
model.cuda()
model.eval()
# load data
traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'test')
x_val = []
y_val = []
x_train = []
y_train = []
# n_class = 0
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,])
val_dataset = TripletDataSet(valdir, val_transform)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0)
# use same transform as val
tr_dataset = TripletDataSet(traindir, val_transform)
tr_loader = torch.utils.data.DataLoader(tr_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0)
for (input, target) in val_loader:
input_var, target_var = input.cuda(), target.cuda()
with torch.no_grad():
embedding = model(input_var)
x_val.append(embedding.cpu().numpy())
y_val.append(target_var.cpu().numpy())
for (input, target) in tr_loader:
input_var, target_var = input.cuda(), target.cuda()
with torch.no_grad():
embedding = model(input_var)
x_train.append(embedding.cpu().numpy())
y_train.append(target_var.cpu().numpy())
x_val, y_val = np.concatenate(x_val, 0), np.concatenate(y_val, 0)
x_train, y_train = np.concatenate(x_train, 0), np.concatenate(y_train, 0)
np.save('x_train', x_train)
np.save('y_train', y_train)
np.save('x_val', x_val)
np.save('y_val', y_val)
x_tsne = TSNE(n_components=2).fit_transform(x_val)
print(x_tsne.shape)
color = ['darkcyan', 'r']
labels = [str(i) for i in range(2)]
times = [0 for i in range(2)]
for i in range(len(x_tsne)):
if times[y_val[i]] == 0:
plt.scatter(x_tsne[i, 0], x_tsne[i, 1], color=color[y_val[i]], label=labels[y_val[i]])
times[y_val[i]] += 1
else:
plt.scatter(x_tsne[i, 0], x_tsne[i, 1], color=color[y_val[i]])
plt.legend()
plt.title('TSNE validation')
plt.savefig('tsne_val.png')