-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_utils.py
89 lines (82 loc) · 3.55 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import torch
get_variable_tesnor = lambda x: torch.Tensor(x)
def get_batch(dataset, idxs, start_idx, end_idx,
num_point, num_channel,
from_rgb_detection=False,
device='cuda'):
''' Prepare batch data for training/evaluation.
batch size is determined by start_idx-end_idx
Input:
dataset: an instance of FrustumDataset class
idxs: a list of data element indices
start_idx: int scalar, start position in idxs
end_idx: int scalar, end position in idxs
num_point: int scalar
num_channel: int scalar
from_rgb_detection: bool
Output:
batched data and label
'''
if from_rgb_detection:
return get_batch_from_rgb_detection(dataset, idxs, start_idx, end_idx,
num_point, num_channel)
bsize = end_idx - start_idx
batch_data = np.zeros((bsize, num_point, num_channel))
batch_label = np.zeros((bsize, num_point), dtype=np.int32)
batch_center = np.zeros((bsize, 3))
batch_heading_class = np.zeros((bsize,), dtype=np.int32)
batch_heading_residual = np.zeros((bsize,))
batch_size_class = np.zeros((bsize,), dtype=np.int32)
batch_size_residual = np.zeros((bsize, 3))
batch_rot_angle = np.zeros((bsize,))
if dataset.one_hot:
batch_one_hot_vec = np.zeros((bsize, 3)) # for car,ped,cyc
for i in range(bsize):
if dataset.one_hot:
ps, seg, center, hclass, hres, sclass, sres, rotangle, onehotvec = \
dataset[idxs[i + start_idx]]
batch_one_hot_vec[i] = onehotvec
else:
ps, seg, center, hclass, hres, sclass, sres, rotangle = \
dataset[idxs[i + start_idx]]
batch_data[i, ...] = ps[:, 0:num_channel]
batch_label[i, :] = seg
batch_center[i, :] = center
batch_heading_class[i] = hclass
batch_heading_residual[i] = hres
batch_size_class[i] = sclass
batch_size_residual[i] = sres
batch_rot_angle[i] = rotangle
if dataset.one_hot:
return_list = [batch_data, batch_label, batch_center, \
batch_heading_class, batch_heading_residual, \
batch_size_class, batch_size_residual, \
batch_rot_angle, batch_one_hot_vec]
else:
return_list = [batch_data, batch_label, batch_center, \
batch_heading_class, batch_heading_residual, \
batch_size_class, batch_size_residual, batch_rot_angle]
return [get_variable_tesnor(x).to(device) for x in return_list]
def get_batch_from_rgb_detection(dataset, idxs, start_idx, end_idx,
num_point, num_channel, device='cuda'):
bsize = end_idx - start_idx
batch_data = np.zeros((bsize, num_point, num_channel))
batch_rot_angle = np.zeros((bsize,))
batch_prob = np.zeros((bsize,))
if dataset.one_hot:
batch_one_hot_vec = np.zeros((bsize, 3)) # for car,ped,cyc
for i in range(bsize):
if dataset.one_hot:
ps, rotangle, prob, onehotvec = dataset[idxs[i + start_idx]]
batch_one_hot_vec[i] = onehotvec
else:
ps, rotangle, prob = dataset[idxs[i + start_idx]]
batch_data[i, ...] = ps[:, 0:num_channel]
batch_rot_angle[i] = rotangle
batch_prob[i] = prob
if dataset.one_hot:
return_list = [batch_data, batch_rot_angle, batch_prob, batch_one_hot_vec]
else:
return_list = [batch_data, batch_rot_angle, batch_prob]
return [get_variable_tesnor(x).to(device) for x in return_list]