-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathlogger.py
82 lines (69 loc) · 2.61 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright 2022 Twitter, Inc.
# SPDX-License-Identifier: Apache-2.0
import atexit
import json
import os
import os.path as osp
import time
class LOG_DATA:
save_it = 0
headers = []
current_row_data = {}
base_output_dir = None
output_file_name = None
output_weights = None
first_row = True
def configure_output_dir(d=None, force=False):
LOG_DATA.base_output_dir = d or "experiments_data/temp/{}".format(int(time.time()))
if not force:
assert not osp.exists(LOG_DATA.base_output_dir)
LOG_DATA.output_weights = "{}/weights".format(LOG_DATA.base_output_dir)
os.makedirs(LOG_DATA.output_weights)
LOG_DATA.output_file_name = open(osp.join(LOG_DATA.base_output_dir, "log.txt"), 'w')
# registering a function to be executed at termination
atexit.register(LOG_DATA.output_file_name.close)
LOG_DATA.first_row = True
LOG_DATA.save_it = 0
LOG_DATA.headers.clear()
LOG_DATA.current_row_data.clear()
print("Logging data to directory {}".format(LOG_DATA.output_file_name.name))
def save_params(params):
with open(osp.join(LOG_DATA.base_output_dir, 'params.json'), 'w') as out:
out.write(json.dumps(params, indent=2, separators=(',', ': ')))
def load_params(dir):
with open(osp.join(dir, "params.json"), 'r') as inp:
data = json.loads(inp.read())
return data
def log_key_val(key, value):
assert key not in LOG_DATA.current_row_data, "key already recorded {}".format(key)
if LOG_DATA.first_row:
LOG_DATA.headers.append(key)
else:
assert key in LOG_DATA.headers, "key not present in headers: {}".format(key)
LOG_DATA.current_row_data[key] = value
def log_iteration():
vals = []
key_lens = [len(key) for key in LOG_DATA.headers]
max_key_len = max(15, max(key_lens))
keystr = '%' + '%d' % max_key_len
fmt = "| " + keystr + "s = %15s |"
n_slashes = 22 + max_key_len
print("+" * n_slashes)
for key in LOG_DATA.headers:
val = LOG_DATA.current_row_data.get(key, "")
if hasattr(val, "__float__"):
valstr = "%8.3g" % val
else:
valstr = val
print(fmt % (key, valstr))
vals.append(val)
print("+" * n_slashes)
if LOG_DATA.output_file_name is not None:
if LOG_DATA.first_row:
LOG_DATA.output_file_name.write("\t".join(LOG_DATA.headers))
LOG_DATA.output_file_name.write("\n")
LOG_DATA.first_row = False
LOG_DATA.output_file_name.write("\t".join(map(str, vals)))
LOG_DATA.output_file_name.write("\n")
LOG_DATA.output_file_name.flush()
LOG_DATA.current_row_data.clear()