forked from wogong/pytorch-dann
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_mnistm_weight.py
121 lines (101 loc) · 4.91 KB
/
mnist_mnistm_weight.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import sys
import datetime
import torch
sys.path.append('../')
from models.model import MNISTmodel, MNISTmodel_plain
from core.train_weight import train_dann
from utils.utils import get_data_loader, get_data_loader_weight, init_model, init_random_seed, get_dataset_root, get_model_root, get_data
from torch.utils.tensorboard import SummaryWriter
import shutil
for data_mode in [1]:
for run_mode in [3]:
class Config(object):
# params for path
model_name = "mnist-mnistm-weight"
dataset_root = get_dataset_root()
model_root = get_model_root(model_name, data_mode, run_mode)
model_root = os.path.join(model_root, datetime.datetime.now().strftime('%m%d_%H%M%S'))
os.makedirs(model_root, exist_ok=True)
config = os.path.join(model_root, 'config.txt')
finetune_flag = False
data_mode = data_mode
run_mode = run_mode
T = 0.6
threshold = (T,T)
soft = False
quantile = False
optimal = True
# params for datasets and data loader
batch_size = 64
# params for source dataset
src_dataset = "mnist"
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
class_num_src = 31
data_mode = data_mode
run_mode = run_mode
# params for target dataset
tgt_dataset = "mnistm"
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')
# params for pretrain
num_epochs_src = 100
log_step_src = 10
save_step_src = 50
eval_step_src = 20
# params for training dann
gpu_id = '0'
## for digit
num_epochs = 100
log_step = 20
save_step = 50
eval_step = 1
## for office
# num_epochs = 1000
# log_step = 10 # iters
# save_step = 500
# eval_step = 5 # epochs
lr_adjust_flag = 'simple'
src_only_flag = False
manual_seed = 0
alpha = 0
# params for optimizing models
lr = 5e-4
momentum = 0
weight_decay = 0
def __init__(self):
public_props = (name for name in dir(self) if not name.startswith('_'))
with open(self.config, 'w') as f:
for name in public_props:
f.write(name + ': ' + str(getattr(self, name)) + '\n')
params = Config()
logger = SummaryWriter(params.model_root)
# init random seed
init_random_seed(params.manual_seed)
# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")
print(data_mode, run_mode)
source_weight, target_weight = get_data(params.data_mode)
print(source_weight, target_weight)
# load dataset
if params.data_mode == 3:
src_data_loader, num_src_train = get_data_loader_weight(
params.src_dataset, params.dataset_root, params.batch_size, train=True, weights = source_weight)
src_data_loader_eval, _ = get_data_loader_weight(
params.src_dataset, params.dataset_root, params.batch_size, train=False, weights = source_weight)
else:
src_data_loader, num_src_train = get_data_loader_weight(params.src_dataset, params.dataset_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False)
tgt_data_loader, num_tgt_train = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=True, weights = target_weight)
tgt_data_loader_eval, _ = get_data_loader_weight(
params.tgt_dataset,
params.dataset_root, params.batch_size, train=False, weights = target_weight)
# load dann model
dann = init_model(net=MNISTmodel(), restore=None)
# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader,
tgt_data_loader_eval, num_src_train, num_tgt_train, device, logger)