-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_caption.py
98 lines (89 loc) · 3.91 KB
/
generate_caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from model import ChineseTokenizer
import torch
import argparse
import os
import PIL
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
import torch
from torch.utils.data import Dataset
from torchvision import transforms as T
from einops import rearrange
'''parser=argparse.ArgumentParser()
parser.add_argument('--prefixLM_path',type=str,help='path to your trained PrefixLM')
parser.add_argument('--img_path',type=str,help='path to your images')
parser.add_argument('--prefix',type=str,default='',help='prefix for caption')
args=parser.parse_args()'''
device = torch.device("cpu")#"cuda" if torch.cuda.is_available() else "cpu")
prefixLM_path='./SimVLP2.pt'
# #greedy sampling
# def generate_caption(model, image_tensor, tokenized_text):
# img_emed = model.ResNet(image_tensor)
# img_emed = rearrange(img_emed, 'b c h w -> b (h w) c')
# img_emed += model.img_pos_embed(img_emed)
#
# pre_txt_embed = model.txt_embed(tokenized_text)
# pre_txt_embed += model.txt_pos_embed(torch.arange(model.prefix_txt_len, device=device))
# tgt_txt = torch.zeros(1, 1, dtype=torch.long, device=device)+4
# tgt_txt_embed = model.txt_embed(tgt_txt)
# tgt_txt_embed += model.txt_pos_embed(torch.arange(1, device=device) + model.prefix_txt_len)
#
# prefix = torch.cat((img_emed, pre_txt_embed), dim=1)
# out = model.transformer(prefix, tgt_txt_embed)
# logits = model.to_logits(out)[:, -1]
# #logits=logits[:,:-26]
# sample = torch.argmax(logits, dim=-1)
# cur_len = 1
# while (cur_len < model.target_txt_len and sample!=5): # 5 is the id of [SEP]
# tgt_txt = torch.cat((tgt_txt, sample.unsqueeze(1)), dim=-1)
# tgt_txt_embed = model.txt_embed(tgt_txt)
# cur_len += 1
# tgt_txt_embed += model.txt_pos_embed(torch.arange(cur_len, device=device) + model.prefix_txt_len)
# out = model.transformer(prefix, tgt_txt_embed)
# logits = model.to_logits(out)[:, -1]
# #logits = logits[:, :-26]
# #print(logits)
# sample = torch.argmax(logits, dim=-1)
# return tgt_txt
assert os.path.exists(prefixLM_path), 'trained model path must exist'
loaded_obj = torch.load(prefixLM_path, map_location='cpu')
PrefixLM_configure, weights = loaded_obj['hparams'],loaded_obj['weights']
model=PrefixLM(**PrefixLM_configure)
model.load_state_dict(weights)
model.to(device)
image_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')
if img.mode != 'RGB' else img),
T.RandomResizedCrop(model.input_resolution,
scale=(0.75, 1.),
ratio=(1., 1.)),
T.ToTensor()
])
tokenizer=ChineseTokenizer()
'''tokenized_text = tokenizer.tokenize(
args.prefix,
model.prefix_txt_len,
truncate_text=True
).to(device)
#print(tokenized_text)
img=PIL.Image.open(args.img_path)
image_tensor = image_transform(img).unsqueeze(0).to(device)
model.eval()
cap=generate_caption(model, image_tensor, tokenized_text)
#print(cap)
print(args.prefix+tokenizer.decode(cap.squeeze(0)).replace('[UNK]',''))'''
def interface(image_path="C:/Users/17914/Pictures/Camera Roll/WIN_20210207_23_12_14_Pro.jpg",prefix=''):
img = PIL.Image.open(image_path)
image_tensor = image_transform(img).unsqueeze(0).to(device)
tokenized_text = tokenizer.tokenize(
prefix,
context_length=10,#any length<model.txt_seq_len
truncate_text=True
).to(device)
model.eval()
cap =model.generate(img,prefix_txt=tokenized_text,sampling_method='nucleus',eos_id=0,top_k=256,top_p=0.9,temperature=1.)
return prefix + tokenizer.decode(cap.squeeze(0)).replace('[UNK]', '')
if __name__=='__main__':
print(interface("C:\\Users\\17914\Pictures\\Camera Roll\\WIN_20210207_23_12_14_Pro.jpg"))