diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py index 8a6e6d816..765d1a2f2 100644 --- a/segment_anything/predictor.py +++ b/segment_anything/predictor.py @@ -202,13 +202,16 @@ def predict_torch( instead of a binary mask. Returns: - (torch.Tensor): The output masks in BxCxHxW format, where C is the - number of masks, and (H, W) is the original image size. + (torch.Tensor): The output masks in BxCxHxW format, where B is the + number of batches, C is the number of masks per batch, and (H, W) is + the original image size. + The meaning of B depends on the prompt input. (torch.Tensor): An array of shape BxC containing the model's - predictions for the quality of each mask. - (torch.Tensor): An array of shape BxCxHxW, where C is the number - of masks and H=W=256. These low res logits can be passed to - a subsequent iteration as mask input. + predictions for the quality of each mask per batch. + (torch.Tensor): An array of shape BxCxHxW, where B is the + number of batches, C is the number of masks per batch and H=W=256. + These low res logits can be passed to a subsequent iteration as mask input. + The meaning of B depends on the prompt input. """ if not self.is_image_set: raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")