From 75173f1b1ac39caa8969ed735a812c9b436a57ad Mon Sep 17 00:00:00 2001 From: MaKaNu <32844273+MaKaNu@users.noreply.github.com> Date: Wed, 5 Feb 2025 17:56:22 +0100 Subject: [PATCH] Fix #810 - Update the docstring of `predict_torch` Based on the conversation of #810 the docstring was updated to provide the user a hint how to interpret the dimensions of the outputs. --- segment_anything/predictor.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py index 8a6e6d816..765d1a2f2 100644 --- a/segment_anything/predictor.py +++ b/segment_anything/predictor.py @@ -202,13 +202,16 @@ def predict_torch( instead of a binary mask. Returns: - (torch.Tensor): The output masks in BxCxHxW format, where C is the - number of masks, and (H, W) is the original image size. + (torch.Tensor): The output masks in BxCxHxW format, where B is the + number of batches, C is the number of masks per batch, and (H, W) is + the original image size. + The meaning of B depends on the prompt input. (torch.Tensor): An array of shape BxC containing the model's - predictions for the quality of each mask. - (torch.Tensor): An array of shape BxCxHxW, where C is the number - of masks and H=W=256. These low res logits can be passed to - a subsequent iteration as mask input. + predictions for the quality of each mask per batch. + (torch.Tensor): An array of shape BxCxHxW, where B is the + number of batches, C is the number of masks per batch and H=W=256. + These low res logits can be passed to a subsequent iteration as mask input. + The meaning of B depends on the prompt input. """ if not self.is_image_set: raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")