-
Notifications
You must be signed in to change notification settings - Fork 0
/
tintml.py
102 lines (78 loc) · 3.18 KB
/
tintml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from rich.progress import BarColumn, TimeRemainingColumn, Progress, TextColumn
from rich.console import Console
import math
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
class Tint():
def __init__(self):
self.console = Console()
#self.current_scope = None
self.metric_dict = dict()
self.longest_metric = 0
self.log = self.console.log
def print(self, something, end='\n'):
self.console.print(something, end=end)
def printh(self, title:str):
#self.current_scope = title
self.console.print("\n{}".format(title), style='bold yellow')
def status(self, title:str):
return self.console.status("[green]{}".format(title))
#def log(self, message:str):
# return self.console.log(message)
def print_metrics(self, metric_dict, low_is_better, multi_line=False):
for i, k in enumerate(metric_dict.keys()):
if len(k) > self.longest_metric:
self.longest_metric = len(k)
# No change (if metric is new or within certain range of
# previous value)
if not k in self.metric_dict or math.isclose(
metric_dict[k], self.metric_dict[k], rel_tol=1e-3):
color = "white"
arrow = " "
# Metric has improved compared to previous value
elif self.compare(metric_dict[k], self.metric_dict[k], low_is_better[i]):
color = "green"
arrow = self.get_arrow(True, low_is_better[i])
# Metric has worsened compared to previous value
else:
color = "red"
arrow = self.get_arrow(False, low_is_better[i])
n_blanks = " " * (self.longest_metric + 1 - len(k))
if multi_line:
self.print(f"{k}:{n_blanks}[{color}]{metric_dict[k]:.3f} {arrow}[/{color}]")
else:
self.print(f"{k}: [{color}]{metric_dict[k]:.3f} {arrow}[/{color}]", end=' ')
# Overwrite previous values
self.metric_dict[k] = metric_dict[k]
if not multi_line:
print()
def get_metrics(self):
return self.metric_dict
def compare(self, val1, val2, low_is_better):
if low_is_better:
return val1 < val2
else:
return val1 > val2
def get_arrow(self, improvement:bool, low_is_better:bool):
if improvement:
return ":arrow_lower_right:" if low_is_better else ":arrow_upper_right:"
else:
return ":arrow_upper_right:" if low_is_better else ":arrow_lower_right:"
def iter(
self,
iterable,
label: str = "Working...",
steps: int = None,
):
progress = Progress(
"[progress.description]{task.description}",
"{task.completed}/{task.total}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
transient=True,
)
with progress:
yield from progress.track(
iterable, total=steps, description=label, update_period=0.1
)