-
Notifications
You must be signed in to change notification settings - Fork 1
/
label.py
46 lines (36 loc) · 1.2 KB
/
label.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import typing
import granularity
class Label(typing.Hashable):
def __init__(self,
l_str: str,
index: int):
self.l_str = l_str
self.index = index
self.g = None
def __str__(self):
return self.l_str
def __hash__(self):
return hash(f'{self.g}_{self.l_str}')
def __eq__(self, other):
return self.__hash__() == other.__hash__()
class FineGrainLabel(Label):
def __init__(self,
g: granularity.Granularity,
l_str: str,
fine_grain_classes_str: typing.List[str]):
super().__init__(l_str=l_str,
index=fine_grain_classes_str.index(l_str))
assert l_str in fine_grain_classes_str
self.g_str = 'fine'
self.g = g
class CoarseGrainLabel(Label):
def __init__(self,
g: granularity.Granularity,
l_str: str,
coarse_grain_classes_str: typing.List[str],
):
super().__init__(l_str=l_str,
index=coarse_grain_classes_str.index(l_str))
assert l_str in coarse_grain_classes_str
self.g_str = 'coarse'
self.g = g