-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
139 lines (112 loc) · 3.78 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import time
import random
import numpy as np
import torch
class BlackHole(object):
def __setattr__(self, name, value):
pass
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, name):
return self
def get_new_log_dir(root="./logs", prefix="", postfix=""):
log_dir = os.path.join(
root, prefix + time.strftime("%Y_%m_%d__%H_%M_%S", time.localtime()) + postfix
)
os.makedirs(log_dir)
return log_dir
class CheckpointManager(object):
def __init__(self, save_dir, logger=BlackHole()):
super().__init__()
os.makedirs(save_dir, exist_ok=True)
self.save_dir = save_dir
self.ckpts = []
self.logger = logger
for f in os.listdir(self.save_dir):
if f[:4] != "ckpt":
continue
_, score, it = f.split("_")
it = it.split(".")[0]
self.ckpts.append(
{
"score": float(score),
"file": f,
"iteration": int(it),
}
)
def get_worst_ckpt_idx(self):
idx = -1
worst = float("-inf")
for i, ckpt in enumerate(self.ckpts):
if ckpt["score"] >= worst:
idx = i
worst = ckpt["score"]
return idx if idx >= 0 else None
def get_best_ckpt_idx(self):
idx = -1
best = float("inf")
for i, ckpt in enumerate(self.ckpts):
if ckpt["score"] <= best:
idx = i
best = ckpt["score"]
return idx if idx >= 0 else None
def get_latest_ckpt_idx(self):
idx = -1
latest_it = -1
for i, ckpt in enumerate(self.ckpts):
if ckpt["iteration"] > latest_it:
idx = i
latest_it = ckpt["iteration"]
return idx if idx >= 0 else None
def save(self, model, args, score, others=None, step=None):
if step is None:
fname = "ckpt_%.6f_.pt" % float(score)
else:
fname = "ckpt_%.6f_%d.pt" % (float(score), int(step))
path = os.path.join(self.save_dir, fname)
torch.save(
{
"args": args,
"state_dict": model.state_dict(),
"others": others,
},
path,
)
self.ckpts.append({"score": score, "file": fname})
return True
def load_best(self):
idx = self.get_best_ckpt_idx()
if idx is None:
raise IOError("No checkpoints found.")
ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]["file"]))
return ckpt
def load_latest(self):
idx = self.get_latest_ckpt_idx()
if idx is None:
raise IOError("No checkpoints found.")
ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]["file"]))
return ckpt
def load_selected(self, file):
ckpt = torch.load(os.path.join(self.save_dir, file))
return ckpt
def seed_all(seed=42):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = True
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# assumes : batch, particles, feats
def normalize_tensor(tensor, mean, std, sigma=1):
for i in range(len(mean)):
tensor[..., i] = (tensor[..., i] - mean[i]) / (std[i] / sigma)
return tensor
def inverse_normalize_tensor(tensor, mean, std, sigma=1):
for i in range(len(mean)):
tensor[..., i] = (tensor[..., i] * (std[i] / sigma)) + mean[i]
return tensor