diff --git a/tripy/examples/segment-anything-model-v2/sam2/build_sam.py b/tripy/examples/segment-anything-model-v2/sam2/build_sam.py index 27ad7ca64..cb9c6ca74 100644 --- a/tripy/examples/segment-anything-model-v2/sam2/build_sam.py +++ b/tripy/examples/segment-anything-model-v2/sam2/build_sam.py @@ -29,6 +29,7 @@ from hydra import compose from hydra.utils import instantiate from omegaconf import OmegaConf +from typing import Dict, Any, Optional import tripy as tp import time @@ -132,7 +133,7 @@ def get_component_configs(model, cfg): dtype=getattr(tp, model_precision), ), # image_pe tp.InputInfo( - (batchsize, 2, 256), + (batchsize, (2, 4, 6), 256), dtype=getattr(tp, model_precision), ), # sparse_prompt_embeddings tp.InputInfo( @@ -250,54 +251,33 @@ def get_component_configs(model, cfg): } -def build_sam2( - config_file, - ckpt_path=None, - device="cuda", - mode="eval", - hydra_overrides_extra=[], - apply_postprocessing=True, - **kwargs, -): +class SAM2ModelCache: + """Singleton class to manage cached compiled models for SAM2.""" - if apply_postprocessing: - hydra_overrides_extra = hydra_overrides_extra.copy() - hydra_overrides_extra += [ - # dynamically fall back to multi-mask if the single mask is not stable - "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", - "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", - "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", - ] - # Read config and init model - cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) - OmegaConf.resolve(cfg) + _instance = None - model = instantiate(cfg.model, _recursive_=True) - _load_checkpoint(model, ckpt_path, cfg) + def __new__(cls): + if cls._instance is None: + cls._instance = super(SAM2ModelCache, cls).__new__(cls) + cls._instance.cached_models = {} + cls._instance.saved_engines_path = os.path.join(os.getcwd(), "saved_engines") + if not os.path.exists(cls._instance.saved_engines_path): + os.makedirs(cls._instance.saved_engines_path) + return cls._instance - current_dir = os.getcwd() - saved_engines_path = os.path.join(current_dir, "saved_engines") + def get_or_compile_component(self, comp_name: str, comp_info: Dict[str, Any]) -> tp.Executable: + """Get a cached compiled model or compile and cache it if not exists.""" + if not comp_info["enabled"]: + return None - # Create the saved_engines directory if it doesn't exist - if not os.path.exists(saved_engines_path): - os.makedirs(saved_engines_path) + # Check if already in memory cache + if comp_name in self.cached_models: + print(f"Using in-memory cached model for {comp_name}") + return self.cached_models[comp_name] - # Get component configurations - components = get_component_configs(model, cfg) - required_components_for_image = [ - "sam_mask_decoder_true", - "sam_mask_decoder.conv_s0", - "sam_mask_decoder.conv_s1", - "sam_prompt_encoder", - "sam_prompt_encoder.get_dense_pe", - "image_encoder.compiled_executable", - ] + executable_file = os.path.join(self.saved_engines_path, comp_name) - for comp_name, comp_info in components.items(): - if not comp_info["enabled"] or comp_name not in required_components_for_image: - continue - - executable_file = os.path.join(saved_engines_path, comp_name) + # Check if compiled model exists on disk if os.path.exists(executable_file): print(f"Loading existing compiled {comp_name} from {executable_file}") compiled_model = tp.Executable.load(executable_file) @@ -308,85 +288,57 @@ def build_sam2( print(f"Compilation took {time.time() - start:.2f}s") compiled_model.save(executable_file) - old_model = comp_info["model"] - # If model is model.forward, retrieve the original model object - if hasattr(old_model, "__self__"): - old_model = old_model.__self__ - - set_model_attr(model, comp_name, compiled_model) - if "special_handling" in comp_info and comp_info["special_handling"] is not None: - comp_info["special_handling"](old_model) - - model = model.to(device) - if mode == "eval": - model.eval() - return model + # Cache the compiled model in memory + self.cached_models[comp_name] = compiled_model + return compiled_model -def build_sam2_video_predictor( - config_file, - ckpt_path=None, - device="cuda", - mode="eval", - hydra_overrides_extra=[], - apply_postprocessing=True, +def build_sam2_base( + config_file: str, + ckpt_path: Optional[str] = None, + device: str = "cuda", + mode: str = "eval", + hydra_overrides: list = None, + apply_postprocessing: bool = True, **kwargs, -): - hydra_overrides = [ - "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", - ] +) -> Any: + """Base function for building SAM2 models with caching support.""" + if hydra_overrides is None: + hydra_overrides = [] + if apply_postprocessing: - hydra_overrides_extra = hydra_overrides_extra.copy() - hydra_overrides_extra += [ - # dynamically fall back to multi-mask if the single mask is not stable + hydra_overrides = hydra_overrides.copy() + hydra_overrides += [ "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", - # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking - "++model.binarize_mask_from_pts_for_mem_enc=true", - # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) - "++model.fill_hole_area=8", ] - hydra_overrides.extend(hydra_overrides_extra) # Read config and init model cfg = compose(config_name=config_file, overrides=hydra_overrides) OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path, cfg) - current_dir = os.getcwd() - saved_engines_path = os.path.join(current_dir, "saved_engines") - # Create the saved_engines directory if it doesn't exist - if not os.path.exists(saved_engines_path): - os.makedirs(saved_engines_path) - - # Get component configurations + # Get component configurations and initialize cache components = get_component_configs(model, cfg) + model_cache = SAM2ModelCache() + # Compile or load all required components for comp_name, comp_info in components.items(): if not comp_info["enabled"]: continue - executable_file = os.path.join(saved_engines_path, comp_name) - if os.path.exists(executable_file): - print(f"Loading existing compiled {comp_name} from {executable_file}") - compiled_model = tp.Executable.load(executable_file) - else: - print(f"Compiling {comp_name}...") - start = time.time() - compiled_model = tp.compile(comp_info["model"], args=comp_info["compile_args"]) - print(f"Compilation took {time.time() - start:.2f}s") - compiled_model.save(executable_file) - - old_model = comp_info["model"] - # If model is model.forward, retrieve the original model object - if hasattr(old_model, "__self__"): - old_model = old_model.__self__ + compiled_model = model_cache.get_or_compile_component(comp_name, comp_info) + if compiled_model is not None: + old_model = comp_info["model"] + if hasattr(old_model, "__self__"): + old_model = old_model.__self__ - set_model_attr(model, comp_name, compiled_model) - if "special_handling" in comp_info and comp_info["special_handling"] is not None: - comp_info["special_handling"](old_model) + set_model_attr(model, comp_name, compiled_model) + if "special_handling" in comp_info and comp_info["special_handling"] is not None: + comp_info["special_handling"](old_model) model = model.to(device) if mode == "eval": @@ -394,6 +346,65 @@ def build_sam2_video_predictor( return model +def build_sam2( + config_file: str, + ckpt_path: Optional[str] = None, + device: str = "cuda", + mode: str = "eval", + hydra_overrides_extra: list = None, + apply_postprocessing: bool = True, + use_tripy_image_encoder: bool = False, + **kwargs, +) -> Any: + """Build SAM2 model with caching support.""" + return build_sam2_base( + config_file=config_file, + ckpt_path=ckpt_path, + device=device, + mode=mode, + hydra_overrides=hydra_overrides_extra, + apply_postprocessing=apply_postprocessing, + **kwargs, + ) + + +def build_sam2_video_predictor( + config_file: str, + ckpt_path: Optional[str] = None, + device: str = "cuda", + mode: str = "eval", + hydra_overrides_extra: list = None, + apply_postprocessing: bool = True, + **kwargs, +) -> Any: + """Build SAM2 video predictor with caching support.""" + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() if hydra_overrides_extra else [] + hydra_overrides_extra += [ + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + "++model.binarize_mask_from_pts_for_mem_enc=true", + "++model.fill_hole_area=8", + ] + + hydra_overrides.extend(hydra_overrides_extra or []) + + return build_sam2_base( + config_file=config_file, + ckpt_path=ckpt_path, + device=device, + mode=mode, + hydra_overrides=hydra_overrides, + apply_postprocessing=apply_postprocessing, + **kwargs, + ) + + def load_component_weights(comp_name, component_info, state_dict, checkpoint_dict): """ Load weights for a single component from checkpoint into state dict.