Skip to content

Commit 1f0c23e

Browse files
authored
Add files via upload
1 parent b60e1a3 commit 1f0c23e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+3255
-1
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2019 Hui Zeng
3+
Copyright (c) 2017 Max deGroot, Ellis Brown
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

ShuffleNetV2.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
from collections import OrderedDict
6+
from torch.nn import init
7+
import math
8+
9+
def conv_bn(inp, oup, stride):
10+
return nn.Sequential(
11+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
12+
nn.BatchNorm2d(oup),
13+
nn.ReLU(inplace=True)
14+
)
15+
16+
17+
def conv_1x1_bn(inp, oup):
18+
return nn.Sequential(
19+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
20+
nn.BatchNorm2d(oup),
21+
nn.ReLU(inplace=True)
22+
)
23+
24+
def channel_shuffle(x, groups):
25+
batchsize, num_channels, height, width = x.data.size()
26+
27+
channels_per_group = num_channels // groups
28+
29+
# reshape
30+
x = x.view(batchsize, groups,
31+
channels_per_group, height, width)
32+
33+
x = torch.transpose(x, 1, 2).contiguous()
34+
35+
# flatten
36+
x = x.view(batchsize, -1, height, width)
37+
38+
return x
39+
40+
class InvertedResidual(nn.Module):
41+
def __init__(self, inp, oup, stride, benchmodel):
42+
super(InvertedResidual, self).__init__()
43+
self.benchmodel = benchmodel
44+
self.stride = stride
45+
assert stride in [1, 2]
46+
47+
oup_inc = oup//2
48+
49+
if self.benchmodel == 1:
50+
#assert inp == oup_inc
51+
self.banch2 = nn.Sequential(
52+
# pw
53+
nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
54+
nn.BatchNorm2d(oup_inc),
55+
nn.ReLU(inplace=True),
56+
# dw
57+
nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
58+
nn.BatchNorm2d(oup_inc),
59+
# pw-linear
60+
nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
61+
nn.BatchNorm2d(oup_inc),
62+
nn.ReLU(inplace=True),
63+
)
64+
else:
65+
self.banch1 = nn.Sequential(
66+
# dw
67+
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
68+
nn.BatchNorm2d(inp),
69+
# pw-linear
70+
nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False),
71+
nn.BatchNorm2d(oup_inc),
72+
nn.ReLU(inplace=True),
73+
)
74+
75+
self.banch2 = nn.Sequential(
76+
# pw
77+
nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False),
78+
nn.BatchNorm2d(oup_inc),
79+
nn.ReLU(inplace=True),
80+
# dw
81+
nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
82+
nn.BatchNorm2d(oup_inc),
83+
# pw-linear
84+
nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
85+
nn.BatchNorm2d(oup_inc),
86+
nn.ReLU(inplace=True),
87+
)
88+
89+
@staticmethod
90+
def _concat(x, out):
91+
# concatenate along channel axis
92+
return torch.cat((x, out), 1)
93+
94+
def forward(self, x):
95+
if 1==self.benchmodel:
96+
x1 = x[:, :(x.shape[1]//2), :, :]
97+
x2 = x[:, (x.shape[1]//2):, :, :]
98+
out = self._concat(x1, self.banch2(x2))
99+
elif 2==self.benchmodel:
100+
out = self._concat(self.banch1(x), self.banch2(x))
101+
102+
return channel_shuffle(out, 2)
103+
104+
105+
class ShuffleNetV2(nn.Module):
106+
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
107+
super(ShuffleNetV2, self).__init__()
108+
109+
assert input_size % 32 == 0
110+
111+
self.stage_repeats = [4, 8, 4]
112+
# index 0 is invalid and should never be called.
113+
# only used for indexing convenience.
114+
if width_mult == 0.5:
115+
self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
116+
elif width_mult == 1.0:
117+
self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
118+
elif width_mult == 1.5:
119+
self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
120+
elif width_mult == 2.0:
121+
self.stage_out_channels = [-1, 24, 224, 488, 976, 2048]
122+
else:
123+
raise ValueError(
124+
"""{} groups is not supported for
125+
1x1 Grouped Convolutions""".format(num_groups))
126+
127+
# building first layer
128+
input_channel = self.stage_out_channels[1]
129+
self.conv1 = conv_bn(3, input_channel, 2)
130+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
131+
132+
self.features = []
133+
# building inverted residual blocks
134+
for idxstage in range(len(self.stage_repeats)):
135+
numrepeat = self.stage_repeats[idxstage]
136+
output_channel = self.stage_out_channels[idxstage+2]
137+
for i in range(numrepeat):
138+
if i == 0:
139+
#inp, oup, stride, benchmodel):
140+
self.features.append(InvertedResidual(input_channel, output_channel, 2, 2))
141+
else:
142+
self.features.append(InvertedResidual(input_channel, output_channel, 1, 1))
143+
input_channel = output_channel
144+
145+
146+
# make it nn.Sequential
147+
self.features = nn.Sequential(*self.features)
148+
149+
# building last several layers
150+
self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1])
151+
self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32)))
152+
153+
# building classifier
154+
self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class))
155+
156+
def forward(self, x):
157+
x = self.conv1(x)
158+
x = self.maxpool(x)
159+
x = self.features(x)
160+
x = self.conv_last(x)
161+
x = self.globalpool(x)
162+
x = x.view(-1, self.stage_out_channels[-1])
163+
x = self.classifier(x)
164+
return x
165+
166+
def shufflenetv2(width_mult=1.):
167+
model = ShuffleNetV2(width_mult=width_mult)
168+
return model
169+
170+
if __name__ == "__main__":
171+
"""Testing
172+
"""
173+
model = ShuffleNetV2()
174+
print(model)

TestAccuracy.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from croppingDataset import GAICD
2+
from croppingModel import build_crop_model
3+
import time
4+
import math
5+
import sys
6+
import torch
7+
from torch.autograd import Variable
8+
import torch.backends.cudnn as cudnn
9+
import torch.utils.data as data
10+
import argparse
11+
from scipy.stats import spearmanr, pearsonr
12+
13+
parser = argparse.ArgumentParser(
14+
description='Single Shot MultiBox Detector Training With Pytorch')
15+
parser.add_argument('--dataset_root', default='dataset/GAIC/', help='Dataset root directory path')
16+
parser.add_argument('--image_size', default=256, type=int, help='Batch size for training')
17+
parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training')
18+
parser.add_argument('--num_workers', default=0, type=int, help='Number of workers used in dataloading')
19+
parser.add_argument('--cuda', default=True, help='Use CUDA to train model')
20+
parser.add_argument('--net_path', default='weights/ablation/cropping/mobilenetv2/downsample4_multi_Aug1_Align9_Cdim8/23_0.625_0.583_0.553_0.525_0.785_0.762_0.748_0.723_0.783_0.806.pth_____',
21+
help='Directory for saving checkpoint models')
22+
args = parser.parse_args()
23+
24+
if torch.cuda.is_available():
25+
if args.cuda:
26+
torch.set_default_tensor_type('torch.cuda.FloatTensor')
27+
if not args.cuda:
28+
print("WARNING: It looks like you have a CUDA device, but aren't " +
29+
"using CUDA.\nRun with --cuda for optimal training speed.")
30+
torch.set_default_tensor_type('torch.FloatTensor')
31+
else:
32+
torch.set_default_tensor_type('torch.FloatTensor')
33+
34+
35+
data_loader = data.DataLoader(GAICD(image_size=args.image_size, dataset_dir=args.dataset_root, set='test'), args.batch_size, num_workers=args.num_workers, shuffle=False)
36+
37+
def test():
38+
39+
net = build_crop_model(scale='multi', alignsize=9, reddim=8, loadweight=True, model='mobilenetv2', downsample=4)
40+
41+
net.load_state_dict(torch.load(args.net_path))
42+
43+
if args.cuda:
44+
net = torch.nn.DataParallel(net,device_ids=[0])
45+
torch.backends.cudnn.deterministic = True
46+
torch.backends.cudnn.benchmark = False
47+
net = net.cuda()
48+
49+
net.eval()
50+
51+
acc4_5 = []
52+
acc4_10 = []
53+
wacc4_5 = []
54+
wacc4_10 = []
55+
srcc = []
56+
pcc = []
57+
for n in range(4):
58+
acc4_5.append(0)
59+
acc4_10.append(0)
60+
wacc4_5.append(0)
61+
wacc4_10.append(0)
62+
63+
for id, sample in enumerate(data_loader):
64+
image = sample['image']
65+
bboxs = sample['bbox']
66+
MOS = sample['MOS']
67+
68+
roi = []
69+
70+
for idx in range(0,len(bboxs['xmin'])):
71+
roi.append((0, bboxs['xmin'][idx],bboxs['ymin'][idx],bboxs['xmax'][idx],bboxs['ymax'][idx]))
72+
73+
if args.cuda:
74+
image = Variable(image.cuda())
75+
roi = Variable(torch.Tensor(roi))
76+
else:
77+
image = Variable(image)
78+
roi = Variable(torch.Tensor(roi))
79+
80+
t0 = time.time()
81+
out = net(image,roi)
82+
t1 = time.time()
83+
print('timer: %.4f sec.' % (t1 - t0))
84+
85+
id_MOS = sorted(range(len(MOS)), key=lambda k: MOS[k], reverse=True)
86+
id_out = sorted(range(len(out)), key=lambda k: out[k], reverse=True)
87+
88+
rank_of_returned_crop = []
89+
for k in range(4):
90+
rank_of_returned_crop.append(id_MOS.index(id_out[k]))
91+
92+
for k in range(4):
93+
temp_acc_4_5 = 0.0
94+
temp_acc_4_10 = 0.0
95+
for j in range(k+1):
96+
if MOS[id_out[j]] >= MOS[id_MOS[4]]:
97+
temp_acc_4_5 += 1.0
98+
if MOS[id_out[j]] >= MOS[id_MOS[9]]:
99+
temp_acc_4_10 += 1.0
100+
acc4_5[k] += temp_acc_4_5 / (k+1.0)
101+
acc4_10[k] += temp_acc_4_10 / (k+1.0)
102+
103+
for k in range(4):
104+
temp_wacc_4_5 = 0.0
105+
temp_wacc_4_10 = 0.0
106+
temp_rank_of_returned_crop = rank_of_returned_crop[:(k+1)]
107+
temp_rank_of_returned_crop.sort()
108+
for j in range(k+1):
109+
if temp_rank_of_returned_crop[j] <= 4:
110+
temp_wacc_4_5 += 1.0 * math.exp(-0.2*(temp_rank_of_returned_crop[j]-j))
111+
if temp_rank_of_returned_crop[j] <= 9:
112+
temp_wacc_4_10 += 1.0 * math.exp(-0.1*(temp_rank_of_returned_crop[j]-j))
113+
wacc4_5[k] += temp_wacc_4_5 / (k+1.0)
114+
wacc4_10[k] += temp_wacc_4_10 / (k+1.0)
115+
116+
117+
MOS_arr = []
118+
out = torch.squeeze(out).cpu().detach().numpy()
119+
for k in range(len(MOS)):
120+
MOS_arr.append(MOS[k].numpy()[0])
121+
srcc.append(spearmanr(MOS_arr,out)[0])
122+
pcc.append(pearsonr(MOS_arr,out)[0])
123+
124+
125+
for k in range(4):
126+
acc4_5[k] = acc4_5[k] / 200.0
127+
acc4_10[k] = acc4_10[k] / 200.0
128+
wacc4_5[k] = wacc4_5[k] / 200.0
129+
wacc4_10[k] = wacc4_10[k] / 200.0
130+
131+
avg_srcc = sum(srcc) / 200.0
132+
avg_pcc = sum(pcc) / 200.0
133+
134+
sys.stdout.write('[%.3f, %.3f, %.3f, %.3f] [%.3f, %.3f, %.3f, %.3f]\n' % (acc4_5[0],acc4_5[1],acc4_5[2],acc4_5[3],acc4_10[0],acc4_10[1],acc4_10[2],acc4_10[3]))
135+
sys.stdout.write('[%.3f, %.3f, %.3f, %.3f] [%.3f, %.3f, %.3f, %.3f]\n' % (wacc4_5[0],wacc4_5[1],wacc4_5[2],wacc4_5[3],wacc4_10[0],wacc4_10[1],wacc4_10[2],wacc4_10[3]))
136+
sys.stdout.write('[Avg SRCC: %.3f] [Avg PCC: %.3f]\n' % (avg_srcc,avg_pcc))
137+
138+
139+
if __name__ == '__main__':
140+
test()

0 commit comments

Comments
 (0)