-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathadversarial_random.py
84 lines (70 loc) · 2.99 KB
/
adversarial_random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import random
from utils.misc import *
from utils.adapt_helpers import *
from utils.rotation import rotate_batch
from utils.model import resnet18
from utils.train_helpers import normalize
from utils.test_helpers import test
import matplotlib.pyplot as plt
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
parser = argparse.ArgumentParser()
parser.add_argument('--level', default=0, type=int)
parser.add_argument('--corruption', default='original')
parser.add_argument('--dataroot', default='data/CIFAR-10-C/')
parser.add_argument('--shared', default=None)
########################################################################
parser.add_argument('--depth', default=18, type=int)
parser.add_argument('--group_norm', default=32, type=int)
parser.add_argument('--batch_size', default=32, type=int)
########################################################################
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--niter', default=1, type=int)
parser.add_argument('--online', action='store_true')
parser.add_argument('--shuffle', action='store_true')
parser.add_argument('--threshold', default=1, type=float)
parser.add_argument('--dset_size', default=0, type=int)
########################################################################
parser.add_argument('--resume', default=None)
parser.add_argument('--outf', default='.')
parser.add_argument('--epochs', default=10, type=int)
args = parser.parse_args()
args.threshold += 0.001 # to correct for numeric errors
my_makedir(args.outf)
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
def gn_helper(planes):
return nn.GroupNorm(args.group_norm, planes)
norm_layer = gn_helper
net = resnet18(num_classes = 10, norm_layer=norm_layer).to(device)
net = torch.nn.DataParallel(net)
print('Resuming from %s...' %(args.resume))
ckpt = torch.load('%s/best.pth' %(args.resume))
net.load_state_dict(ckpt['net'])
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(net.parameters(), lr=args.lr)
_, teloader = prepare_test_data(args)
print("Random Attack...")
for i in range(args.epochs):
r = np.random.normal(0.4914, 0.2023, (32, 32))
g = np.random.normal(0.4822, 0.1994, (32, 32))
b = np.random.normal(0.4465, 0.2010, (32, 32))
img = np.stack([r, g, b], axis=0)
tensor = torch.from_numpy(img).float()
tensor = torch.clamp(tensor, 0, 1)
adapt_single_tensor(net, tensor, optimizer, criterion, args.niter, args.batch_size)
if i % 50 == 49:
print("%d%%" % ((i + 1) * 100 / 5000))
err_cls, correct_per_cls, total_per_cls = test(teloader, net, verbose=True, print_freq=0)
print("Epoch %d Test error: %.3f" % (i, err_cls))