-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_tsne_plotting.py
103 lines (76 loc) · 3.04 KB
/
torch_tsne_plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
path_to_save_check_points=r'C:../early_stopping/' + '/test1'
path_to_checkpoints=path_to_save_check_points
import torch
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import argparse
import os
import numpy as np
import pandas as pd
# from tsne import bh_sne
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
import timm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 5e-3 #0.005
WEIGHT_DECAY = 5e-4 # 0.0005
def gen_features(test_loader):
model = timm.create_model('vgg16', pretrained=True,num_classes=3)
model = model.to(device=DEVICE,dtype=torch.float)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
checkpoint = torch.load(path_to_checkpoints+".pth.tar",map_location=DEVICE)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()
targets_list = []
outputs_list = []
with torch.no_grad():
for idx, (inputs, targets,images_name) in enumerate(test_loader):
inputs = inputs.to(DEVICE)
targets = targets.to(DEVICE)
targets_np = targets.data.cpu().numpy()
outputs = model(inputs)
outputs_np = outputs.data.cpu().numpy()
targets_list.append(targets_np[:, np.newaxis])
outputs_list.append(outputs_np)
if ((idx+1) % 10 == 0) or (idx+1 == len(test_loader)):
print(idx+1, '/', len(test_loader))
targets = np.concatenate(targets_list, axis=0)
outputs = np.concatenate(outputs_list, axis=0).astype(np.float64)
return targets, outputs
parser = argparse.ArgumentParser(description='PyTorch t-SNE for STL10')
parser.add_argument('--save-dir', type=str, default='./results', help='path to save the t-sne image')
parser.add_argument('--batch-size', type=int, default=12, help='batch size (default: 128)')
parser.add_argument('--seed', type=int, default=1, help='random seed value (default: 1)')
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
def tsne_plot(save_dir, targets, outputs):
print('generating t-SNE plot...')
# tsne_output = bh_sne(outputs)
tsne = TSNE(random_state=0)
tsne_output = tsne.fit_transform(outputs)
df = pd.DataFrame(tsne_output, columns=['x', 'y'])
df['targets'] = targets
plt.rcParams['figure.figsize'] = 10, 10
sns.scatterplot(
x='x', y='y',
hue='targets',
palette=sns.color_palette("hls", 3), # change the number according to your outputs
data=df,
marker='o',
legend="full",
alpha=0.5
)
plt.xticks([])
plt.yticks([])
plt.xlabel('')
plt.ylabel('')
plt.savefig(os.path.join(save_dir,'tsne.png'), bbox_inches='tight')
print('done!')
### need to pass the test data loader here
targets, outputs = gen_features(test_loader)
tsne_plot(args.save_dir, targets, outputs)