forked from perrying/realistic-ssl-evaluation-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
77 lines (76 loc) · 1.48 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from lib.datasets import svhn, cifar10
import numpy as np
shared_config = {
"iteration" : 500000,
"warmup" : 200000,
"lr_decay_iter" : 400000,
"lr_decay_factor" : 0.2,
"batch_size" : 100,
}
### dataset ###
svhn_config = {
"transform" : [False, True, False], # flip, rnd crop, gaussian noise
"dataset" : svhn.SVHN,
"num_classes" : 10,
}
cifar10_config = {
"transform" : [True, True, True],
"dataset" : cifar10.CIFAR10,
"num_classes" : 10,
}
### algorithm ###
vat_config = {
# virtual adversarial training
"xi" : 1e-6,
"eps" : {"cifar10":6, "svhn":1},
"consis_coef" : 0.3,
"lr" : 3e-3
}
pl_config = {
# pseudo label
"threashold" : 0.95,
"lr" : 3e-4,
"consis_coef" : 1,
}
mt_config = {
# mean teacher
"ema_factor" : 0.95,
"lr" : 4e-4,
"consis_coef" : 8,
}
pi_config = {
# Pi Model
"lr" : 3e-4,
"consis_coef" : 20.0,
}
ict_config = {
# interpolation consistency training
"ema_factor" : 0.999,
"lr" : 4e-4,
"consis_coef" : 100,
"alpha" : 0.1,
}
mm_config = {
# mixmatch
"lr" : 3e-3,
"consis_coef" : 100,
"alpha" : 0.75,
"T" : 0.5,
"K" : 2,
}
supervised_config = {
"lr" : 3e-3
}
### master ###
config = {
"shared" : shared_config,
"svhn" : svhn_config,
"cifar10" : cifar10_config,
"VAT" : vat_config,
"PL" : pl_config,
"MT" : mt_config,
"PI" : pi_config,
"ICT" : ict_config,
"MM" : mm_config,
"supervised" : supervised_config
}