-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
121 lines (104 loc) · 3.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
### packages
# software packages
import string
import torch
import numpy as np
import traceback
import torch.nn as nn
import torch.optim as optim
import argparse
import os
from PIL import Image
import yaml
import logging
import time
import sys
import shutil
from sklearn.manifold import TSNE
from scipy.spatial import Voronoi, voronoi_plot_2d
from runners import *
### main.py
# (1) Import configs (replace argparse. Too many argparse flags now)
def parse_args_and_config():
parser = argparse.ArgumentParser(description=globals()['__doc__'])
# Dataset and save logs
parser.add_argument('--log', default='imgs', help='Output path, including images and logs')
parser.add_argument('--config', type=str, default='default.yml', help='Path for saving running related data.')
parser.add_argument('--seed', type=int, default=1234, help='Random seed')
parser.add_argument('--exp_mode', type=str, default='Full', help='Available: [Full, Partial, One]')
parser.add_argument('--runner', type=str, default='Empirical', help='Available: [Empirical, Certified, Deploy]')
# Arguments not to be touched
parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
# parser.add_argument('--CIFARC_CLASS', type=int, default=-1)
# parser.add_argument('--CIFARC_SEV', type=int, default=1)
args = parser.parse_args()
run_id = str(os.getpid())
run_time = time.strftime('%Y-%b-%d-%H-%M-%S')
# args.doc = '_'.join([args.doc, run_id, run_time])
# parse config file
with open(os.path.join('configs', args.config), 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
new_config = dict2namespace(config)
#define the folder name
if new_config.purification.cond:
args.log = os.path.join("logs","{}_{}_COND:{}".format(
new_config.structure.dataset,
str(new_config.attack.attack_method),
new_config.purification.guide_mode
),
"step_{}_iter_{}_path_{}_per={}_{}".format(
new_config.purification.purify_step,
new_config.purification.max_iter,
new_config.purification.path_number,
new_config.attack.ptb,
f'{new_config.purification.guide_scale}+{new_config.purification.guide_scale_base}'
))
else:
args.log = os.path.join("logs","{}_{}".format(
new_config.structure.dataset,
str(new_config.attack.attack_method)
),
"step_{}_iter_{}_path_{}_per={}".format(
new_config.purification.purify_step,
new_config.purification.max_iter,
new_config.purification.path_number,
new_config.attack.ptb
))
# create folder
# if os.path.exists(args.log):
# shutil.rmtree(args.log)
if not os.path.exists(args.log):
os.makedirs(args.log,exist_ok=True)
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config = parse_args_and_config()
log_progress = open(os.path.join(args.log, f"log_progress_{config.device.rank}"), "w")
sys.stdout = log_progress
logging.info("Config =")
print(">" * 80)
print(config)
print("<" * 80)
try:
runner = eval(args.runner)(args, config)
runner.run(log_progress)
except:
logging.error(traceback.format_exc())
log_progress.close()
return 0
if __name__ == '__main__':
sys.exit(main())