Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions notebooks/video_predictor_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,62 @@
" show_mask(out_mask, plt.gca(), obj_id=out_obj_id)"
]
},
{
"cell_type": "markdown",
"id": "f37870b9",
"metadata": {},
"source": [
"<br>\n",
"\n",
"### Save the Segmented Frames and Create a Video"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa0e361b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import numpy as np\n",
"from PIL import Image\n",
"\n",
"# Define video writer\n",
"h, w = Image.open(os.path.join(video_dir, frame_names[0])).size[::-1] # (height, width)\n",
"out = cv2.VideoWriter(\"segmented_output.mp4\", cv2.VideoWriter_fourcc(*\"mp4v\"), 25, (w, h))\n",
"\n",
"# Loop through all frames in order\n",
"for out_frame_idx in range(len(frame_names)):\n",
" # Load original frame\n",
" frame_path = os.path.join(video_dir, frame_names[out_frame_idx])\n",
" frame = np.array(Image.open(frame_path).convert(\"RGB\"))\n",
"\n",
" # Convert to BGR for OpenCV\n",
" frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n",
"\n",
" # If we have segmentation results for this frame\n",
" if out_frame_idx in video_segments:\n",
" for out_obj_id, out_mask in video_segments[out_frame_idx].items():\n",
" # Convert mask to uint8\n",
" mask = (out_mask.astype(np.uint8) * 255)\n",
"\n",
" # Blue mask\n",
" colored_mask = np.zeros_like(frame, dtype=np.uint8)\n",
" colored_mask[:, :, 0] = mask # Blue channel\n",
" colored_mask[:, :, 1] = 0 # Green channel\n",
" colored_mask[:, :, 2] = 0 # Red channel\n",
"\n",
" # Blend with original frame (alpha blending)\n",
" frame = cv2.addWeighted(frame, 1.0, colored_mask, 0.5, 0)\n",
"\n",
" # Write frame to video\n",
" out.write(frame)\n",
"\n",
"out.release()\n"
]
},
{
"cell_type": "markdown",
"id": "18a0b9d7-c78f-432b-afb0-11f2ea5b652a",
Expand All @@ -1414,7 +1470,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "sam2",
"language": "python",
"name": "python3"
},
Expand All @@ -1428,7 +1484,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.18"
}
},
"nbformat": 4,
Expand Down