diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 5a7e1a01c..5c17466d5 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -973,6 +973,75 @@ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): non_cond_frame_outputs.pop(t, None) + + @torch.inference_mode() + def merge_multiple_objects( + self, + inference_state, + merge_frame_idx, + frames, # kept for interface consistency though not used in the in-place update + obj_ids_to_merge, + partial_results_before_merge=None + ): + """ + Merge all objects in `obj_ids_to_merge` by unioning their video-resolution masks + (`pred_masks`) at `merge_frame_idx` into a single object mask. This function updates + the given inference state in-place without reinitializing it. + """ + + mask_list = [] + valid_obj_ids = [] + for obj_id in obj_ids_to_merge: + obj_idx = inference_state["obj_id_to_idx"].get(obj_id) + if obj_idx is None: + print(f"Object ID {obj_id} not found in inference state.") + continue + valid_obj_ids.append(obj_id) + + # Decide whether to get from conditioning or non-conditioning outputs. + obj_output = inference_state["output_dict_per_obj"][obj_idx] + out = (obj_output["cond_frame_outputs"].get(merge_frame_idx) or + obj_output["non_cond_frame_outputs"].get(merge_frame_idx)) + if out is None or "pred_masks" not in out: + print(f"No 'pred_masks' found for object ID {obj_id} at frame {merge_frame_idx}.") + continue + + mask_video_res = out["pred_masks"] # Expected shape: [1, 1, H_video, W_video] + mask_2d = mask_video_res.squeeze(0).squeeze(0).cpu() + binary_mask = (mask_2d > 0).to(torch.uint8) + mask_list.append(binary_mask) + + if not mask_list: + print("No valid masks found for merging. Aborting merge.") + return inference_state + + # Union all the binary masks. + merged_mask = torch.zeros_like(mask_list[0], dtype=torch.uint8) + for m in mask_list: + merged_mask = torch.logical_or(merged_mask, m) + + # Choose the representative merged object ID (here, the smallest ID). + merged_obj_id = min(valid_obj_ids) + + # Update the merged object with the new union mask. + self.add_new_mask( + inference_state=inference_state, + frame_idx=merge_frame_idx, + obj_id=merged_obj_id, + mask=merged_mask # shape [H_video, W_video] + ) + + # Remove any other objects that were merged. + for obj_id in valid_obj_ids: + if obj_id == merged_obj_id: + continue + self.remove_object(inference_state, obj_id, strict=False, need_output=False) + + print(f"Merged objects into obj_id {merged_obj_id}.") + return inference_state + + + class SAM2VideoPredictorVOS(SAM2VideoPredictor): """Optimized for the VOS setting"""