Skip to content

Commit 46f106e

Browse files
committed
support batch infer
1 parent e7b3ce2 commit 46f106e

File tree

3 files changed

+94
-76
lines changed

3 files changed

+94
-76
lines changed

deploy.py

Lines changed: 63 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gdown
1111
import accelerate
1212
import torch.nn as nn
13+
from typing import Type, Tuple
1314
def _download(url: str, name: str,root: str):
1415
os.makedirs(root, exist_ok=True)
1516

@@ -25,85 +26,15 @@ def load(ckpt_path, type = "lisa", low_gpu_memory = False):
2526

2627
url = "https://drive.google.com/uc?export=download&id=1OyVci6rAwnb2sJPxhObgK7AvlLYDLLHw"
2728
model_path = _download(url, "sam_vit_h_4b8939.pth", os.path.expanduser(f"~/.cache/SeeWhatYouNeed/Sam"))
28-
if type == 'lisa':
29-
model = DeployModel_LISA(
30-
ckpt_path = ckpt_path,
31-
sam_ckpt=model_path,
32-
offload_languageencoder=low_gpu_memory
33-
)
34-
else:
35-
model = DeployModel_FAM(
36-
ckpt_path = ckpt_path,
37-
sam_ckpt=model_path
38-
).cuda()
29+
model = DeployModel_LISA(
30+
ckpt_path = ckpt_path,
31+
sam_ckpt=model_path,
32+
offload_languageencoder=low_gpu_memory
33+
)
3934
return model
4035

4136

4237

43-
44-
class DeployModel_FAM(nn.Module):
45-
def __init__(self,
46-
ckpt_path,
47-
sam_ckpt
48-
):
49-
super().__init__()
50-
self.model = FAM(
51-
sam_model=sam_ckpt
52-
)
53-
ckpt = torch.load(ckpt_path, map_location="cpu")
54-
if 'module' in ckpt: ckpt = ckpt['module']
55-
print(self.model.load_state_dict(ckpt, strict=False))
56-
57-
58-
@torch.no_grad()
59-
def forward(self,
60-
image: Image,
61-
instruction: str,
62-
blur_kernel_size = 201,
63-
threshold = 0.2,
64-
dilate_kernel_size = 11):
65-
66-
ori_size = image.size
67-
ori_image = np.asarray(image).astype(np.float32)
68-
resize_img = image.resize((self.model.image_encoder.img_size, self.model.image_encoder.img_size))
69-
70-
masks = F.interpolate(torch.sigmoid(masks),
71-
(ori_size[1], ori_size[0]),
72-
mode="bilinear",
73-
align_corners=False,
74-
)[0, 0, :, :].detach().cpu().numpy().astype(np.float32)[:,:,np.newaxis]
75-
76-
77-
mask_output = np.where(masks > threshold, 1, 0).astype(np.uint8)
78-
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(dilate_kernel_size,dilate_kernel_size)) #ksize=7x7,
79-
mask_output = cv2.dilate(mask_output,kernel,iterations=1).astype(np.float32)
80-
mask_output = cv2.GaussianBlur(mask_output, (dilate_kernel_size, dilate_kernel_size), 0)[:,:,np.newaxis]
81-
82-
rgba = np.concatenate((ori_image, mask_output * 255), axis=-1)
83-
ori_blurred_image = cv2.GaussianBlur(ori_image, (blur_kernel_size, blur_kernel_size), 0)
84-
blur_image = mask_output * ori_image + (1-mask_output) * ori_blurred_image
85-
highlight_image = ori_image * mask_output
86-
87-
y_indices, x_indices = np.where(mask_output[:,:,0] > 0)
88-
89-
# 计算裁剪边界
90-
x_min, x_max = x_indices.min(), x_indices.max()
91-
y_min, y_max = y_indices.min(), y_indices.max()
92-
93-
# 根据边界裁剪图片
94-
cropped_blur_img = blur_image[y_min:y_max+1, x_min:x_max+1]
95-
cropped_highlight_img = highlight_image[y_min:y_max+1, x_min:x_max+1]
96-
return {
97-
'soft': masks,
98-
'hard': mask_output,
99-
'blur_image': blur_image,
100-
'highlight_image': highlight_image,
101-
'cropped_blur_img': cropped_blur_img,
102-
'cropped_highlight_img': cropped_highlight_img,
103-
'alhpa_image': rgba
104-
}
105-
106-
10738
class DeployModel_LISA(nn.Module):
10839
def __init__(self,
10940
ckpt_path,
@@ -127,9 +58,65 @@ def __init__(self,
12758
self.model.pixel_std = self.model.pixel_std.cuda()
12859
self.model.mask_decoder = self.model.mask_decoder.cuda()
12960

61+
@torch.no_grad()
62+
def forward_batch(
63+
self,
64+
image, # list of PIL.Image
65+
instruction, # list of instruction
66+
blur_kernel_size = 201,
67+
threshold = 0.5,
68+
dilate_kernel_size = 21,
69+
fill_color=(255, 255, 255)
70+
):
71+
ori_sizes = [img.size for img in image]
72+
ori_images = [np.asarray(img).astype(np.float32) for img in image]
73+
masks = self.model.generate_batch([img.resize((1024, 1024)) for img in image], instruction)
74+
75+
76+
soft = []
77+
hard = []
78+
blur_image = []
79+
highlight_image = []
80+
cropped_blur_img = []
81+
cropped_highlight_img = []
82+
rgba = []
83+
for mask, ori_image, ori_size in zip(masks, ori_images, ori_sizes):
84+
mask = torch.sigmoid(F.interpolate(
85+
mask.unsqueeze(0),
86+
(ori_size[1], ori_size[0]),
87+
mode="bilinear",
88+
align_corners=False,
89+
)[0, 0, :, :]).detach().cpu().numpy().astype(np.float32)[:,:,np.newaxis]
90+
mask_output = np.where(mask > threshold, 1, 0).astype(np.uint8)
91+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(dilate_kernel_size,dilate_kernel_size)) #ksize=7x7,
92+
mask_output = cv2.dilate(mask_output,kernel,iterations=1).astype(np.float32)
93+
mask_output = cv2.GaussianBlur(mask_output, (dilate_kernel_size, dilate_kernel_size), 0)[:,:,np.newaxis]
94+
y_indices, x_indices = np.where(mask_output[:,:,0] > 0)
95+
x_min, x_max = x_indices.min(), x_indices.max()
96+
y_min, y_max = y_indices.min(), y_indices.max()
97+
98+
99+
soft.append(mask)
100+
hard.append(mask_output)
101+
rgba.append(np.concatenate((ori_image, mask_output * 255), axis=-1))
102+
blur_image.append(mask_output * ori_image + (1-mask_output) * cv2.GaussianBlur(ori_image, (blur_kernel_size, blur_kernel_size), 0))
103+
highlight_image.append(ori_image * mask_output + torch.tensor(fill_color, dtype=torch.uint8).repeat(image.size[1], image.size[0], 1).numpy() * (1 - mask_output))
104+
cropped_blur_img.append(blur_image[-1][y_min:y_max+1, x_min:x_max+1])
105+
cropped_highlight_img.append(highlight_image[-1][y_min:y_max+1, x_min:x_max+1])
106+
return {
107+
'soft': masks,
108+
'hard': mask_output,
109+
'blur_image': blur_image,
110+
'highlight_image': highlight_image,
111+
'cropped_blur_img': cropped_blur_img,
112+
'cropped_highlight_img': cropped_highlight_img,
113+
'rgba_image': rgba
114+
}
115+
130116

131117
@torch.no_grad()
132-
def forward(self,
118+
def forward(
119+
self,
133120
image: Image,
134121
instruction: str,
135122
blur_kernel_size = 201,

model/LISA_vanilla.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,14 @@ def generate(self, images, instruction):
109109
masks = self.mask_decoder(image_embeddings, language_embeddings)
110110
mask_output = self.postprocess_masks(masks)
111111
return mask_output
112+
113+
def generate_batch(self, images, instruction):
114+
images = torch.stack([self.preprocess(img) for img in images], dim = 0)
115+
image_embeddings = self.image_encoder(images)
116+
language_embeddings = self.prompt_encoder(instruction, images)
117+
masks = self.mask_decoder(image_embeddings, language_embeddings)
118+
mask_output = self.postprocess_masks(masks)
119+
return mask_output
112120

113121
def forward(
114122
self,

utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
from PIL import Image
3+
import json
4+
def generate_boxes_mask(xmin, ymin, xmax, ymax, height, width):
5+
mask = np.zeros((height, width), dtype=np.float32)
6+
xmin = xmin * max(width, height)
7+
ymin = ymin * max(width, height)
8+
xmax = xmax * max(width, height)
9+
ymax = ymax * max(width, height)
10+
if width > height:
11+
overlay = (width - height) // 2
12+
ymin = max(0, ymin - overlay)
13+
ymax = max(0, ymax - overlay)
14+
else:
15+
overlay = (height - width) // 2
16+
xmin = max(0, xmin - overlay)
17+
xmax = max(0, xmax - overlay)
18+
mask[int(ymin):int(ymax), int(xmin):int(xmax)] = 1
19+
return mask
20+
21+
image = Image.open(...)
22+
item = json.loads(...)
23+
probability_mask = generate_boxes_mask(*(item['boxes'] + [image.size[1], image.size[0]]))

0 commit comments

Comments
 (0)