diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py index 8a6e6d816..651092630 100644 --- a/segment_anything/predictor.py +++ b/segment_anything/predictor.py @@ -161,8 +161,8 @@ def predict( ) masks_np = masks[0].detach().cpu().numpy() - iou_predictions_np = iou_predictions[0].detach().cpu().numpy() - low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().to(torch.float32).numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().to(torch.float32).numpy() return masks_np, iou_predictions_np, low_res_masks_np @torch.no_grad()