Skip to content

Commit

Permalink
Enable sam2 video pipeline (#444)
Browse files Browse the repository at this point in the history
Co-authored-by: Yizhuo Zhang <[email protected]>
  • Loading branch information
parthchadha and yizhuoz004 authored Dec 13, 2024
1 parent f8d4fe5 commit ab9fa85
Show file tree
Hide file tree
Showing 15 changed files with 1,732 additions and 138 deletions.
27 changes: 20 additions & 7 deletions tripy/examples/segment-anything-model-v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ This is an implementation of SAM2 model ([original repository](https://github.co

## Running The Example

### Image pipeline

1. Install prerequisites:

```bash
sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 jpeginfo -y
python3 -m pip install -r requirements.txt
```

2. Retrieve an example image and the checkpoint:
2. Retrieve the images and the checkpoint:

```bash
wget -O truck.jpg https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg
python3 download_test_data.py
mkdir checkpoints && cd checkpoints && wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
```

3. Run the example:
### Image pipeline

1. Run the example:

```bash
python3 image_demo.py
Expand All @@ -38,7 +38,20 @@ This is an implementation of SAM2 model ([original repository](https://github.co

### Video segmentation pipeline

TBD
1. Run the example:

```bash
python3 video_demo.py
```

<!--
Tripy: TEST: EXPECTED_STDOUT Start
```
Last frame object 2 has mask properties: volume {16338~5%}, centre (0.0, {95.80028155220957~5%}, {133.8682825315216~5%})
Last frame object 3 has mask properties: volume {4415~5%}, centre (0.0, {161.95605889014723~5%}, {421.4523216308041~5%})
```
Tripy: TEST: EXPECTED_STDOUT End
-->


## License
Expand Down
57 changes: 23 additions & 34 deletions tripy/examples/segment-anything-model-v2/configs/sam2_hiera_l.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,34 @@ model:
d_model: 256
pos_enc_at_input: true
dtype: float16
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
dtype: float16
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
dtype: float16
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
dtype: float16
num_layers: 4
# memory attention layer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
# self rope attention
sa_rope_theta: 10000.0
sa_feat_sizes: [32, 32]
sa_embedding_dim: 256
sa_num_heads: 1
sa_downsample_rate: 1
sa_dropout: 0.1
# cross rope attention
ca_rope_theta: 10000.0
ca_feat_sizes: [32, 32]
ca_rope_k_repeat: True
ca_embedding_dim: 256
ca_num_heads: 1
ca_downsample_rate: 1
ca_dropout: 0.1
ca_kv_in_dim: 64

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
dtype: float16
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
Expand All @@ -81,18 +73,15 @@ model:
kernel_size: 3
stride: 2
padding: 1
dtype: float16
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
dtype: float16
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
dtype: float16
num_layers: 2

num_maskmem: 7
Expand Down
99 changes: 99 additions & 0 deletions tripy/examples/segment-anything-model-v2/demo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Not a contribution
# Changes made by NVIDIA CORPORATION & AFFILIATES enabling SAM2 with Tripy or otherwise documented as
# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cv2
import numpy as np
import matplotlib.pyplot as plt

plt.switch_backend("agg") # Switch to non-interactive backend
from typing import Tuple, Optional


def process_and_show_mask(
mask: np.ndarray, ax: plt.Axes, obj_id: Optional[int] = None, random_color: bool = False, borders: bool = False
) -> np.ndarray:
# Generate mask color
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
if obj_id is not None:
cmap = plt.get_cmap("tab10")
color = np.array([*cmap(obj_id)[:3], 0.6])
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])

# Process mask
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

# Add borders if requested
if borders:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)

ax.imshow(mask_image)
return mask_image


def show_points(
coords: np.ndarray, labels: np.ndarray, ax: plt.Axes, marker_size: int = 375
) -> Tuple[np.ndarray, np.ndarray]:
"""
Display point prompts and return point coordinates for testing.
"""
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]

ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)

return pos_points, neg_points


def show_box(box: np.ndarray, ax: plt.Axes) -> Tuple[float, float, float, float]:
"""
Display a bounding box and return its coordinates for testing.
"""
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
return x0, y0, w, h
76 changes: 76 additions & 0 deletions tripy/examples/segment-anything-model-v2/download_test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import requests
from pathlib import Path
from PIL import Image
import io


def verify_jpeg(filepath):
try:
with Image.open(filepath) as img:
img.verify()
return True
except:
return False


def download_file(url, filepath, max_retries=5, timeout=5):
"""Download file with retries and timeout."""
for attempt in range(max_retries):
try:
print(f"Downloading {filepath}...")
response = requests.get(url, timeout=timeout)
response.raise_for_status()

with open(filepath, "wb") as f:
f.write(response.content)

# Verify if it's a valid JPEG
if verify_jpeg(filepath):
return True
else:
print(f"Invalid JPEG file {filepath}")
os.remove(filepath) # Remove invalid file

except (requests.exceptions.RequestException, IOError) as e:
print(f"Error downloading {filepath} (attempt {attempt + 1}/{max_retries}): {str(e)}")
time.sleep(0.01)
continue
print(f"Failed to download {filepath} after {max_retries} attempts")
return False


def main():
# Download test image for image segmentation
truck_url = "https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/images/truck.jpg"
download_file(truck_url, "truck.jpg")

# Create bedroom directory if it doesn't exist
bedroom_dir = Path("bedroom")
bedroom_dir.mkdir(exist_ok=True)

# Download images for video segmentation
base_url = "https://raw.githubusercontent.com/facebookresearch/sam2/main/notebooks/videos/bedroom/{:05d}.jpg"

for i in range(200): # 0 to 199
filepath = bedroom_dir / f"{i:05d}.jpg"
download_file(base_url.format(i), filepath)


if __name__ == "__main__":
main()
70 changes: 2 additions & 68 deletions tripy/examples/segment-anything-model-v2/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,82 +24,16 @@

plt.switch_backend("agg") # Switch to non-interactive backend
from PIL import Image
from typing import Tuple, Optional, Dict
from typing import Optional, Dict

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from demo_utils import process_and_show_mask, show_box, show_points

parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch", type=int, default=2, help="batch size of the input images, between [1, 4]")


def process_and_show_mask(
mask: np.ndarray, ax: plt.Axes, random_color: bool = False, borders: bool = True
) -> np.ndarray:
"""
Process and display a segmentation mask, returning the processed mask for testing.
"""
# Generate mask color
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])

# Process mask
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

if borders:
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)

ax.imshow(mask_image)
return mask_image


def show_points(
coords: np.ndarray, labels: np.ndarray, ax: plt.Axes, marker_size: int = 375
) -> Tuple[np.ndarray, np.ndarray]:
"""
Display point prompts and return point coordinates for testing.
"""
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]

ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)

return pos_points, neg_points


def show_box(box: np.ndarray, ax: plt.Axes) -> Tuple[float, float, float, float]:
"""
Display a bounding box and return its coordinates for testing.
"""
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
return x0, y0, w, h


def process_predictions(
image: np.ndarray,
masks: np.ndarray,
Expand Down
Loading

0 comments on commit ab9fa85

Please sign in to comment.