Skip to content
This repository was archived by the owner on May 1, 2025. It is now read-only.

Commit f77760c

Browse files
committed
replicate demo
1 parent b726f8e commit f77760c

File tree

3 files changed

+232
-0
lines changed

3 files changed

+232
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ and visual grounding on RefCOCO+. Pre-trained and finetuned checkpoints are rele
2626
We provide code in visualize.ipynb to visualize the important areas in an image for each word in a text.
2727
Here is an example visualization using the visual grounding checkpoint.
2828

29+
Try the Replicate demo here [![Replicate](https://replicate.com/salesforce/albef/badge)](https://replicate.com/salesforce/albef).
30+
2931
<img src="examples/visualization.png" width="700">
3032

3133
### Pre-training on custom datasets:

cog.yaml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
build:
2+
gpu: true
3+
cuda: "11.1"
4+
python_version: "3.8"
5+
system_packages:
6+
- "libgl1-mesa-glx"
7+
- "libglib2.0-0"
8+
python_packages:
9+
- "ipython==7.30.1"
10+
- "torchvision==0.11.1"
11+
- "torch==1.10.0"
12+
- "timm==0.4.12"
13+
- "transformers==4.8.1"
14+
- "Pillow==8.3.2"
15+
- "numpy==1.21.1"
16+
- "opencv-python==4.5.5.62"
17+
- "scipy==1.8.0"
18+
- "scikit_image==0.19.2"
19+
- "matplotlib==3.4.3"
20+
21+
predict: "predict.py:Predictor"

predict.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import re
2+
import tempfile
3+
from functools import partial
4+
import cv2
5+
from PIL import Image
6+
import numpy as np
7+
from cog import BasePredictor, Path, Input
8+
9+
from skimage import transform as skimage_transform
10+
from scipy.ndimage import filters
11+
from matplotlib import pyplot as plt
12+
13+
import torch
14+
from torch import nn
15+
from torchvision import transforms
16+
17+
from models.vit import VisionTransformer
18+
from models.xbert import BertConfig, BertModel
19+
from models.tokenization_bert import BertTokenizer
20+
21+
22+
class Predictor(BasePredictor):
23+
def setup(self):
24+
normalize = transforms.Normalize(
25+
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
26+
)
27+
28+
self.transform = transforms.Compose(
29+
[
30+
transforms.Resize((384, 384), interpolation=Image.BICUBIC),
31+
transforms.ToTensor(),
32+
normalize,
33+
]
34+
)
35+
36+
self.tokenizer = BertTokenizer.from_pretrained("bert/bert-base-uncased")
37+
38+
bert_config_path = "configs/config_bert.json"
39+
self.model = VL_Transformer_ITM(
40+
text_encoder="bert/bert-base-uncased", config_bert=bert_config_path
41+
)
42+
43+
checkpoint = torch.load("refcoco.pth", map_location="cpu")
44+
msg = self.model.load_state_dict(checkpoint, strict=False)
45+
self.model.eval()
46+
47+
self.block_num = 8
48+
self.model.text_encoder.base_model.base_model.encoder.layer[
49+
self.block_num
50+
].crossattention.self.save_attention = True
51+
52+
self.model.cuda()
53+
54+
def predict(
55+
self,
56+
image: Path = Input(description="Input image."),
57+
caption: str = Input(
58+
description="Caption for the image. Grad-CAM visualization will be generated "
59+
"for each word in the cation."
60+
),
61+
) -> Path:
62+
63+
image_pil = Image.open(str(image)).convert("RGB")
64+
img = self.transform(image_pil).unsqueeze(0)
65+
66+
text = pre_caption(caption)
67+
text_input = self.tokenizer(text, return_tensors="pt")
68+
69+
img = img.cuda()
70+
text_input = text_input.to(img.device)
71+
72+
# Compute GradCAM
73+
output = self.model(img, text_input)
74+
loss = output[:, 1].sum()
75+
76+
self.model.zero_grad()
77+
loss.backward()
78+
79+
with torch.no_grad():
80+
mask = text_input.attention_mask.view(
81+
text_input.attention_mask.size(0), 1, -1, 1, 1
82+
)
83+
84+
grads = self.model.text_encoder.base_model.base_model.encoder.layer[
85+
self.block_num
86+
].crossattention.self.get_attn_gradients()
87+
cams = self.model.text_encoder.base_model.base_model.encoder.layer[
88+
self.block_num
89+
].crossattention.self.get_attention_map()
90+
91+
cams = cams[:, :, :, 1:].reshape(img.size(0), 12, -1, 24, 24) * mask
92+
grads = (
93+
grads[:, :, :, 1:].clamp(0).reshape(img.size(0), 12, -1, 24, 24) * mask
94+
)
95+
96+
gradcam = cams * grads
97+
gradcam = gradcam[0].mean(0).cpu().detach()
98+
99+
num_image = len(text_input.input_ids[0])
100+
fig, ax = plt.subplots(num_image, 1, figsize=(20, 8 * num_image))
101+
102+
rgb_image = cv2.imread(str(image))[:, :, ::-1]
103+
rgb_image = np.float32(rgb_image) / 255
104+
105+
ax[0].imshow(rgb_image)
106+
ax[0].set_yticks([])
107+
ax[0].set_xticks([])
108+
ax[0].set_xlabel("Image")
109+
110+
for i, token_id in enumerate(text_input.input_ids[0][1:]):
111+
word = self.tokenizer.decode([token_id])
112+
gradcam_image = getAttMap(rgb_image, gradcam[i + 1])
113+
ax[i + 1].imshow(gradcam_image)
114+
ax[i + 1].set_yticks([])
115+
ax[i + 1].set_xticks([])
116+
ax[i + 1].set_xlabel(word)
117+
118+
out_path = Path(tempfile.mkdtemp()) / "output.png"
119+
fig.savefig(str(out_path))
120+
return out_path
121+
122+
123+
class VL_Transformer_ITM(nn.Module):
124+
def __init__(self, text_encoder=None, config_bert=""):
125+
super().__init__()
126+
127+
bert_config = BertConfig.from_json_file(config_bert)
128+
129+
self.visual_encoder = VisionTransformer(
130+
img_size=384,
131+
patch_size=16,
132+
embed_dim=768,
133+
depth=12,
134+
num_heads=12,
135+
mlp_ratio=4,
136+
qkv_bias=True,
137+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
138+
)
139+
140+
self.text_encoder = BertModel.from_pretrained(
141+
text_encoder, config=bert_config, add_pooling_layer=False
142+
)
143+
144+
self.itm_head = nn.Linear(768, 2)
145+
146+
def forward(self, image, text):
147+
image_embeds = self.visual_encoder(image)
148+
149+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
150+
image.device
151+
)
152+
153+
output = self.text_encoder(
154+
text.input_ids,
155+
attention_mask=text.attention_mask,
156+
encoder_hidden_states=image_embeds,
157+
encoder_attention_mask=image_atts,
158+
return_dict=True,
159+
)
160+
161+
vl_embeddings = output.last_hidden_state[:, 0, :]
162+
vl_output = self.itm_head(vl_embeddings)
163+
return vl_output
164+
165+
166+
def pre_caption(caption, max_words=30):
167+
caption = (
168+
re.sub(
169+
r"([,.'!?\"()*#:;~])",
170+
"",
171+
caption.lower(),
172+
)
173+
.replace("-", " ")
174+
.replace("/", " ")
175+
)
176+
177+
caption = re.sub(
178+
r"\s{2,}",
179+
" ",
180+
caption,
181+
)
182+
caption = caption.rstrip("\n")
183+
caption = caption.strip(" ")
184+
185+
# truncate caption
186+
caption_words = caption.split(" ")
187+
if len(caption_words) > max_words:
188+
caption = " ".join(caption_words[:max_words])
189+
return caption
190+
191+
192+
def getAttMap(img, attMap, blur=True, overlap=True):
193+
attMap -= attMap.min()
194+
if attMap.max() > 0:
195+
attMap /= attMap.max()
196+
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
197+
if blur:
198+
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
199+
attMap -= attMap.min()
200+
attMap /= attMap.max()
201+
cmap = plt.get_cmap("jet")
202+
attMapV = cmap(attMap)
203+
attMapV = np.delete(attMapV, 3, 2)
204+
if overlap:
205+
attMap = (
206+
1 * (1 - attMap ** 0.7).reshape(attMap.shape + (1,)) * img
207+
+ (attMap ** 0.7).reshape(attMap.shape + (1,)) * attMapV
208+
)
209+
return attMap

0 commit comments

Comments
 (0)