10
10
import gdown
11
11
import accelerate
12
12
import torch .nn as nn
13
+ from typing import Type , Tuple
13
14
def _download (url : str , name : str ,root : str ):
14
15
os .makedirs (root , exist_ok = True )
15
16
@@ -25,85 +26,15 @@ def load(ckpt_path, type = "lisa", low_gpu_memory = False):
25
26
26
27
url = "https://drive.google.com/uc?export=download&id=1OyVci6rAwnb2sJPxhObgK7AvlLYDLLHw"
27
28
model_path = _download (url , "sam_vit_h_4b8939.pth" , os .path .expanduser (f"~/.cache/SeeWhatYouNeed/Sam" ))
28
- if type == 'lisa' :
29
- model = DeployModel_LISA (
30
- ckpt_path = ckpt_path ,
31
- sam_ckpt = model_path ,
32
- offload_languageencoder = low_gpu_memory
33
- )
34
- else :
35
- model = DeployModel_FAM (
36
- ckpt_path = ckpt_path ,
37
- sam_ckpt = model_path
38
- ).cuda ()
29
+ model = DeployModel_LISA (
30
+ ckpt_path = ckpt_path ,
31
+ sam_ckpt = model_path ,
32
+ offload_languageencoder = low_gpu_memory
33
+ )
39
34
return model
40
35
41
36
42
37
43
-
44
- class DeployModel_FAM (nn .Module ):
45
- def __init__ (self ,
46
- ckpt_path ,
47
- sam_ckpt
48
- ):
49
- super ().__init__ ()
50
- self .model = FAM (
51
- sam_model = sam_ckpt
52
- )
53
- ckpt = torch .load (ckpt_path , map_location = "cpu" )
54
- if 'module' in ckpt : ckpt = ckpt ['module' ]
55
- print (self .model .load_state_dict (ckpt , strict = False ))
56
-
57
-
58
- @torch .no_grad ()
59
- def forward (self ,
60
- image : Image ,
61
- instruction : str ,
62
- blur_kernel_size = 201 ,
63
- threshold = 0.2 ,
64
- dilate_kernel_size = 11 ):
65
-
66
- ori_size = image .size
67
- ori_image = np .asarray (image ).astype (np .float32 )
68
- resize_img = image .resize ((self .model .image_encoder .img_size , self .model .image_encoder .img_size ))
69
-
70
- masks = F .interpolate (torch .sigmoid (masks ),
71
- (ori_size [1 ], ori_size [0 ]),
72
- mode = "bilinear" ,
73
- align_corners = False ,
74
- )[0 , 0 , :, :].detach ().cpu ().numpy ().astype (np .float32 )[:,:,np .newaxis ]
75
-
76
-
77
- mask_output = np .where (masks > threshold , 1 , 0 ).astype (np .uint8 )
78
- kernel = cv2 .getStructuringElement (cv2 .MORPH_RECT ,(dilate_kernel_size ,dilate_kernel_size )) #ksize=7x7,
79
- mask_output = cv2 .dilate (mask_output ,kernel ,iterations = 1 ).astype (np .float32 )
80
- mask_output = cv2 .GaussianBlur (mask_output , (dilate_kernel_size , dilate_kernel_size ), 0 )[:,:,np .newaxis ]
81
-
82
- rgba = np .concatenate ((ori_image , mask_output * 255 ), axis = - 1 )
83
- ori_blurred_image = cv2 .GaussianBlur (ori_image , (blur_kernel_size , blur_kernel_size ), 0 )
84
- blur_image = mask_output * ori_image + (1 - mask_output ) * ori_blurred_image
85
- highlight_image = ori_image * mask_output
86
-
87
- y_indices , x_indices = np .where (mask_output [:,:,0 ] > 0 )
88
-
89
- # 计算裁剪边界
90
- x_min , x_max = x_indices .min (), x_indices .max ()
91
- y_min , y_max = y_indices .min (), y_indices .max ()
92
-
93
- # 根据边界裁剪图片
94
- cropped_blur_img = blur_image [y_min :y_max + 1 , x_min :x_max + 1 ]
95
- cropped_highlight_img = highlight_image [y_min :y_max + 1 , x_min :x_max + 1 ]
96
- return {
97
- 'soft' : masks ,
98
- 'hard' : mask_output ,
99
- 'blur_image' : blur_image ,
100
- 'highlight_image' : highlight_image ,
101
- 'cropped_blur_img' : cropped_blur_img ,
102
- 'cropped_highlight_img' : cropped_highlight_img ,
103
- 'alhpa_image' : rgba
104
- }
105
-
106
-
107
38
class DeployModel_LISA (nn .Module ):
108
39
def __init__ (self ,
109
40
ckpt_path ,
@@ -127,9 +58,65 @@ def __init__(self,
127
58
self .model .pixel_std = self .model .pixel_std .cuda ()
128
59
self .model .mask_decoder = self .model .mask_decoder .cuda ()
129
60
61
+ @torch .no_grad ()
62
+ def forward_batch (
63
+ self ,
64
+ image , # list of PIL.Image
65
+ instruction , # list of instruction
66
+ blur_kernel_size = 201 ,
67
+ threshold = 0.5 ,
68
+ dilate_kernel_size = 21 ,
69
+ fill_color = (255 , 255 , 255 )
70
+ ):
71
+ ori_sizes = [img .size for img in image ]
72
+ ori_images = [np .asarray (img ).astype (np .float32 ) for img in image ]
73
+ masks = self .model .generate_batch ([img .resize ((1024 , 1024 )) for img in image ], instruction )
74
+
75
+
76
+ soft = []
77
+ hard = []
78
+ blur_image = []
79
+ highlight_image = []
80
+ cropped_blur_img = []
81
+ cropped_highlight_img = []
82
+ rgba = []
83
+ for mask , ori_image , ori_size in zip (masks , ori_images , ori_sizes ):
84
+ mask = torch .sigmoid (F .interpolate (
85
+ mask .unsqueeze (0 ),
86
+ (ori_size [1 ], ori_size [0 ]),
87
+ mode = "bilinear" ,
88
+ align_corners = False ,
89
+ )[0 , 0 , :, :]).detach ().cpu ().numpy ().astype (np .float32 )[:,:,np .newaxis ]
90
+ mask_output = np .where (mask > threshold , 1 , 0 ).astype (np .uint8 )
91
+ kernel = cv2 .getStructuringElement (cv2 .MORPH_RECT ,(dilate_kernel_size ,dilate_kernel_size )) #ksize=7x7,
92
+ mask_output = cv2 .dilate (mask_output ,kernel ,iterations = 1 ).astype (np .float32 )
93
+ mask_output = cv2 .GaussianBlur (mask_output , (dilate_kernel_size , dilate_kernel_size ), 0 )[:,:,np .newaxis ]
94
+ y_indices , x_indices = np .where (mask_output [:,:,0 ] > 0 )
95
+ x_min , x_max = x_indices .min (), x_indices .max ()
96
+ y_min , y_max = y_indices .min (), y_indices .max ()
97
+
98
+
99
+ soft .append (mask )
100
+ hard .append (mask_output )
101
+ rgba .append (np .concatenate ((ori_image , mask_output * 255 ), axis = - 1 ))
102
+ blur_image .append (mask_output * ori_image + (1 - mask_output ) * cv2 .GaussianBlur (ori_image , (blur_kernel_size , blur_kernel_size ), 0 ))
103
+ highlight_image .append (ori_image * mask_output + torch .tensor (fill_color , dtype = torch .uint8 ).repeat (image .size [1 ], image .size [0 ], 1 ).numpy () * (1 - mask_output ))
104
+ cropped_blur_img .append (blur_image [- 1 ][y_min :y_max + 1 , x_min :x_max + 1 ])
105
+ cropped_highlight_img .append (highlight_image [- 1 ][y_min :y_max + 1 , x_min :x_max + 1 ])
106
+ return {
107
+ 'soft' : masks ,
108
+ 'hard' : mask_output ,
109
+ 'blur_image' : blur_image ,
110
+ 'highlight_image' : highlight_image ,
111
+ 'cropped_blur_img' : cropped_blur_img ,
112
+ 'cropped_highlight_img' : cropped_highlight_img ,
113
+ 'rgba_image' : rgba
114
+ }
115
+
130
116
131
117
@torch .no_grad ()
132
- def forward (self ,
118
+ def forward (
119
+ self ,
133
120
image : Image ,
134
121
instruction : str ,
135
122
blur_kernel_size = 201 ,
0 commit comments