Skip to content

Commit

Permalink
Use singleton class to manage sharing of models across image and vide…
Browse files Browse the repository at this point in the history
…o pipeline
  • Loading branch information
parthchadha committed Dec 16, 2024
1 parent ab9fa85 commit 77d5d20
Showing 1 changed file with 111 additions and 100 deletions.
211 changes: 111 additions & 100 deletions tripy/examples/segment-anything-model-v2/sam2/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
from typing import Dict, Any, Optional

import tripy as tp
import time
Expand Down Expand Up @@ -132,7 +133,7 @@ def get_component_configs(model, cfg):
dtype=getattr(tp, model_precision),
), # image_pe
tp.InputInfo(
(batchsize, 2, 256),
(batchsize, (2, 4, 6), 256),
dtype=getattr(tp, model_precision),
), # sparse_prompt_embeddings
tp.InputInfo(
Expand Down Expand Up @@ -250,54 +251,33 @@ def get_component_configs(model, cfg):
}


def build_sam2(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
**kwargs,
):
class SAM2ModelCache:
"""Singleton class to manage cached compiled models for SAM2."""

if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
]
# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
OmegaConf.resolve(cfg)
_instance = None

model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path, cfg)
def __new__(cls):
if cls._instance is None:
cls._instance = super(SAM2ModelCache, cls).__new__(cls)
cls._instance.cached_models = {}
cls._instance.saved_engines_path = os.path.join(os.getcwd(), "saved_engines")
if not os.path.exists(cls._instance.saved_engines_path):
os.makedirs(cls._instance.saved_engines_path)
return cls._instance

current_dir = os.getcwd()
saved_engines_path = os.path.join(current_dir, "saved_engines")
def get_or_compile_component(self, comp_name: str, comp_info: Dict[str, Any]) -> tp.Executable:
"""Get a cached compiled model or compile and cache it if not exists."""
if not comp_info["enabled"]:
return None

# Create the saved_engines directory if it doesn't exist
if not os.path.exists(saved_engines_path):
os.makedirs(saved_engines_path)
# Check if already in memory cache
if comp_name in self.cached_models:
print(f"Using in-memory cached model for {comp_name}")
return self.cached_models[comp_name]

# Get component configurations
components = get_component_configs(model, cfg)
required_components_for_image = [
"sam_mask_decoder_true",
"sam_mask_decoder.conv_s0",
"sam_mask_decoder.conv_s1",
"sam_prompt_encoder",
"sam_prompt_encoder.get_dense_pe",
"image_encoder.compiled_executable",
]
executable_file = os.path.join(self.saved_engines_path, comp_name)

for comp_name, comp_info in components.items():
if not comp_info["enabled"] or comp_name not in required_components_for_image:
continue

executable_file = os.path.join(saved_engines_path, comp_name)
# Check if compiled model exists on disk
if os.path.exists(executable_file):
print(f"Loading existing compiled {comp_name} from {executable_file}")
compiled_model = tp.Executable.load(executable_file)
Expand All @@ -308,92 +288,123 @@ def build_sam2(
print(f"Compilation took {time.time() - start:.2f}s")
compiled_model.save(executable_file)

old_model = comp_info["model"]
# If model is model.forward, retrieve the original model object
if hasattr(old_model, "__self__"):
old_model = old_model.__self__

set_model_attr(model, comp_name, compiled_model)
if "special_handling" in comp_info and comp_info["special_handling"] is not None:
comp_info["special_handling"](old_model)

model = model.to(device)
if mode == "eval":
model.eval()
return model
# Cache the compiled model in memory
self.cached_models[comp_name] = compiled_model
return compiled_model


def build_sam2_video_predictor(
config_file,
ckpt_path=None,
device="cuda",
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
def build_sam2_base(
config_file: str,
ckpt_path: Optional[str] = None,
device: str = "cuda",
mode: str = "eval",
hydra_overrides: list = None,
apply_postprocessing: bool = True,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
]
) -> Any:
"""Base function for building SAM2 models with caching support."""
if hydra_overrides is None:
hydra_overrides = []

if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
# dynamically fall back to multi-mask if the single mask is not stable
hydra_overrides = hydra_overrides.copy()
hydra_overrides += [
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
"++model.binarize_mask_from_pts_for_mem_enc=true",
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
"++model.fill_hole_area=8",
]
hydra_overrides.extend(hydra_overrides_extra)

# Read config and init model
cfg = compose(config_name=config_file, overrides=hydra_overrides)
OmegaConf.resolve(cfg)

model = instantiate(cfg.model, _recursive_=True)
_load_checkpoint(model, ckpt_path, cfg)

current_dir = os.getcwd()
saved_engines_path = os.path.join(current_dir, "saved_engines")
# Create the saved_engines directory if it doesn't exist
if not os.path.exists(saved_engines_path):
os.makedirs(saved_engines_path)

# Get component configurations
# Get component configurations and initialize cache
components = get_component_configs(model, cfg)
model_cache = SAM2ModelCache()

# Compile or load all required components
for comp_name, comp_info in components.items():
if not comp_info["enabled"]:
continue

executable_file = os.path.join(saved_engines_path, comp_name)
if os.path.exists(executable_file):
print(f"Loading existing compiled {comp_name} from {executable_file}")
compiled_model = tp.Executable.load(executable_file)
else:
print(f"Compiling {comp_name}...")
start = time.time()
compiled_model = tp.compile(comp_info["model"], args=comp_info["compile_args"])
print(f"Compilation took {time.time() - start:.2f}s")
compiled_model.save(executable_file)

old_model = comp_info["model"]
# If model is model.forward, retrieve the original model object
if hasattr(old_model, "__self__"):
old_model = old_model.__self__
compiled_model = model_cache.get_or_compile_component(comp_name, comp_info)
if compiled_model is not None:
old_model = comp_info["model"]
if hasattr(old_model, "__self__"):
old_model = old_model.__self__

set_model_attr(model, comp_name, compiled_model)
if "special_handling" in comp_info and comp_info["special_handling"] is not None:
comp_info["special_handling"](old_model)
set_model_attr(model, comp_name, compiled_model)
if "special_handling" in comp_info and comp_info["special_handling"] is not None:
comp_info["special_handling"](old_model)

model = model.to(device)
if mode == "eval":
model.eval()
return model


def build_sam2(
config_file: str,
ckpt_path: Optional[str] = None,
device: str = "cuda",
mode: str = "eval",
hydra_overrides_extra: list = None,
apply_postprocessing: bool = True,
use_tripy_image_encoder: bool = False,
**kwargs,
) -> Any:
"""Build SAM2 model with caching support."""
return build_sam2_base(
config_file=config_file,
ckpt_path=ckpt_path,
device=device,
mode=mode,
hydra_overrides=hydra_overrides_extra,
apply_postprocessing=apply_postprocessing,
**kwargs,
)


def build_sam2_video_predictor(
config_file: str,
ckpt_path: Optional[str] = None,
device: str = "cuda",
mode: str = "eval",
hydra_overrides_extra: list = None,
apply_postprocessing: bool = True,
**kwargs,
) -> Any:
"""Build SAM2 video predictor with caching support."""
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
]

if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy() if hydra_overrides_extra else []
hydra_overrides_extra += [
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
"++model.binarize_mask_from_pts_for_mem_enc=true",
"++model.fill_hole_area=8",
]

hydra_overrides.extend(hydra_overrides_extra or [])

return build_sam2_base(
config_file=config_file,
ckpt_path=ckpt_path,
device=device,
mode=mode,
hydra_overrides=hydra_overrides,
apply_postprocessing=apply_postprocessing,
**kwargs,
)


def load_component_weights(comp_name, component_info, state_dict, checkpoint_dict):
"""
Load weights for a single component from checkpoint into state dict.
Expand Down

0 comments on commit 77d5d20

Please sign in to comment.