-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
127 lines (98 loc) · 3.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import logging
import pathlib
import signal
import subprocess
import sys
import time
from datetime import datetime
from typing import List, Optional
TIMESTAMP = int(time.time()) # Timestamp at experiment start
EXPERIMENT_DIR = None
def _run_subprocess(command: List[str]):
"""Execute command in shell and return output"""
return subprocess.check_output(command).decode("ascii").strip()
def get_hostname():
"""Get hostname"""
return _run_subprocess(["hostname"])
def get_git_revision_hash():
"""Get current git revision hash"""
return _run_subprocess(["git", "rev-parse", "HEAD"])
def get_git_diff():
"""Get git diff of all files except notebooks and shell scripts"""
return _run_subprocess(["git", "diff", "HEAD", ":!*.ipynb", ":!*.sh", ":!*.json"])
def get_git_root():
"""Get git root directory"""
return _run_subprocess(["git", "rev-parse", "--show-toplevel"])
def log_uncaught_errors():
"""Set up logging for uncaught errors"""
def excepthook(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logging.error(
"Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)
)
sys.excepthook = excepthook
def log_replication_info():
"""Log command and git info"""
replication_info = {
"hostname": get_hostname(),
"pwd": str(pathlib.Path.cwd()),
"Command": f"python3 {' '.join(sys.argv)}",
"Git hash": get_git_revision_hash(),
"Git diff": get_git_diff(),
}
for key, value in replication_info.items():
separator = "\n" if "\n" in value else " "
logging.info(f"{key}:{separator}{value}")
save_experiment_json("replication_info", replication_info)
def log_to_stdout():
logging.basicConfig(
level=logging.INFO,
format="{asctime} {levelname} [{filename:s}:{lineno:d}] {message}",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
style="{",
)
log_uncaught_errors()
def init_experiment(handle: str, experiments: Optional[pathlib.Path] = None):
"""Initialize experiment"""
if experiments is None:
experiments = pathlib.Path(get_git_root()) / "experiments"
global EXPERIMENT_DIR
assert EXPERIMENT_DIR is None
date = datetime.utcfromtimestamp(TIMESTAMP).astimezone().strftime("%Y%m%d")
EXPERIMENT_DIR = experiments / handle / f"{date}-{TIMESTAMP}"
EXPERIMENT_DIR.mkdir(parents=True)
logfile = EXPERIMENT_DIR / "default.log"
handlers = [
logging.FileHandler(filename=logfile),
logging.StreamHandler(sys.stdout),
]
logging.basicConfig(
level=logging.INFO,
format="{asctime} {levelname} [{filename:s}:{lineno:d}] {message}",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=handlers,
style="{",
)
logging.info(f"Logging to {logfile}")
log_replication_info()
log_uncaught_errors()
# Symlink to latest experiment
latest = experiments / "_latest"
latest.unlink(missing_ok=True)
latest.symlink_to(EXPERIMENT_DIR.relative_to(experiments))
def save_experiment_json(handle: str, data, partial: bool = False):
assert EXPERIMENT_DIR is not None
if partial:
# Log partial results as "{handle}.partial.json"
handle = f"{handle}.partial"
else:
# Remove partial file if it exists
(EXPERIMENT_DIR / f"{handle}.partial.json").unlink(missing_ok=True)
filename = EXPERIMENT_DIR / f"{handle}.json"
with open(filename, "w") as f:
json.dump(data, f, indent=4)
logging.info(f"Saved experiment {handle} to {filename}")