Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,75 @@ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
non_cond_frame_outputs.pop(t, None)



@torch.inference_mode()
def merge_multiple_objects(
self,
inference_state,
merge_frame_idx,
frames, # kept for interface consistency though not used in the in-place update
obj_ids_to_merge,
partial_results_before_merge=None
):
"""
Merge all objects in `obj_ids_to_merge` by unioning their video-resolution masks
(`pred_masks`) at `merge_frame_idx` into a single object mask. This function updates
the given inference state in-place without reinitializing it.
"""

mask_list = []
valid_obj_ids = []
for obj_id in obj_ids_to_merge:
obj_idx = inference_state["obj_id_to_idx"].get(obj_id)
if obj_idx is None:
print(f"Object ID {obj_id} not found in inference state.")
continue
valid_obj_ids.append(obj_id)

# Decide whether to get from conditioning or non-conditioning outputs.
obj_output = inference_state["output_dict_per_obj"][obj_idx]
out = (obj_output["cond_frame_outputs"].get(merge_frame_idx) or
obj_output["non_cond_frame_outputs"].get(merge_frame_idx))
if out is None or "pred_masks" not in out:
print(f"No 'pred_masks' found for object ID {obj_id} at frame {merge_frame_idx}.")
continue

mask_video_res = out["pred_masks"] # Expected shape: [1, 1, H_video, W_video]
mask_2d = mask_video_res.squeeze(0).squeeze(0).cpu()
binary_mask = (mask_2d > 0).to(torch.uint8)
mask_list.append(binary_mask)

if not mask_list:
print("No valid masks found for merging. Aborting merge.")
return inference_state

# Union all the binary masks.
merged_mask = torch.zeros_like(mask_list[0], dtype=torch.uint8)
for m in mask_list:
merged_mask = torch.logical_or(merged_mask, m)

# Choose the representative merged object ID (here, the smallest ID).
merged_obj_id = min(valid_obj_ids)

# Update the merged object with the new union mask.
self.add_new_mask(
inference_state=inference_state,
frame_idx=merge_frame_idx,
obj_id=merged_obj_id,
mask=merged_mask # shape [H_video, W_video]
)

# Remove any other objects that were merged.
for obj_id in valid_obj_ids:
if obj_id == merged_obj_id:
continue
self.remove_object(inference_state, obj_id, strict=False, need_output=False)

print(f"Merged objects into obj_id {merged_obj_id}.")
return inference_state



class SAM2VideoPredictorVOS(SAM2VideoPredictor):
"""Optimized for the VOS setting"""

Expand Down