-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Docstrings about dimesions are unclear. #810
Comments
I believe the model/docstrings are consistent here, though it's definitely true that the way it works causes a lot of confusion based on the number of related issues. The main issue seems to be that while multiple points can be given per prompt, the box input only supports having 1 box per prompt. Though this is mentioned in the
So the box input is interpreted as Bx4 (as opposed to Nx4 or BxNx4). In the The 'C' (number of masks) in the output shape is determined by the |
Okay combining the docstring of the output with the docstring of the input, it now matches with one of my last sentences:
In this case, I would say referencing the input docstring in the output docstring would help to clarify and helps to use it. When I am looking for the shape of the output, I look at the docstring of the output and not at the input. |
Yes I'd agree, it's also missing details about how batching with image + prompts works, which can be especially confusing.
To clarify, the 'B' in the output shape is not exactly the 'number of boxes', but the batch size. These end up being numerically the same, but the interpretation of the mask results is different. input_points = torch.tensor(
[[
[75, 275],
[425, 600],
[1375, 550],
[1240, 675],
]],
device=predictor.device,
)
input_labels = torch.tensor([[1,1,1,1]], device=predictor.device)
masks_predict_torch, _, _ = predictor.predict_torch(
point_coords=input_points,
point_labels=input_labels,
boxes=None,
multimask_output=False,
)
print("Mask shape:", masks_predict_torch.shape)
# -> shape: (1, 1, 1200, 1800) In this case, the 'number of points' is 4, but there is only 1 prompt, so the 'B' part of the output shape will be 1. batched_points = torch.tensor(
[
[[75, 275]],
[[425, 600]],
[[1375, 550]],
[[1240, 675]],
],
device=predictor.device,
)
# Output mask shape will be: (4, 1, 1200, 1800) Here, the 'number of points' is 1 per prompt, and there are 4 prompts batched together, so the 'B' part of the output shape will be 4 (like in the box example). The confusing thing is that, for boxes, the model only ever supports the second case above, where the 'number of boxes' must be 1 per prompt. So if more than one box is given, that's interpreted as being multiple prompts (each with 1 box) being batched together. One important consequence of this is that you can't combine multiple boxes to segment a single object (in the same way that multiple points can be used to refine the mask on one object), since the boxes are always batched and so running 'independently' through the model. |
- Update the docstring of `predict_torch` Based on the conversation of facebookresearch#810 the docstring was updated to provide the user a hint how to interpret the dimensions of the outputs.
- Update the docstring of `predict_torch` Based on the conversation of facebookresearch#810 the docstring was updated to provide the user a hint how to interpret the dimensions of the outputs.
The Docstring of the method
SamPredictor.set_image
describes the Input dimensions as the following:While the Output of the method
SamPredictor.predict_torch
is described like the following:While running the code of the example notebook
predictor_example.ipynb
I have noticed that this is not the case.Here a MWE to recreate the issue:
From the Output it seems that the number of masks moved to the Batch position.
Further thinking about it leads to my conclusion, that this seems to be correct, since with Batch is probably the Number of BBoxes meant.
But if this is the case the Docstring is not very clear about it. At least the
B
-dimension should be clearly addressed.The text was updated successfully, but these errors were encountered: