diff --git a/pythia/scripts/features/extract_features_vmb.py b/pythia/scripts/features/extract_features_vmb.py index c18566177..77988bb62 100644 --- a/pythia/scripts/features/extract_features_vmb.py +++ b/pythia/scripts/features/extract_features_vmb.py @@ -162,7 +162,7 @@ def _process_feature_extraction( feat_list.append(feats[i][keep_boxes]) bbox = output[0]["proposals"][i][keep_boxes].bbox / im_scales[i] # Predict the class label using the scores - objects = torch.argmax(scores[keep_boxes][start_index:], dim=1) + objects = torch.argmax(scores[keep_boxes][:, start_index:], dim=1) info_list.append( {