forked from meijieru/crnn.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
106 lines (84 loc) · 2.94 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/python
# encoding: utf-8
import torch
import torch.nn as nn
import collections
class strLabelConverter(object):
def __init__(self, alphabet):
self.alphabet = alphabet + '-' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
def encode(self, text, depth=0):
"""Support batch or single str."""
if isinstance(text, str):
text = [self.dict[char.lower()] for char in text]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s) for s in text]
text = ''.join(text)
text, _ = self.encode(text)
if depth:
return text, len(text)
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
if length.numel() == 1:
length = length[0]
t = t[:length]
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
class averager(object):
def __init__(self):
self.reset()
def add(self, v):
self.n_count += v.data.numel()
# NOTE: not `+= v.sum()`, which will add a node in the compute graph,
# which lead to memory leak
self.sum += v.data.sum()
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
def oneHot(v, v_length, nc):
batchSize = v_length.size(0)
maxLength = v_length.max()
v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0)
acc = 0
for i in range(batchSize):
length = v_length[i]
label = v[acc:acc + length].view(-1, 1).long()
v_onehot[i, :length].scatter_(1, label, 1.0)
acc += length
return v_onehot
def loadData(v, data):
v.data.resize_(data.size()).copy_(data)
def prettyPrint(v):
print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], v.mean().data[0]))
def assureRatio(img):
"""Ensure imgH <= imgW."""
b, c, h, w = img.size()
if h > w:
main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
img = main(img)
return img