Skip to content

Commit 86c349c

Browse files
author
mabingqi
committed
model loading revision
1 parent bd5b372 commit 86c349c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

IVM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def load(ckpt_path, low_gpu_memory = False):
2727
url = "https://drive.google.com/uc?export=download&id=1OyVci6rAwnb2sJPxhObgK7AvlLYDLLHw"
2828
sam_ckpt = _download(url, "sam_vit_h_4b8939.pth", os.path.expanduser(f"~/.cache/IVM/Sam"))
2929
ckpt = torch.load(ckpt_path, map_location="cpu")
30-
model = IVM(sam_model=sam_ckpt)
30+
model = IVM(sam_model=sam_ckpt).eval()
3131
model.load_state_dict(ckpt, strict=False)
3232
if low_gpu_memory: return accelerate.cpu_offload(model, "cuda")
3333
else: return model.cuda()

0 commit comments

Comments
 (0)