Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docstrings about dimesions are unclear. #810

Open
MaKaNu opened this issue Feb 4, 2025 · 3 comments · May be fixed by #811
Open

Docstrings about dimesions are unclear. #810

MaKaNu opened this issue Feb 4, 2025 · 3 comments · May be fixed by #811

Comments

@MaKaNu
Copy link

MaKaNu commented Feb 4, 2025

The Docstring of the method SamPredictor.set_image describes the Input dimensions as the following:

Arguments:
          image (np.ndarray): The image for calculating masks. Expects an
            image in HWC uint8 format, with pixel values in [0, 255].

While the Output of the method SamPredictor.predict_torch is described like the following:

Returns:
          (torch.Tensor): The output masks in BxCxHxW format, where C is the
            number of masks, and (H, W) is the original image size.

While running the code of the example notebook predictor_example.ipynb I have noticed that this is not the case.

Here a MWE to recreate the issue:

import sys

import cv2
import numpy as np
import torch

sys.path.append("..")
from segment_anything import SamPredictor, sam_model_registry

image = cv2.imread("images/truck.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

predictor.set_image(image)

# Example with predict
input_point = np.array([[500, 375]])
input_label = np.array([1])

masks_predict, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

# Example with predict_torch
input_boxes = torch.tensor(
    [
        [75, 275, 1725, 850],
        [425, 600, 700, 875],
        [1375, 550, 1650, 800],
        [1240, 675, 1400, 750],
    ],
    device=predictor.device,
)

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks_predict_torch, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)

print("Image Shape: ", image.shape)  # expected shape: (1200, 1800, 3); returned shape: (1200, 1800, 3)
print("Masks Predict Shape: ", masks_predict.shape) # expected shape: (3, 1200, 1800); returned shape: (3, 1200, 1800)
print("Masks Predict Torch Shape: ", masks_predict_torch.shape) # expected shape (1,4, 1200, 1800); returned shape (4, 1, 1200, 1800)

From the Output it seems that the number of masks moved to the Batch position.
Further thinking about it leads to my conclusion, that this seems to be correct, since with Batch is probably the Number of BBoxes meant.
But if this is the case the Docstring is not very clear about it. At least the B-dimension should be clearly addressed.

@heyoeyo
Copy link

heyoeyo commented Feb 4, 2025

I believe the model/docstrings are consistent here, though it's definitely true that the way it works causes a lot of confusion based on the number of related issues.

The main issue seems to be that while multiple points can be given per prompt, the box input only supports having 1 box per prompt. Though this is mentioned in the predict_torch docstring:

boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format

So the box input is interpreted as Bx4 (as opposed to Nx4 or BxNx4). In the predict_torch example above, the input is a batch of 4 boxes, which is why the 'B' (batch size) part of the output shape is 4.

The 'C' (number of masks) in the output shape is determined by the multimask_output setting. When set to true, then 3 masks are returned per prompt (like in the first predict example). When it's false, only 1 mask is returned per prompt (like in the predict_torch example).

@MaKaNu
Copy link
Author

MaKaNu commented Feb 4, 2025

Okay combining the docstring of the output with the docstring of the input, it now matches with one of my last sentences:

Further thinking about it leads to my conclusion, that this seems to be correct, since with Batch is probably the Number of BBoxes meant.

In this case, I would say referencing the input docstring in the output docstring would help to clarify and helps to use it. When I am looking for the shape of the output, I look at the docstring of the output and not at the input.

@heyoeyo
Copy link

heyoeyo commented Feb 5, 2025

I would say referencing the input docstring in the output docstring would help to clarify and helps to use it

Yes I'd agree, it's also missing details about how batching with image + prompts works, which can be especially confusing.

since with Batch is probably the Number of BBoxes meant.

To clarify, the 'B' in the output shape is not exactly the 'number of boxes', but the batch size. These end up being numerically the same, but the interpretation of the mask results is different.
It's maybe easier to see this with point prompts. For example:

input_points = torch.tensor(
    [[
        [75, 275],
        [425, 600],
        [1375, 550],
        [1240, 675],
    ]],
    device=predictor.device,
)
input_labels = torch.tensor([[1,1,1,1]], device=predictor.device)

masks_predict_torch, _, _ = predictor.predict_torch(
    point_coords=input_points,
    point_labels=input_labels,
    boxes=None,
    multimask_output=False,
)
print("Mask shape:", masks_predict_torch.shape)
# -> shape: (1, 1, 1200, 1800)

In this case, the 'number of points' is 4, but there is only 1 prompt, so the 'B' part of the output shape will be 1.
If instead the points are written as:

batched_points = torch.tensor(
    [
      [[75, 275]],
      [[425, 600]],
      [[1375, 550]],
      [[1240, 675]],
    ],
    device=predictor.device,
)
# Output mask shape will be: (4, 1, 1200, 1800)

Here, the 'number of points' is 1 per prompt, and there are 4 prompts batched together, so the 'B' part of the output shape will be 4 (like in the box example).

The confusing thing is that, for boxes, the model only ever supports the second case above, where the 'number of boxes' must be 1 per prompt. So if more than one box is given, that's interpreted as being multiple prompts (each with 1 box) being batched together.

One important consequence of this is that you can't combine multiple boxes to segment a single object (in the same way that multiple points can be used to refine the mask on one object), since the boxes are always batched and so running 'independently' through the model.

MaKaNu added a commit to MaKaNu/segment-anything that referenced this issue Feb 5, 2025
- Update the docstring of `predict_torch`

Based on the conversation of facebookresearch#810 the docstring was updated to provide the user a hint how to interpret the dimensions of the outputs.
@MaKaNu MaKaNu linked a pull request Feb 5, 2025 that will close this issue
MaKaNu added a commit to MaKaNu/segment-anything that referenced this issue Feb 5, 2025
- Update the docstring of `predict_torch`

Based on the conversation of facebookresearch#810 the docstring was updated to provide the user a hint how to interpret the dimensions of the outputs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants