-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using SAM2 detects few trees #314
Comments
For this sort of example, SAM might not be a great option. With so many small, similar looking objects, it might be worth trying simpler methods like correlation (see template matching from scikit-image) or even just thresholding might be a good start (for this specific image at least). However, you can probably improve the SAM result by adjusting the default settings. Two that seem like they might help are points_per_side and crop_n_layers. The auto-masking works by generating a bunch of single point prompts in a grid, and the If you want to see what these options are doing, you can add the following code just after line 266 of the mask generator: # Visualize point prompts used by the mask generator
import cv2
debug_img = cv2.cvtColor(cropped_im, cv2.COLOR_RGB2BGR)
for xy in points_for_image:
pt_xy = xy.astype(np.int32).tolist()
cv2.circle(debug_img, pt_xy, 2, (255,0,255), -1)
cv2.imshow("DebugPoints", debug_img)
cv2.waitKey(0)
cv2.destroyWindow("DebugPoints") This will create a pop-up image showing the 'crop' that's being used along with the point prompts (it will also pause the masking, but you can press any key with the window open to resume). One last thing that's probably helpful if you only want the trees is to ignore any overly large masks. You can filter them out using something like: # Filter out large masks
max_area = 1000
masks = [m for m in masks if m["area"] < max_area] |
Small object detection is another CV domain. |
Hi everyone
I recently started evaluating SAM2 for tree detection.
I´d like to clarify that I am new to this whole topic and I´m trying to learn how to use SAM2 to detect trees.
I have tried the following code:
import os
if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
select the device for computation
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"using device: {device}")
if device.type == "cuda":
# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).enter()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. pytorch/pytorch#84936 for a discussion."
)
np.random.seed(3)
def show_anns(anns, borders=True):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
image = Image.open(r"C:\Users\Lenovo\Desktop\Daniel\AIconteo\cerro2_corte.tif")
image = np.array(image.convert("RGB"))
plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('off')
plt.show()
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
sam2_checkpoint = r"C:\Users\Lenovo\segment-anything-2\checkpoints\sam2_hiera_base_plus.pt"
model_cfg = "sam2_hiera_b+.yaml"
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
121
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
I add some examples of images. I am already very grateful if can anyone help me? :)
The text was updated successfully, but these errors were encountered: