Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
- Update the docstring of `predict_torch`

Based on the conversation of facebookresearch#810 the docstring was updated to provide the user a hint how to interpret the dimensions of the outputs.
  • Loading branch information
MaKaNu committed Feb 5, 2025
1 parent dca509f commit 75173f1
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions segment_anything/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,16 @@ def predict_torch(
instead of a binary mask.
Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(torch.Tensor): The output masks in BxCxHxW format, where B is the
number of batches, C is the number of masks per batch, and (H, W) is
the original image size.
The meaning of B depends on the prompt input.
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
(torch.Tensor): An array of shape BxCxHxW, where C is the number
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
predictions for the quality of each mask per batch.
(torch.Tensor): An array of shape BxCxHxW, where B is the
number of batches, C is the number of masks per batch and H=W=256.
These low res logits can be passed to a subsequent iteration as mask input.
The meaning of B depends on the prompt input.
"""
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
Expand Down

0 comments on commit 75173f1

Please sign in to comment.