diff --git a/README.md b/README.md index 2e58656be..13efbee6f 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,23 @@ Additionally, masks can be generated for images from the command line: python scripts/amg.py --checkpoint --model-type --input --output ``` +To control the granularity of segmented objects, you can use the `granularity` parameter when initializing the `SamAutomaticMaskGenerator` class or pass the `--granularity` argument when using the `amg.py` script. + +Example usage: + +```python +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry +sam = sam_model_registry[""](checkpoint="") +mask_generator = SamAutomaticMaskGenerator(sam, granularity=0.7) +masks = mask_generator.generate() +``` + +Alternatively, you can pass the `--granularity` argument when using the `amg.py` script: + +``` +python scripts/amg.py --checkpoint --model-type --input --output --granularity 0.7 +``` + See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details.

@@ -181,3 +198,24 @@ If you use SAM or SA-1B in your research, please use the following BibTeX entry. year={2023} } ``` + +## Granularity Parameter + +The `SamAutomaticMaskGenerator` class now includes a `granularity` parameter that allows you to control the granularity of segmented objects. This parameter can be set when initializing the `SamAutomaticMaskGenerator` class or passed as a command-line argument when using the `amg.py` script. + +### Example Usage + +To use the `granularity` parameter, you can initialize the `SamAutomaticMaskGenerator` class with the desired granularity value: + +```python +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry +sam = sam_model_registry[""](checkpoint="") +mask_generator = SamAutomaticMaskGenerator(sam, granularity=0.7) +masks = mask_generator.generate() +``` + +Alternatively, you can pass the `--granularity` argument when using the `amg.py` script: + +``` +python scripts/amg.py --checkpoint --model-type --input --output --granularity 0.7 +``` diff --git a/scripts/amg.py b/scripts/amg.py index f2dbf676a..4df88fc4d 100644 --- a/scripts/amg.py +++ b/scripts/amg.py @@ -63,6 +63,13 @@ ), ) +parser.add_argument( + "--granularity", + type=float, + default=0.5, + help="Set the granularity of segmented objects.", +) + amg_settings = parser.add_argument_group("AMG Settings") amg_settings.add_argument( @@ -187,6 +194,7 @@ def get_amg_kwargs(args): "crop_overlap_ratio": args.crop_overlap_ratio, "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, "min_mask_region_area": args.min_mask_region_area, + "granularity": args.granularity, } amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} return amg_kwargs diff --git a/segment_anything/automatic_mask_generator.py b/segment_anything/automatic_mask_generator.py index d5a8c9692..af925b386 100644 --- a/segment_anything/automatic_mask_generator.py +++ b/segment_anything/automatic_mask_generator.py @@ -49,6 +49,7 @@ def __init__( point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = "binary_mask", + granularity: float = 0.5, ) -> None: """ Using a SAM model, generates masks for the entire image. @@ -93,6 +94,7 @@ def __init__( 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. For large resolutions, 'binary_mask' may consume large amounts of memory. + granularity (float): A parameter to control the granularity of segmented objects. """ assert (points_per_side is None) != ( @@ -132,6 +134,7 @@ def __init__( self.crop_n_points_downscale_factor = crop_n_points_downscale_factor self.min_mask_region_area = min_mask_region_area self.output_mode = output_mode + self.granularity = granularity @torch.no_grad() def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: @@ -170,6 +173,10 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: max(self.box_nms_thresh, self.crop_nms_thresh), ) + # Blend smaller segments into larger ones based on granularity + if self.granularity > 0: + mask_data = self.blend_segments(mask_data, self.granularity) + # Encode masks if self.output_mode == "coco_rle": mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] @@ -370,3 +377,18 @@ def postprocess_small_regions( mask_data.filter(keep_by_nms) return mask_data + + def blend_segments(self, mask_data: MaskData, granularity: float) -> MaskData: + """ + Blends smaller segments into larger ones based on the granularity parameter. + + Arguments: + mask_data (MaskData): The mask data containing the segments. + granularity (float): The granularity parameter to control the blending. + + Returns: + MaskData: The updated mask data with blended segments. + """ + # Implement the blending logic here based on the granularity parameter. + # This is a placeholder implementation and should be replaced with the actual blending logic. + return mask_data