From ab9fa85173f66f9abe9719195024f03165213f65 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Fri, 13 Dec 2024 10:34:43 -0800 Subject: [PATCH] Enable sam2 video pipeline (#444) Co-authored-by: Yizhuo Zhang --- .../segment-anything-model-v2/README.md | 27 +- .../configs/sam2_hiera_l.yaml | 57 +- .../segment-anything-model-v2/demo_utils.py | 99 ++ .../download_test_data.py | 76 ++ .../segment-anything-model-v2/image_demo.py | 70 +- .../sam2/build_sam.py | 14 +- .../sam2/modeling/memory_attention.py | 61 +- .../sam2/modeling/sam/mask_decoder.py | 20 +- .../sam2/modeling/sam/transformer.py | 6 +- .../sam2/modeling/sam2_base.py | 16 +- .../sam2/modeling/sam2_utils.py | 5 - .../sam2/sam2_video_predictor.py | 948 ++++++++++++++++++ .../sam2/utils/misc.py | 288 ++++++ .../segment-anything-model-v2/video_demo.py | 178 ++++ tripy/tests/test_examples.py | 5 +- 15 files changed, 1732 insertions(+), 138 deletions(-) create mode 100644 tripy/examples/segment-anything-model-v2/demo_utils.py create mode 100644 tripy/examples/segment-anything-model-v2/download_test_data.py create mode 100644 tripy/examples/segment-anything-model-v2/sam2/sam2_video_predictor.py create mode 100644 tripy/examples/segment-anything-model-v2/sam2/utils/misc.py create mode 100644 tripy/examples/segment-anything-model-v2/video_demo.py diff --git a/tripy/examples/segment-anything-model-v2/README.md b/tripy/examples/segment-anything-model-v2/README.md index 592749f96..400303a6e 100644 --- a/tripy/examples/segment-anything-model-v2/README.md +++ b/tripy/examples/segment-anything-model-v2/README.md @@ -6,23 +6,23 @@ This is an implementation of SAM2 model ([original repository](https://github.co ## Running The Example -### Image pipeline - 1. Install prerequisites: ```bash - sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y + sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 jpeginfo -y python3 -m pip install -r requirements.txt ``` -2. Retrieve an example image and the checkpoint: +2. Retrieve the images and the checkpoint: ```bash - wget -O truck.jpg https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg + python3 download_test_data.py mkdir checkpoints && cd checkpoints && wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt ``` -3. Run the example: +### Image pipeline + +1. Run the example: ```bash python3 image_demo.py @@ -38,7 +38,20 @@ This is an implementation of SAM2 model ([original repository](https://github.co ### Video segmentation pipeline -TBD +1. Run the example: + + ```bash + python3 video_demo.py + ``` + + ## License diff --git a/tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml b/tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml index 6c84c6089..1bf182359 100644 --- a/tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml +++ b/tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml @@ -34,42 +34,34 @@ model: d_model: 256 pos_enc_at_input: true dtype: float16 - layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer - activation: relu - dim_feedforward: 2048 - dropout: 0.1 - dtype: float16 - pos_enc_at_attn: false - self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention - rope_theta: 10000.0 - feat_sizes: [32, 32] - embedding_dim: 256 - num_heads: 1 - downsample_rate: 1 - dropout: 0.1 - dtype: float16 - d_model: 256 - pos_enc_at_cross_attn_keys: true - pos_enc_at_cross_attn_queries: false - cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention - rope_theta: 10000.0 - feat_sizes: [32, 32] - rope_k_repeat: True - embedding_dim: 256 - num_heads: 1 - downsample_rate: 1 - dropout: 0.1 - kv_in_dim: 64 - dtype: float16 num_layers: 4 + # memory attention layer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + # self rope attention + sa_rope_theta: 10000.0 + sa_feat_sizes: [32, 32] + sa_embedding_dim: 256 + sa_num_heads: 1 + sa_downsample_rate: 1 + sa_dropout: 0.1 + # cross rope attention + ca_rope_theta: 10000.0 + ca_feat_sizes: [32, 32] + ca_rope_k_repeat: True + ca_embedding_dim: 256 + ca_num_heads: 1 + ca_downsample_rate: 1 + ca_dropout: 0.1 + ca_kv_in_dim: 64 memory_encoder: _target_: sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 - dtype: float16 position_encoding: _target_: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 @@ -81,10 +73,8 @@ model: kernel_size: 3 stride: 2 padding: 1 - dtype: float16 fuser: _target_: sam2.modeling.memory_encoder.Fuser - dtype: float16 layer: _target_: sam2.modeling.memory_encoder.CXBlock dim: 256 @@ -92,7 +82,6 @@ model: padding: 3 layer_scale_init_value: 1e-6 use_dwconv: True # depth-wise convs - dtype: float16 num_layers: 2 num_maskmem: 7 diff --git a/tripy/examples/segment-anything-model-v2/demo_utils.py b/tripy/examples/segment-anything-model-v2/demo_utils.py new file mode 100644 index 000000000..4cde953d6 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/demo_utils.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling SAM2 with Tripy or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import matplotlib.pyplot as plt + +plt.switch_backend("agg") # Switch to non-interactive backend +from typing import Tuple, Optional + + +def process_and_show_mask( + mask: np.ndarray, ax: plt.Axes, obj_id: Optional[int] = None, random_color: bool = False, borders: bool = False +) -> np.ndarray: + # Generate mask color + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + if obj_id is not None: + cmap = plt.get_cmap("tab10") + color = np.array([*cmap(obj_id)[:3], 0.6]) + else: + color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) + + # Process mask + h, w = mask.shape[-2:] + mask = mask.astype(np.uint8) + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + + # Add borders if requested + if borders: + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] + mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) + + ax.imshow(mask_image) + return mask_image + + +def show_points( + coords: np.ndarray, labels: np.ndarray, ax: plt.Axes, marker_size: int = 375 +) -> Tuple[np.ndarray, np.ndarray]: + """ + Display point prompts and return point coordinates for testing. + """ + pos_points = coords[labels == 1] + neg_points = coords[labels == 0] + + ax.scatter( + pos_points[:, 0], + pos_points[:, 1], + color="green", + marker="*", + s=marker_size, + edgecolor="white", + linewidth=1.25, + ) + ax.scatter( + neg_points[:, 0], + neg_points[:, 1], + color="red", + marker="*", + s=marker_size, + edgecolor="white", + linewidth=1.25, + ) + + return pos_points, neg_points + + +def show_box(box: np.ndarray, ax: plt.Axes) -> Tuple[float, float, float, float]: + """ + Display a bounding box and return its coordinates for testing. + """ + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)) + return x0, y0, w, h diff --git a/tripy/examples/segment-anything-model-v2/download_test_data.py b/tripy/examples/segment-anything-model-v2/download_test_data.py new file mode 100644 index 000000000..12c62c075 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/download_test_data.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import requests +from pathlib import Path +from PIL import Image +import io + + +def verify_jpeg(filepath): + try: + with Image.open(filepath) as img: + img.verify() + return True + except: + return False + + +def download_file(url, filepath, max_retries=5, timeout=5): + """Download file with retries and timeout.""" + for attempt in range(max_retries): + try: + print(f"Downloading {filepath}...") + response = requests.get(url, timeout=timeout) + response.raise_for_status() + + with open(filepath, "wb") as f: + f.write(response.content) + + # Verify if it's a valid JPEG + if verify_jpeg(filepath): + return True + else: + print(f"Invalid JPEG file {filepath}") + os.remove(filepath) # Remove invalid file + + except (requests.exceptions.RequestException, IOError) as e: + print(f"Error downloading {filepath} (attempt {attempt + 1}/{max_retries}): {str(e)}") + time.sleep(0.01) + continue + print(f"Failed to download {filepath} after {max_retries} attempts") + return False + + +def main(): + # Download test image for image segmentation + truck_url = "https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg" + download_file(truck_url, "truck.jpg") + + # Create bedroom directory if it doesn't exist + bedroom_dir = Path("bedroom") + bedroom_dir.mkdir(exist_ok=True) + + # Download images for video segmentation + base_url = "https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/videos/bedroom/{:05d}.jpg" + + for i in range(200): # 0 to 199 + filepath = bedroom_dir / f"{i:05d}.jpg" + download_file(base_url.format(i), filepath) + + +if __name__ == "__main__": + main() diff --git a/tripy/examples/segment-anything-model-v2/image_demo.py b/tripy/examples/segment-anything-model-v2/image_demo.py index 43107c833..57be089cb 100644 --- a/tripy/examples/segment-anything-model-v2/image_demo.py +++ b/tripy/examples/segment-anything-model-v2/image_demo.py @@ -24,82 +24,16 @@ plt.switch_backend("agg") # Switch to non-interactive backend from PIL import Image -from typing import Tuple, Optional, Dict +from typing import Optional, Dict from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor +from demo_utils import process_and_show_mask, show_box, show_points parser = argparse.ArgumentParser() parser.add_argument("-b", "--batch", type=int, default=2, help="batch size of the input images, between [1, 4]") -def process_and_show_mask( - mask: np.ndarray, ax: plt.Axes, random_color: bool = False, borders: bool = True -) -> np.ndarray: - """ - Process and display a segmentation mask, returning the processed mask for testing. - """ - # Generate mask color - if random_color: - color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) - else: - color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) - - # Process mask - h, w = mask.shape[-2:] - mask = mask.astype(np.uint8) - mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) - - if borders: - contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) - contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] - mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) - - ax.imshow(mask_image) - return mask_image - - -def show_points( - coords: np.ndarray, labels: np.ndarray, ax: plt.Axes, marker_size: int = 375 -) -> Tuple[np.ndarray, np.ndarray]: - """ - Display point prompts and return point coordinates for testing. - """ - pos_points = coords[labels == 1] - neg_points = coords[labels == 0] - - ax.scatter( - pos_points[:, 0], - pos_points[:, 1], - color="green", - marker="*", - s=marker_size, - edgecolor="white", - linewidth=1.25, - ) - ax.scatter( - neg_points[:, 0], - neg_points[:, 1], - color="red", - marker="*", - s=marker_size, - edgecolor="white", - linewidth=1.25, - ) - - return pos_points, neg_points - - -def show_box(box: np.ndarray, ax: plt.Axes) -> Tuple[float, float, float, float]: - """ - Display a bounding box and return its coordinates for testing. - """ - x0, y0 = box[0], box[1] - w, h = box[2] - box[0], box[3] - box[1] - ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)) - return x0, y0, w, h - - def process_predictions( image: np.ndarray, masks: np.ndarray, diff --git a/tripy/examples/segment-anything-model-v2/sam2/build_sam.py b/tripy/examples/segment-anything-model-v2/sam2/build_sam.py index 561d5fa7c..27ad7ca64 100644 --- a/tripy/examples/segment-anything-model-v2/sam2/build_sam.py +++ b/tripy/examples/segment-anything-model-v2/sam2/build_sam.py @@ -65,19 +65,19 @@ def get_component_configs(model, cfg): "dtype": model_precision, "compile_args": [ tp.InputInfo( - (4096, 1, 256), + (4096, (1, 2, 8), 256), getattr(tp, model_precision), ), tp.InputInfo( - ((4100, 16400, 28736), 1, 64), + ((4100, 16400, 28736), (1, 2, 8), 64), getattr(tp, model_precision), ), tp.InputInfo( - (4096, 1, 256), + (4096, (1, 2, 8), 256), getattr(tp, model_precision), ), tp.InputInfo( - ((4100, 16400, 28736), 1, 64), + ((4100, 16400, 28736), (1, 2, 8), 64), getattr(tp, model_precision), ), tp.InputInfo(((4, 16, 64),), tp.int32), @@ -182,10 +182,10 @@ def get_component_configs(model, cfg): "memory_encoder": { "enabled": True, "model": model.memory_encoder, - "dtype": model_precision, # TODO add fp16 to yaml + "dtype": "float32", # TODO add fp16 to yaml "compile_args": [ - tp.InputInfo((1, 256, 64, 64), getattr(tp, model_precision)), - tp.InputInfo((1, 1, 1024, 1024), getattr(tp, model_precision)), + tp.InputInfo((batchsize, 256, 64, 64), tp.float32), + tp.InputInfo((batchsize, num_obj, 1024, 1024), tp.float32), True, ], "skip_dtype_convert": ["ln", "norm"] diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py index 6c0466444..1a84deae0 100644 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/memory_attention.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, List from sam2.modeling.sam.transformer import RoPEAttention from sam2.modeling.sam2_utils import get_activation_fn @@ -117,19 +117,72 @@ def __init__( self, d_model: int, pos_enc_at_input: bool, - layer: tp.Module, num_layers: int, - batch_first: bool = True, # Do layers expect batch first input? + activation: str, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + sa_rope_theta: float, + sa_feat_sizes: List[int], + sa_embedding_dim: int, + sa_num_heads: int, + sa_downsample_rate: int, + sa_dropout: float, + ca_rope_theta: float, + ca_feat_sizes: List[int], + ca_rope_k_repeat: bool, + ca_embedding_dim: int, + ca_num_heads: int, + ca_downsample_rate: int, + ca_dropout: float, + ca_kv_in_dim: int, + batch_first: bool = True, dtype="float32", ): super().__init__() self.d_model = d_model - self.layers = [layer for i in range(num_layers)] self.num_layers = num_layers self.norm = tp.LayerNorm(d_model) self.pos_enc_at_input = pos_enc_at_input self.batch_first = batch_first self.dtype = getattr(tp, dtype) + self.layers = [] + for _ in range(num_layers): + self_attn = RoPEAttention( + sa_embedding_dim, + sa_num_heads, + sa_downsample_rate, + sa_dropout, + rope_theta=sa_rope_theta, + feat_sizes=sa_feat_sizes, + dtype=dtype, + ) + cross_attn = RoPEAttention( + ca_embedding_dim, + ca_num_heads, + ca_downsample_rate, + ca_dropout, + ca_kv_in_dim, + rope_theta=ca_rope_theta, + rope_k_repeat=ca_rope_k_repeat, + feat_sizes=ca_feat_sizes, + dtype=dtype, + ) + memory_attn_layer = MemoryAttentionLayer( + activation=activation, + cross_attention=cross_attn, + d_model=d_model, + dim_feedforward=dim_feedforward, + dropout=dropout, + pos_enc_at_attn=pos_enc_at_attn, + pos_enc_at_cross_attn_keys=pos_enc_at_cross_attn_keys, + pos_enc_at_cross_attn_queries=pos_enc_at_cross_attn_queries, + self_attention=self_attn, + dtype=dtype, + ) + self.layers.append(memory_attn_layer) def __call__( self, diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/mask_decoder.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/mask_decoder.py index b125e9522..ccec4d272 100755 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/mask_decoder.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/mask_decoder.py @@ -349,8 +349,24 @@ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): batch_inds = tp.arange(multimask_iou_scores.shape[0], dtype=self.dtype) batch_inds = tp.cast(batch_inds, tp.int32) - best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] - best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + def indexing(tensor, first_index, second_index): + step1 = tp.gather(tensor, dim=0, index=first_index) + + batch_size = first_index.shape[0] + row_indices = tp.arange(batch_size, dtype=tp.int32) + + combined_indices = tp.stack([row_indices, second_index], dim=1) + + flattened = tp.flatten(step1) + + flat_indices = combined_indices[:, 0] * batch_size + combined_indices[:, 1] + + result = tp.gather(flattened, dim=0, index=flat_indices) + + return result + + best_multimask_logits = indexing(multimask_logits, batch_inds, best_scores_inds) + best_multimask_iou_scores = indexing(multimask_iou_scores, batch_inds, best_scores_inds) best_multimask_logits = tp.unsqueeze(best_multimask_logits, 1) best_multimask_iou_scores = tp.unsqueeze(best_multimask_iou_scores, 1) diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py index 31f74cdf7..5c183365c 100755 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam/transformer.py @@ -316,8 +316,6 @@ def __init__( self.dtype = getattr(tp, dtype) super().__init__(*args, dtype=self.dtype, **kwargs) self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) - freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) - self.freqs_cis = freqs_cis self.rope_k_repeat = rope_k_repeat def __call__(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: tp.Tensor) -> Tensor: @@ -337,8 +335,8 @@ def forward(self, q: tp.Tensor, k: tp.Tensor, v: tp.Tensor, num_k_exclude_rope: # Apply rotary position encoding # w = h = tp.DimensionSize(tp.cast(tp.sqrt(tp.cast(q.shape[-2], tp.float32)), tp.int32)) # DDS? w = h = tp.DimensionSize(64) # Current demo always uses 64. - self.freqs_cis = self.compute_cis(end_x=w, end_y=h) - self.freqs_cis = tp.cast(self.freqs_cis, self.dtype) + freqs_cis = self.compute_cis(end_x=w, end_y=h) + self.freqs_cis = tp.cast(freqs_cis, self.dtype) num_k_rope = k.shape[-2] - num_k_exclude_rope q, new_k = apply_rotary_enc( diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py index b1bde83eb..b5b4bb78c 100644 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_base.py @@ -305,6 +305,9 @@ def _forward_sam_heads( sam_point_coords = torch.zeros(B, 1, 2, device=device) sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + # b) Handle mask prompts + # Issue #445 will add mask_input support. + sam_point_coords = tp.Tensor(sam_point_coords.contiguous()) sam_point_labels = tp.Tensor(sam_point_labels.contiguous()) @@ -315,12 +318,12 @@ def _forward_sam_heads( self.dense_pe = self.sam_prompt_encoder.get_dense_pe() hres_1 = high_res_features[0] hres_2 = high_res_features[1] - if self.model.model_dtype == tp.float16: - image_embedding = image_embedding.half() + if self.model_dtype == tp.float16: + image_embedding = backbone_features.half() hres_1 = hres_1.half() hres_2 = hres_2.half() - tp_backbone_features = tp.Tensor(backbone_features.contiguous()) + tp_backbone_features = tp.Tensor(image_embedding.contiguous()) hres_1 = tp.Tensor(hres_1.contiguous()) hres_2 = tp.Tensor(hres_2.contiguous()) @@ -469,7 +472,9 @@ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs) def forward_image(self, img_batch: tp.Tensor): """Get the image feature on the input batch.""" - + if not isinstance(img_batch, tp.Tensor): + img_batch = img_batch.to(getattr(torch, self.image_encoder.trunk.dtype)).contiguous() + img_batch = tp.Tensor(img_batch) backbone_out = self.image_encoder(img_batch) if self.use_high_res_features_in_sam: @@ -713,7 +718,7 @@ def _encode_new_memory( mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc maskmem_features, maskmem_pos_enc = self.memory_encoder( - tp.Tensor(pix_feat.contiguous()), tp.Tensor(mask_for_mem.contiguous()) + tp.Tensor(pix_feat.float().contiguous()), tp.Tensor(mask_for_mem.contiguous()) ) # sigmoid already applied maskmem_features = torch.from_dlpack(maskmem_features) maskmem_pos_enc = [torch.from_dlpack(maskmem_pos_enc)] @@ -798,7 +803,6 @@ def track_step( current_out["pred_masks"] = low_res_masks current_out["pred_masks_high_res"] = high_res_masks current_out["obj_ptr"] = obj_ptr - # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) if run_mem_encoder and self.num_maskmem > 0: diff --git a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py index 1f7c58bf4..bf534853f 100755 --- a/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py +++ b/tripy/examples/segment-anything-model-v2/sam2/modeling/sam2_utils.py @@ -22,7 +22,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from typing import Optional, Callable import torch @@ -214,10 +213,6 @@ def get_activation_fn(activation): raise RuntimeError(f"activation should be relu/gelu, not {activation}.") -def get_clones(module, N): - return [copy.deepcopy(module) for _ in range(N)] - - def cartesian_via_polar(abs, angles): r""" Constructs the real-valued cartesian coordinates from magnitude and angle representing polar coordinates. For input diff --git a/tripy/examples/segment-anything-model-v2/sam2/sam2_video_predictor.py b/tripy/examples/segment-anything-model-v2/sam2/sam2_video_predictor.py new file mode 100644 index 000000000..77e73dd35 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/sam2_video_predictor.py @@ -0,0 +1,948 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling SAM2 with Tripy or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = torch.device("cuda") + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return sam_model + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"]) + return frame_idx, obj_ids, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output(inference_state, consolidated_out["pred_masks_video_res"]) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr(inference_state, frame_idx) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object(inference_state, frame_idx, consolidated_out, storage_key) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(inference_state, frame_idx, current_out, storage_key) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output(inference_state, pred_masks) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object(self, inference_state, frame_idx, current_out, storage_key): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get(frame_idx, (None, None)) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand(batch_size, -1, -1, -1) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores(pred_masks_gpu, self.fill_hole_area) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder(self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(inference_state, frame_idx, batch_size) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, {"maskmem_pos_enc": maskmem_pos_enc}) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/tripy/examples/segment-anything-model-v2/sam2/utils/misc.py b/tripy/examples/segment-anything-model-v2/sam2/utils/misc.py new file mode 100644 index 000000000..5429d51d3 --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/sam2/utils/misc.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling SAM2 with Tripy or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor(self.img_paths[index], self.image_size) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [p for p in os.listdir(jpg_folder) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/tripy/examples/segment-anything-model-v2/video_demo.py b/tripy/examples/segment-anything-model-v2/video_demo.py new file mode 100644 index 000000000..4b11e580a --- /dev/null +++ b/tripy/examples/segment-anything-model-v2/video_demo.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling SAM2 with Tripy or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions: +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from sam2.build_sam import build_sam2_video_predictor +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +from demo_utils import process_and_show_mask as show_mask +from typing import Optional + +import os +import time + + +def compute_mask_properties(mask): + # Ensure we have a boolean array + test_mask = np.asarray(mask, dtype=bool) + + # Calculate basic stats + volume = np.sum(test_mask) + + # Calculate centroid (center of mass) + indices = np.where(test_mask) + assert volume > 0 + centroid = tuple(float(np.mean(idx)) for idx in indices) + return volume, centroid + + +def main(video_dir: str, save_path: Optional[str] = None): + """ + Main execution function. + + Args: + video_path (str): Path to where video frames are stored + save_path (str, optional): Directory to save visualizations + + Returns: + Dict[str, np.ndarray]: Processing results + """ + + sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt" + model_cfg = "sam2_hiera_l.yaml" + predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=torch.device("cuda")) + + # scan all the JPEG frame names in this directory + frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + # take a look the first video frame + frame_idx = 0 + if save_path: + plt.figure(figsize=(9, 6)) + plt.title(f"frame {frame_idx}") + plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx]))) + plt.savefig(os.path.join(save_path, f"video_{frame_idx}.png")) + plt.close("all") + + inference_state = predictor.init_state(video_path=video_dir) + + def make_tensors_contiguous(d): + for key, value in d.items(): + if isinstance(value, torch.Tensor): + d[key] = value.contiguous() + return d + + inference_state = make_tensors_contiguous(inference_state) + + predictor.reset_state(inference_state) + + prompts = {} # hold all the clicks we add for visualization + + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 2 # give a unique id to each object we interact with (it can be any integers) + + # Let's add a positive click at (x, y) = (200, 300) to get started on the first object + points = np.array([[200, 300]], dtype=np.float32) + # for labels, `1` means positive click and `0` means negative click + labels = np.array([1], np.int32) + prompts[ann_obj_id] = points, labels + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, + ) + + # add the first object + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 2 # give a unique id to each object we interact with (it can be any integers) + + # Let's add a 2nd negative click at (x, y) = (275, 175) to refine the first object + # sending all clicks (and their labels) to `add_new_points_or_box` + points = np.array([[200, 300], [275, 175]], dtype=np.float32) + # for labels, `1` means positive click and `0` means negative click + labels = np.array([1, 0], np.int32) + prompts[ann_obj_id] = points, labels + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, + ) + + ann_frame_idx = 0 # the frame index we interact with + ann_obj_id = 3 # give a unique id to each object we interact with (it can be any integers) + + # Let's now move on to the second object we want to track (giving it object id `3`) + # with a positive click at (x, y) = (400, 150) + points = np.array([[400, 150]], dtype=np.float32) + # for labels, `1` means positive click and `0` means negative click + labels = np.array([1], np.int32) + prompts[ann_obj_id] = points, labels + + # `add_new_points_or_box` returns masks for all objects added so far on this interacted frame + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, + ) + + # run propagation throughout the video and collect the results in a dict + start = time.perf_counter() + video_segments = {} # video_segments contains the per-frame segmentation results + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) + } + end = time.perf_counter() + print(f"Video segmentation took {(end - start)}s") + + if save_path: + os.makedirs(save_path, exist_ok=True) + # render the segmentation results every few frames + vis_frame_stride = 30 + for out_frame_idx in range(0, len(frame_names), vis_frame_stride): + plt.figure(figsize=(6, 4)) + plt.title(f"frame {out_frame_idx}") + plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))) + for out_obj_id, out_mask in video_segments[out_frame_idx].items(): + vol, centre = compute_mask_properties(out_mask) + show_mask(out_mask, plt.gca(), obj_id=out_obj_id) + plt.savefig(os.path.join(save_path, f"video_final_mask_{out_frame_idx}.png")) + + # Print the properties of the mask generated for the final image for integration testing. + for last_frame_obj_id, last_frame_obj_mask in video_segments[ + len(frame_names) - (len(frame_names) % vis_frame_stride) + ].items(): + vol, centre = compute_mask_properties(last_frame_obj_mask) + print(f"Last frame object {last_frame_obj_id} has mask properties: volume {vol}, centre {centre}") + + +if __name__ == "__main__": + main("./bedroom", save_path="output") diff --git a/tripy/tests/test_examples.py b/tripy/tests/test_examples.py index 6401998c2..cef9fd3b4 100644 --- a/tripy/tests/test_examples.py +++ b/tripy/tests/test_examples.py @@ -82,7 +82,10 @@ def __str__(self): EXAMPLES = [ Example(["nanogpt"]), - Example(["segment-anything-model-v2"], artifact_names=["truck.jpg", "saved_engines/", "output/", "checkpoints/"]), + Example( + ["segment-anything-model-v2"], + artifact_names=["truck.jpg", "bedroom", "saved_engines/", "output/", "checkpoints/"], + ), ]