Skip to content

Commit b6a6286

Browse files
authored
correct simple_test error
1 parent c0a0013 commit b6a6286

File tree

1 file changed

+4
-31
lines changed

1 file changed

+4
-31
lines changed

mmdet3d/models/detectors/minkunet_semsegFF.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -235,38 +235,11 @@ def forward_train(self, points, pts_semantic_mask, img_metas):
235235
def simple_test(self, points, img_metas, *args, **kwargs):
236236
"""Test without augmentations.
237237
"""
238-
239-
timestamps = []
240-
if self.evaluator_mode == 'slice_len_constant':
241-
i=1
242-
while i*self.len_slice<len(points[0]):
243-
timestamps.append(i*self.len_slice)
244-
i=i+1
245-
timestamps.append(len(points[0]))
246-
else:
247-
num_slice = min(len(points[0]),self.num_slice)
248-
for i in range(1,num_slice):
249-
timestamps.append(i*(len(points[0])//num_slice))
250-
timestamps.append(len(points[0]))
251-
252-
# Process
253-
semseg_results = []
254-
255-
for i in range(len(timestamps)):
256-
if i == 0:
257-
ts_start, ts_end = 0, timestamps[i]
258-
else:
259-
ts_start, ts_end = timestamps[i-1], timestamps[i]
260-
sem_result = []
261-
262-
points_new = [points[0][ts_start:ts_end,:,:].reshape(-1,points[0].shape[-1])]
263-
field = self.collate(points_new, ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
264-
x = self.extract_feat(field.sparse(), img_metas)
265-
266-
preds = self.head.forward_test(x, field, img_metas)
267-
semseg_results.append(preds.cpu())
238+
field = self.collate(points, ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
239+
x = self.extract_feat(field.sparse(), img_metas)
268240

269-
results = [dict(semantic_mask=torch.cat(semseg_results,dim=0))]
241+
preds = self.head.forward_test(x, field, img_metas)
242+
results = [dict(semantic_mask=preds[0].cpu())]
270243
return results
271244

272245
def aug_test(self, points, img_metas, **kwargs):

0 commit comments

Comments
 (0)