diff --git a/Inference.py b/Inference.py index 61b70fc..3e7604f 100644 --- a/Inference.py +++ b/Inference.py @@ -68,6 +68,9 @@ def parse_args(): parser.add_argument( "--withContours", type=bool, default=False, help="draw the edges of the masks" ) + parser.add_argument( + "--bwMask", type=str, default="", help="black or white mask mode, reverse for background and foreground" + ) return parser.parse_args() @@ -112,6 +115,7 @@ def main(args): point_label = point_label, withContours=args.withContours, better_quality=args.better_quality, + bwMask=args.bwMask ) diff --git a/MORE_USAGES.md b/MORE_USAGES.md index 9fdfbfe..11919ee 100644 --- a/MORE_USAGES.md +++ b/MORE_USAGES.md @@ -57,3 +57,27 @@ python Inference.py --model_path ./weights/FastSAM.pt \ --withContours True ``` ![text prompt](assets/more_usages/text_prompt_cat.png) + +### use text prompt and output black mask +Use `--bwMask "black"` to specify the mask color +```shell +python Inference.py --model_path ./weights/FastSAM.pt \ + --img_path ./images/dogs.jpg \ + --text_prompt "the black dog" \ + --better_quality True \ + --withContours True \ + --bwMask "black" +``` +![text prompt](assets/more_usages/dogs_black.jpeg) + +### use text prompt and output white mask +Use `--bwMask "white"` to specify the mask color +```shell +python Inference.py --model_path ./weights/FastSAM.pt \ + --img_path ./images/dogs.jpg \ + --text_prompt "the black dog" \ + --better_quality True \ + --withContours True \ + --bwMask "white" +``` +![text prompt](assets/more_usages/dogs_white.jpeg) diff --git a/app_gradio.py b/app_gradio.py index 5682bc5..b4238db 100644 --- a/app_gradio.py +++ b/app_gradio.py @@ -80,7 +80,11 @@ def segment_everything( text="", wider=False, mask_random_color=True, + bw_mask="", ): + # Reset the default value of bw_mask + if bw_mask=="None": + bw_mask="" input_size = int(input_size) # 确保 imgsz 是整数 # Thanks for the suggestion by hysts in HuggingFace. w, h = input.size @@ -111,7 +115,8 @@ def segment_everything( mask_random_color=mask_random_color, bbox=None, use_retina=use_retina, - withContours=withContours,) + withContours=withContours, + bwMask=bw_mask) return fig @@ -124,6 +129,7 @@ def segment_with_points( withContours=True, use_retina=True, mask_random_color=True, + bw_mask="", ): global global_points global global_point_label @@ -157,7 +163,8 @@ def segment_with_points( mask_random_color=mask_random_color, bbox=None, use_retina=use_retina, - withContours=withContours,) + withContours=withContours, + bwMask=bw_mask) global_points = [] global_point_label = [] @@ -342,6 +349,9 @@ def get_points_with_draw(image, label, evt: gr.SelectData): mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx') retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks') wider_check = gr.Checkbox(value=False, label='wider', info='wider result') + with gr.Row(): + bw_mask = gr.Radio(["None", "black", "white"], value="None", label="Output result image as white or black masks") + rand_color = gr.Checkbox(value=False, label='random', info='mask with random color') # Description gr.Markdown(description_e) @@ -357,6 +367,8 @@ def get_points_with_draw(image, label, evt: gr.SelectData): retina_check, text_box, wider_check, + rand_color, + bw_mask, ], outputs=segm_img_t) diff --git a/assets/more_usages/dogs_black.jpeg b/assets/more_usages/dogs_black.jpeg new file mode 100644 index 0000000..d3d0d2e Binary files /dev/null and b/assets/more_usages/dogs_black.jpeg differ diff --git a/assets/more_usages/dogs_white.jpeg b/assets/more_usages/dogs_white.jpeg new file mode 100644 index 0000000..8856220 Binary files /dev/null and b/assets/more_usages/dogs_white.jpeg differ diff --git a/fastsam/prompt.py b/fastsam/prompt.py index dde50ac..4bc34c9 100644 --- a/fastsam/prompt.py +++ b/fastsam/prompt.py @@ -100,7 +100,8 @@ def plot_to_result(self, mask_random_color=True, better_quality=True, retina=False, - withContours=True) -> np.ndarray: + withContours=True, + bwMask="") -> np.ndarray: if isinstance(annotations[0], dict): annotations = [annotation['segmentation'] for annotation in annotations] image = self.img @@ -109,14 +110,17 @@ def plot_to_result(self, original_w = image.shape[1] if sys.platform == "darwin": plt.switch_backend("TkAgg") - plt.figure(figsize=(original_w / 100, original_h / 100)) + bgColor="white" + if bwMask == "white": + bgColor="black" + plt.figure(figsize=(original_w / 100, original_h / 100), facecolor=bgColor, edgecolor=bgColor) # Add subplot with no margin. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) - - plt.imshow(image) + if bwMask == "": + plt.imshow(image) if better_quality: if isinstance(annotations[0], torch.Tensor): annotations = np.array(annotations.cpu()) @@ -135,6 +139,7 @@ def plot_to_result(self, retinamask=retina, target_height=original_h, target_width=original_w, + bwMask=bwMask, ) else: if isinstance(annotations[0], np.ndarray): @@ -149,6 +154,7 @@ def plot_to_result(self, retinamask=retina, target_height=original_h, target_width=original_w, + bwMask=bwMask, ) if isinstance(annotations, torch.Tensor): annotations = annotations.cpu().numpy() @@ -198,7 +204,8 @@ def plot(self, mask_random_color=True, better_quality=True, retina=False, - withContours=True): + withContours=True, + bwMask=""): if len(annotations) == 0: return None result = self.plot_to_result( @@ -210,6 +217,7 @@ def plot(self, better_quality, retina, withContours, + bwMask, ) path = os.path.dirname(os.path.abspath(output_path)) @@ -230,6 +238,7 @@ def fast_show_mask( retinamask=True, target_height=960, target_width=960, + bwMask="", ): msak_sum = annotation.shape[0] height = annotation.shape[1] @@ -239,12 +248,19 @@ def fast_show_mask( sorted_indices = np.argsort(areas) annotation = annotation[sorted_indices] + opacity=0.6 index = (annotation != 0).argmax(axis=0) - if random_color: + if bwMask == "white": + color = np.ones((msak_sum, 1, 1, 3)) + opacity=1 + elif bwMask == "black": + color = np.zeros((msak_sum, 1, 1, 3)) + opacity=1 + elif random_color: color = np.random.random((msak_sum, 1, 1, 3)) else: color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) - transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 + transparency = np.ones((msak_sum, 1, 1, 1)) * opacity visual = np.concatenate([color, transparency], axis=-1) mask_image = np.expand_dims(annotation, -1) * visual @@ -287,6 +303,7 @@ def fast_show_mask_gpu( retinamask=True, target_height=960, target_width=960, + bwMask="", ): msak_sum = annotation.shape[0] height = annotation.shape[1] @@ -295,13 +312,20 @@ def fast_show_mask_gpu( sorted_indices = torch.argsort(areas, descending=False) annotation = annotation[sorted_indices] # Find the index of the first non-zero value at each position. + opacity=0.6 index = (annotation != 0).to(torch.long).argmax(dim=0) - if random_color: + if bwMask == "white": + color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) + opacity=1 + elif bwMask == "black": + color = torch.zeros((msak_sum, 1, 1, 3)).to(annotation.device) + opacity=1 + elif random_color: color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) else: color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ 30 / 255, 144 / 255, 255 / 255]).to(annotation.device) - transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 + transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * opacity visual = torch.cat([color, transparency], dim=-1) mask_image = torch.unsqueeze(annotation, -1) * visual # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form. diff --git a/predict.py b/predict.py index 06c955c..f1ace0a 100644 --- a/predict.py +++ b/predict.py @@ -43,6 +43,10 @@ def predict( better_quality: bool = Input( description="better quality using morphologyEx", default=False ), + bw_mask: str = Input( + default="", + description="empty value to disable, set white or black to get white or black mask" + ), ) -> Path: """Run a single prediction on the model""" @@ -77,6 +81,7 @@ def predict( retina=retina, text_prompt=text_prompt, withContours=withContours, + bw_mask=bw_mask ) args.point_prompt = ast.literal_eval(args.point_prompt) args.box_prompt = ast.literal_eval(args.box_prompt) @@ -102,6 +107,7 @@ def predict( args=args, mask_random_color=args.randomcolor, bbox=convert_box_xywh_to_xyxy(args.box_prompt), + bwMask=args.bw_mask, ) elif args.text_prompt != None: @@ -109,7 +115,10 @@ def predict( annotations = prompt(results, args, text=True) annotations = np.array([annotations]) fast_process( - annotations=annotations, args=args, mask_random_color=args.randomcolor + annotations=annotations, + args=args, + mask_random_color=args.randomcolor, + bwMask=args.bw_mask, ) elif args.point_prompt[0] != [0, 0]: @@ -122,6 +131,7 @@ def predict( args=args, mask_random_color=args.randomcolor, points=args.point_prompt, + bwMask=args.bw_mask, ) else: @@ -129,6 +139,7 @@ def predict( annotations=results[0].masks.data, args=args, mask_random_color=args.randomcolor, + bwMask=args.bw_mask, ) out = "/tmp.out.png" diff --git a/requirements.txt b/requirements.txt index 03b8c97..e2514f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ tqdm>=4.64.0 pandas>=1.1.4 seaborn>=0.11.0 -gradio==3.35.2 +gradio==3.41.0 # Ultralytics----------------------------------- ultralytics == 8.0.120 diff --git a/utils/tools.py b/utils/tools.py index 8b80b4a..50f7ba3 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -93,7 +93,7 @@ def get_bbox_from_mask(mask): def fast_process( - annotations, args, mask_random_color, bbox=None, points=None, edges=False + annotations, args, mask_random_color, bbox=None, points=None, edges=False, bwMask="" ): if isinstance(annotations[0], dict): annotations = [annotation["segmentation"] for annotation in annotations] @@ -104,13 +104,17 @@ def fast_process( original_w = image.shape[1] if sys.platform == "darwin": plt.switch_backend("TkAgg") - plt.figure(figsize=(original_w/100, original_h/100)) + bgColor="white" + if bwMask == "white": + bgColor="black" + plt.figure(figsize=(original_w/100, original_h/100), facecolor=bgColor, edgecolor=bgColor) # Add subplot with no margin. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) - plt.imshow(image) + if bwMask == "": + plt.imshow(image) if args.better_quality == True: if isinstance(annotations[0], torch.Tensor): annotations = np.array(annotations.cpu()) @@ -133,6 +137,7 @@ def fast_process( retinamask=args.retina, target_height=original_h, target_width=original_w, + bwMask=bwMask, ) else: if isinstance(annotations[0], np.ndarray): @@ -147,6 +152,7 @@ def fast_process( retinamask=args.retina, target_height=original_h, target_width=original_w, + bwMask=bwMask, ) if isinstance(annotations, torch.Tensor): annotations = annotations.cpu().numpy() @@ -202,6 +208,7 @@ def fast_show_mask( retinamask=True, target_height=960, target_width=960, + bwMask="", ): msak_sum = annotation.shape[0] height = annotation.shape[1] @@ -211,14 +218,21 @@ def fast_show_mask( sorted_indices = np.argsort(areas) annotation = annotation[sorted_indices] + opacity=0.6 index = (annotation != 0).argmax(axis=0) - if random_color == True: + if bwMask == "white": + color = np.ones((msak_sum, 1, 1, 3)) + opacity=1 + elif bwMask == "black": + color = np.zeros((msak_sum, 1, 1, 3)) + opacity=1 + elif random_color == True: color = np.random.random((msak_sum, 1, 1, 3)) else: color = np.ones((msak_sum, 1, 1, 3)) * np.array( [30 / 255, 144 / 255, 255 / 255] ) - transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 + transparency = np.ones((msak_sum, 1, 1, 1)) * opacity visual = np.concatenate([color, transparency], axis=-1) mask_image = np.expand_dims(annotation, -1) * visual @@ -268,6 +282,7 @@ def fast_show_mask_gpu( retinamask=True, target_height=960, target_width=960, + bwMask="", ): msak_sum = annotation.shape[0] height = annotation.shape[1] @@ -276,14 +291,21 @@ def fast_show_mask_gpu( sorted_indices = torch.argsort(areas, descending=False) annotation = annotation[sorted_indices] # 找每个位置第一个非零值下标 + opacity=0.6 index = (annotation != 0).to(torch.long).argmax(dim=0) - if random_color == True: + if bwMask == "white": + color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) + opacity=1 + elif bwMask == "black": + color = torch.zeros((msak_sum, 1, 1, 3)).to(annotation.device) + opacity=1 + elif random_color == True: color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) else: color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor( [30 / 255, 144 / 255, 255 / 255] ).to(annotation.device) - transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 + transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * opacity visual = torch.cat([color, transparency], dim=-1) mask_image = torch.unsqueeze(annotation, -1) * visual # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 diff --git a/utils/tools_gradio.py b/utils/tools_gradio.py index 31ec5d1..5e17a88 100644 --- a/utils/tools_gradio.py +++ b/utils/tools_gradio.py @@ -15,6 +15,7 @@ def fast_process( bbox=None, use_retina=True, withContours=True, + bwMask="" ): if isinstance(annotations[0], dict): annotations = [annotation['segmentation'] for annotation in annotations] @@ -37,6 +38,7 @@ def fast_process( retinamask=use_retina, target_height=original_h, target_width=original_w, + bwMask=bwMask, ) else: if isinstance(annotations[0], np.ndarray): @@ -49,6 +51,7 @@ def fast_process( retinamask=use_retina, target_height=original_h, target_width=original_w, + bwMask=bwMask, ) if isinstance(annotations, torch.Tensor): annotations = annotations.cpu().numpy() @@ -73,7 +76,13 @@ def fast_process( color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9]) contour_mask = temp / 255 * color.reshape(1, 1, -1) - image = image.convert('RGBA') + if bwMask=="": + image = image.convert('RGBA') + elif bwMask=="white": + image = Image.new('1', (original_w, original_h), 0) + else: + image = Image.new('1', (original_w, original_h), 1) + overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA') image.paste(overlay_inner, (0, 0), overlay_inner) @@ -93,6 +102,7 @@ def fast_show_mask( retinamask=True, target_height=960, target_width=960, + bwMask="", ): mask_sum = annotation.shape[0] height = annotation.shape[1] @@ -102,12 +112,19 @@ def fast_show_mask( sorted_indices = np.argsort(areas)[::1] annotation = annotation[sorted_indices] + opacity=0.6 index = (annotation != 0).argmax(axis=0) - if random_color: + if bwMask == "white": + color = np.ones((mask_sum, 1, 1, 3)) + opacity=1 + elif bwMask == "black": + color = np.zeros((mask_sum, 1, 1, 3)) + opacity=1 + elif random_color: color = np.random.random((mask_sum, 1, 1, 3)) else: color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) - transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6 + transparency = np.ones((mask_sum, 1, 1, 1)) * opacity visual = np.concatenate([color, transparency], axis=-1) mask_image = np.expand_dims(annotation, -1) * visual @@ -135,6 +152,7 @@ def fast_show_mask_gpu( retinamask=True, target_height=960, target_width=960, + bwMask="", ): device = annotation.device mask_sum = annotation.shape[0] @@ -144,14 +162,21 @@ def fast_show_mask_gpu( sorted_indices = torch.argsort(areas, descending=False) annotation = annotation[sorted_indices] # 找每个位置第一个非零值下标 + opacity=0.6 index = (annotation != 0).to(torch.long).argmax(dim=0) - if random_color: + if bwMask == "white": + color = torch.ones((mask_sum, 1, 1, 3)).to(device) + opacity=1 + elif bwMask == "black": + color = torch.zeros((mask_sum, 1, 1, 3)).to(device) + opacity=1 + elif random_color: color = torch.rand((mask_sum, 1, 1, 3)).to(device) else: color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor( [30 / 255, 144 / 255, 255 / 255] ).to(device) - transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6 + transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * opacity visual = torch.cat([color, transparency], dim=-1) mask_image = torch.unsqueeze(annotation, -1) * visual # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式