-
Notifications
You must be signed in to change notification settings - Fork 1
/
sem_map.py
157 lines (130 loc) · 6.41 KB
/
sem_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import depth_utils as du
from constants import category2objectid, mp3d_region_id2name
class Semantic_Mapping(nn.Module):
"""
Semantic_Mapping
"""
def __init__(self, args):
super(Semantic_Mapping, self).__init__()
self.args = args
self.device = args.device
self.screen_h = args.frame_height
self.screen_w = args.frame_width
self.resolution = args.map_resolution
self.z_resolution = args.map_resolution
self.map_size_cm = args.map_size_cm
self.vision_range = args.vision_range
self.fov = args.hfov
self.du_scale = args.du_scale
self.cat_pred_threshold = args.cat_pred_threshold
self.exp_pred_threshold = args.exp_pred_threshold
self.map_pred_threshold = args.map_pred_threshold
self.max_height = int(360 / self.z_resolution)
self.min_height = int(-40 / self.z_resolution)
self.min_d = args.min_depth
self.max_d = args.max_depth
self.agent_height = args.camera_height * 100.
self.shift_loc = [self.vision_range *
self.resolution // 2, 0, np.pi / 2.0]
self.camera_matrix = du.get_camera_matrix(
self.screen_w, self.screen_h, self.fov)
vr = self.vision_range
self.num_sem_categories = len(category2objectid.keys()) * self.args.use_obj + \
len(mp3d_region_id2name.keys()) * self.args.use_region
self.init_grid = torch.zeros(
args.num_processes, 1 + self.num_sem_categories, vr, vr,
self.max_height - self.min_height
).float().to(self.device)
self.feat = torch.ones(
args.num_processes, 1 + self.num_sem_categories,
self.screen_h // self.du_scale * self.screen_w // self.du_scale
).float().to(self.device)
def forward(self, obs, maps_last, poses):
bs = obs.shape[0]
c = obs.shape[1]
depth = obs[:, 3]
mask2 = depth > 0.99
depth[mask2] = 0.
mask1 = depth == 0
depth[mask1] = 100
depth = self.args.min_depth * 100.0 + depth * self.args.max_depth * 100.0
#TODO if the depth is too small, don't use the detection results.
# X is positive going right; Y is positive into the image; Z is positive up in the image
point_cloud_t = du.get_point_cloud_from_z_t(
depth, self.camera_matrix, self.device, scale=self.du_scale)
agent_view_t = du.transform_camera_view_t(
point_cloud_t, self.agent_height, 0, self.device)
agent_view_centered_t = du.transform_pose_t(
agent_view_t, self.shift_loc, self.device)
max_h = self.max_height
min_h = self.min_height
xy_resolution = self.resolution
z_resolution = self.z_resolution
vision_range = self.vision_range
XYZ_cm_std = agent_view_centered_t.float()
XYZ_cm_std[..., :2] = (XYZ_cm_std[..., :2] / xy_resolution)
XYZ_cm_std[..., :2] = (XYZ_cm_std[..., :2] -
vision_range // 2.) / vision_range * 2.
XYZ_cm_std[..., 2] = XYZ_cm_std[..., 2] / z_resolution
XYZ_cm_std[..., 2] = (XYZ_cm_std[..., 2] -
(max_h + min_h) // 2.) / (max_h - min_h) * 2.
XYZ_cm_std = XYZ_cm_std.permute(0, 3, 1, 2)
XYZ_cm_std = XYZ_cm_std.view(XYZ_cm_std.shape[0],
XYZ_cm_std.shape[1],
XYZ_cm_std.shape[2] * XYZ_cm_std.shape[3])
feat = self.feat[:bs]
init_grid = self.init_grid[:bs]
if c > 4:
feat[:, 1:, :] = obs[:, 4:, :, :].view(bs, c-4, -1)
voxels = du.splat_feat_nd(
init_grid * 0., feat, XYZ_cm_std).transpose(2, 3)
min_z = int(25 / z_resolution - min_h)
max_z = int((self.agent_height + 1) / z_resolution - min_h)
agent_height_proj = voxels[..., min_z:max_z].sum(4) # only consider the voxels at agent's height
all_height_proj = voxels.sum(4) # all height is for entire scene
fp_map_pred = agent_height_proj[:, 0:1, :, :]
fp_exp_pred = all_height_proj[:, 0:1, :, :]
fp_map_pred = fp_map_pred / self.map_pred_threshold
fp_exp_pred = fp_exp_pred / self.exp_pred_threshold
fp_map_pred = torch.clamp(fp_map_pred, min=0.0, max=1.0)
fp_exp_pred = torch.clamp(fp_exp_pred, min=0.0, max=1.0)
agent_view = torch.zeros(bs, c,
self.map_size_cm // self.resolution,
self.map_size_cm // self.resolution
).to(self.device)
x1 = self.map_size_cm // (self.resolution * 2) - self.vision_range // 2
x2 = x1 + self.vision_range
y1 = self.map_size_cm // (self.resolution * 2)
y2 = y1 + self.vision_range
agent_view[:, 0:1, y1:y2, x1:x2] = fp_map_pred
agent_view[:, 1:2, y1:y2, x1:x2] = fp_exp_pred
agent_view[:, 4:4+self.num_sem_categories, y1:y2, x1:x2] = torch.clamp(
agent_height_proj[:, 1:1+self.num_sem_categories, :, :] / self.cat_pred_threshold,
min=0.0, max=1.0)
current_poses = poses
st_pose = current_poses.clone().detach()
st_pose[:, :2] = - (st_pose[:, :2]
* 100.0 / self.resolution
- self.map_size_cm // (self.resolution * 2)) / \
(self.map_size_cm // (self.resolution * 2))
st_pose[:, 2] = 90. - (st_pose[:, 2])
rot_mat, trans_mat = du.get_grid(st_pose, agent_view.size(),
self.device)
rotated = F.grid_sample(agent_view, rot_mat, align_corners=True)
translated = F.grid_sample(rotated, trans_mat, align_corners=True)
maps2 = torch.cat((maps_last.unsqueeze(1), translated.unsqueeze(1)), 1)
map_pred, _ = torch.max(maps2, 1)
# update agent location
# dim 2: current agent location
# dim 3: past agent location
map_pred[:, 3] = map_pred[:, 2]
curr_loc = poses[:, :2] * 100 / self.args.map_resolution
curr_loc = curr_loc.int()
for scene_i in range(map_pred.shape[0]):
r, c = curr_loc[scene_i, 1], curr_loc[scene_i, 0]
map_pred[scene_i, 2, r-1:r+2, c-1:c+2] = 1.0
return map_pred