Skip to content

Commit 5d59d81

Browse files
committed
Merge branch 'self_critical_bottom_up' into self-critical
* self_critical_bottom_up: (42 commits) Add advanced. (Still nothing in it.) Update readme. Sort the features in the forwarding instead of dataloader. Add compatibility to resnet features. Add comments in Attmodel. Make image_root an optional option when prepro_label. Add options and verbose for make_bu_data. Add cider submodule Simplify resnet code. Update more to 0.4 version. Update to pytorch 0.4 Fix some in evals. Simplify AttModel. Update FC Model to the compatible version (previously FC Model is depreacated and not adapted to new structure.) Move set_lr to the right place in train.py Add max ppl option (beam search sorted by perplexity instead of logprob) (it doens't seem changing too much) Fix a bug in ensemble sample. Add logit layers option. (haven't reigourously tested if it works or not) Allow new ways of computing (using pack sequence) capable of using dataparallel. Add batch normalization layer in att_embed. ... # Conflicts: # misc/rewards.py # train.py
2 parents 3601e9c + 403141d commit 5d59d81

25 files changed

+1303
-460
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "cider"]
2+
path = cider
3+
url = https://github.com/ruotianluo/cider.git

ADVANCED.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Advanced
2+
3+
## Ensemble
4+
5+
## Batch normalization
6+
7+
## Box feature

README.md

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,78 @@
1-
# Self-critical Sequence Training for Image Captioning
1+
# Self-critical Sequence Training for Image Captioning (+ misc.)
22

3-
This is an unofficial implementation for [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563). The result of FC model can be replicated. (Not able to replicate Att2in result.)
3+
This repository includes the unofficial implementation [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563) and [Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering](https://arxiv.org/abs/1707.07998).
44

5-
The author helped me a lot when I tried to replicate the result. Great thanks. The latest topdown and att2in2 model can achieve 1.12 Cider score on Karpathy's test split after self-critical training.
5+
The author of SCST helped me a lot when I tried to replicate the result. Great thanks. The att2in2 model can achieve more than 1.20 Cider score on Karpathy's test split (with self-critical training, bottom-up feature, large rnn hidden size, without ensemble)
66

7-
This is based on my [neuraltalk2.pytorch](https://github.com/ruotianluo/neuraltalk2.pytorch) repository. The modifications is:
8-
- Add self critical training.
7+
This is based on my [ImageCaptioning.pytorch](https://github.com/ruotianluo/ImageCaptioning.pytorch) repository. The modifications is:
8+
- Self critical training.
9+
- Bottom up feature support from [ref](https://arxiv.org/abs/1707.07998). (Evaluation on arbitrary images is not supported.)
10+
- Ensemble
11+
- Multi-GPU training
912

1013
## Requirements
1114
Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3)
12-
PyTorch 0.2 (along with torchvision)
15+
PyTorch 0.4 (along with torchvision)
16+
cider (already been added as a submodule)
1317

14-
You need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`.
18+
(**Skip if you are using bottom-up feature**): If you want to use resnet to extract image features, you need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`.
1519

16-
## Pretrained models
20+
## Pretrained models (using resnet101 feature)
1721
Pretrained models are provided [here](https://drive.google.com/open?id=0B7fNdx_jAqhtdE1JRXpmeGJudTg). And the performances of each model will be maintained in this [issue](https://github.com/ruotianluo/neuraltalk2.pytorch/issues/10).
1822

19-
If you want to do evaluation only, then you can follow [this section](#generate-image-captions) after downloading the pretrained models.
23+
If you want to do evaluation only, you can then follow [this section](#generate-image-captions) after downloading the pretrained models (and also the pretrained resnet101).
2024

2125
## Train your own network on COCO
2226

23-
### Download COCO dataset and preprocessing
24-
25-
First, download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`.
27+
### Download COCO captions and preprocess them
2628

2729
Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits.
2830

29-
Once we have these, we can now invoke the `prepro_*.py` script, which will read all of this in and create a dataset (two feature folders, a hdf5 label file and a json file).
31+
Then do:
3032

3133
```bash
3234
$ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk
33-
$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT
3435
```
3536

3637
`prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`.
3738

39+
### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature)
40+
41+
Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`.
42+
43+
Then:
44+
45+
```
46+
$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT
47+
```
48+
49+
3850
`prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB.
3951

4052
(Check the prepro scripts for more options, like other resnet models or other attention sizes.)
4153

4254
**Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset.
4355

56+
### Download Bottom-up features (Skip if you are using resnet features)
57+
58+
Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one.
59+
60+
For example:
61+
```
62+
mkdir data/bu_data; cd data/bu_data
63+
wget https://storage.googleapis.com/bottom-up-attention/trainval.zip
64+
unzip trainval.zip
65+
66+
```
67+
68+
Then:
69+
70+
```bash
71+
python script/make_bu_data.py --output_dir data/cocobu
72+
```
73+
74+
This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu.
75+
4476
### Start training
4577

4678
```bash
@@ -68,8 +100,6 @@ First you should preprocess the dataset and get the cache for calculating cider
68100
$ python scripts/prepro_ngrams.py --input_json .../dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train
69101
```
70102

71-
And also you need to clone my forked [cider](https://github.com/ruotianluo/cider) repository.
72-
73103
Then, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up)
74104
```
75105
$ bash scripts/copy_model.sh fc fc_rl
@@ -122,6 +152,25 @@ The defualt split to evaluate is test. The default inference method is greedy de
122152

123153
**Live demo**. Not supported now. Welcome pull request.
124154

155+
## For more advanced features:
156+
157+
Checkout `ADVANCED.md`.
158+
159+
## Reference
160+
161+
If you find this repo useful, please consider citing (no obligation at all):
162+
163+
```
164+
@article{luo2018discriminability,
165+
title={Discriminability objective for training descriptive captions},
166+
author={Luo, Ruotian and Price, Brian and Cohen, Scott and Shakhnarovich, Gregory},
167+
journal={arXiv preprint arXiv:1803.04376},
168+
year={2018}
169+
}
170+
```
171+
172+
Of course, please cite the original paper of models you are using (You can find references in the model files).
173+
125174
## Acknowledgements
126175

127176
Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team.

dataloader.py

Lines changed: 93 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313

1414
import multiprocessing
1515

16-
def get_npy_data(ix, fc_file, att_file, use_att):
17-
if use_att == True:
18-
return (np.load(fc_file), np.load(att_file)['feat'], ix)
19-
else:
20-
return (np.load(fc_file), np.zeros((1,1,1)), ix)
21-
2216
class DataLoader(data.Dataset):
2317

2418
def reset_iterator(self, split):
@@ -39,7 +33,12 @@ def __init__(self, opt):
3933
self.opt = opt
4034
self.batch_size = self.opt.batch_size
4135
self.seq_per_img = opt.seq_per_img
36+
37+
# feature related options
4238
self.use_att = getattr(opt, 'use_att', True)
39+
self.use_box = getattr(opt, 'use_box', 0)
40+
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
41+
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
4342

4443
# load the json file which contains additional information about the dataset
4544
print('DataLoader loading json file: ', opt.input_json)
@@ -49,11 +48,12 @@ def __init__(self, opt):
4948
print('vocab size is ', self.vocab_size)
5049

5150
# open the hdf5 file
52-
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_label_h5)
51+
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
5352
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
5453

5554
self.input_fc_dir = self.opt.input_fc_dir
5655
self.input_att_dir = self.opt.input_att_dir
56+
self.input_box_dir = self.opt.input_box_dir
5757

5858
# load in the sequence data
5959
seq_size = self.h5_label_file['labels'].shape
@@ -96,6 +96,25 @@ def cleanup():
9696
import atexit
9797
atexit.register(cleanup)
9898

99+
def get_captions(self, ix, seq_per_img):
100+
# fetch the sequence labels
101+
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
102+
ix2 = self.label_end_ix[ix] - 1
103+
ncap = ix2 - ix1 + 1 # number of captions available for this image
104+
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
105+
106+
if ncap < seq_per_img:
107+
# we need to subsample (with replacement)
108+
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
109+
for q in range(seq_per_img):
110+
ixl = random.randint(ix1,ix2)
111+
seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length]
112+
else:
113+
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
114+
seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length]
115+
116+
return seq
117+
99118
def get_batch(self, split, batch_size=None, seq_per_img=None):
100119
batch_size = batch_size or self.batch_size
101120
seq_per_img = seq_per_img or self.seq_per_img
@@ -111,31 +130,13 @@ def get_batch(self, split, batch_size=None, seq_per_img=None):
111130
gts = []
112131

113132
for i in range(batch_size):
114-
import time
115-
t_start = time.time()
116133
# fetch image
117134
tmp_fc, tmp_att,\
118135
ix, tmp_wrapped = self._prefetch_process[split].get()
119-
fc_batch += [tmp_fc] * seq_per_img
120-
att_batch += [tmp_att] * seq_per_img
121-
122-
# fetch the sequence labels
123-
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
124-
ix2 = self.label_end_ix[ix] - 1
125-
ncap = ix2 - ix1 + 1 # number of captions available for this image
126-
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
127-
128-
if ncap < seq_per_img:
129-
# we need to subsample (with replacement)
130-
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
131-
for q in range(seq_per_img):
132-
ixl = random.randint(ix1,ix2)
133-
seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length]
134-
else:
135-
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
136-
seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length]
136+
fc_batch.append(tmp_fc)
137+
att_batch.append(tmp_att)
137138

138-
label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = seq
139+
label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = self.get_captions(ix, seq_per_img)
139140

140141
if tmp_wrapped:
141142
wrapped = True
@@ -149,21 +150,34 @@ def get_batch(self, split, batch_size=None, seq_per_img=None):
149150
info_dict['id'] = self.info['images'][ix]['id']
150151
info_dict['file_path'] = self.info['images'][ix]['file_path']
151152
infos.append(info_dict)
152-
#print(i, time.time() - t_start)
153153

154+
# #sort by att_feat length
155+
# fc_batch, att_batch, label_batch, gts, infos = \
156+
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
157+
fc_batch, att_batch, label_batch, gts, infos = \
158+
zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: 0, reverse=True))
159+
data = {}
160+
data['fc_feats'] = np.stack(reduce(lambda x,y:x+y, [[_]*seq_per_img for _ in fc_batch]))
161+
# merge att_feats
162+
max_att_len = max([_.shape[0] for _ in att_batch])
163+
data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, max_att_len, att_batch[0].shape[1]], dtype = 'float32')
164+
for i in range(len(att_batch)):
165+
data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = att_batch[i]
166+
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
167+
for i in range(len(att_batch)):
168+
data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = 1
169+
# set att_masks to None if attention features have same length
170+
if data['att_masks'].sum() == data['att_masks'].size:
171+
data['att_masks'] = None
172+
173+
data['labels'] = np.vstack(label_batch)
154174
# generate mask
155-
t_start = time.time()
156-
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, label_batch)))
175+
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
157176
for ix, row in enumerate(mask_batch):
158177
row[:nonzeros[ix]] = 1
159-
#print('mask', time.time() - t_start)
178+
data['masks'] = mask_batch
160179

161-
data = {}
162-
data['fc_feats'] = np.stack(fc_batch)
163-
data['att_feats'] = np.stack(att_batch)
164-
data['labels'] = label_batch
165-
data['gts'] = gts
166-
data['masks'] = mask_batch
180+
data['gts'] = gts # all ground truth captions of each images
167181
data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
168182
data['infos'] = infos
169183

@@ -176,15 +190,47 @@ def __getitem__(self, index):
176190
"""This function returns a tuple that is further passed to collate_fn
177191
"""
178192
ix = index #self.split_ix[index]
179-
return get_npy_data(ix, \
180-
os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy'),
181-
os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'),
182-
self.use_att
183-
)
193+
if self.use_att:
194+
att_feat = np.load(os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'))['feat']
195+
# Reshape to K x C
196+
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
197+
if self.norm_att_feat:
198+
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
199+
if self.use_box:
200+
box_feat = np.load(os.path.join(self.input_box_dir, str(self.info['images'][ix]['id']) + '.npy'))
201+
# devided by image width and height
202+
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
203+
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
204+
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
205+
if self.norm_box_feat:
206+
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
207+
att_feat = np.hstack([att_feat, box_feat])
208+
# sort the features by the size of boxes
209+
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
210+
else:
211+
att_feat = np.zeros((1,1,1))
212+
return (np.load(os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy')),
213+
att_feat,
214+
ix)
184215

185216
def __len__(self):
186217
return len(self.info['images'])
187218

219+
class SubsetSampler(torch.utils.data.sampler.Sampler):
220+
r"""Samples elements randomly from a given list of indices, without replacement.
221+
Arguments:
222+
indices (list): a list of indices
223+
"""
224+
225+
def __init__(self, indices):
226+
self.indices = indices
227+
228+
def __iter__(self):
229+
return (self.indices[i] for i in range(len(self.indices)))
230+
231+
def __len__(self):
232+
return len(self.indices)
233+
188234
class BlobFetcher():
189235
"""Experimental class for prefetching blobs in a separate process."""
190236
def __init__(self, split, dataloader, if_shuffle=False):
@@ -198,17 +244,17 @@ def __init__(self, split, dataloader, if_shuffle=False):
198244
# Add more in the queue
199245
def reset(self):
200246
"""
201-
Two cases:
247+
Two cases for this function to be triggered:
202248
1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
203249
2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
204250
"""
205-
# batch_size is 0, the merge is done in DataLoader class
251+
# batch_size is 1, the merge is done in DataLoader class
206252
self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
207253
batch_size=1,
208-
sampler=self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:],
254+
sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]),
209255
shuffle=False,
210256
pin_memory=True,
211-
num_workers=multiprocessing.cpu_count(),
257+
num_workers=4, # 4 is usually enough
212258
collate_fn=lambda x: x[0]))
213259

214260
def _get_next_minibatch_inds(self):

dataloaderraw.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
import random
1010
import torch
11-
from torch.autograd import Variable
1211
import skimage
1312
import skimage.io
1413
import scipy.misc
@@ -109,8 +108,9 @@ def get_batch(self, split, batch_size=None):
109108

110109
img = img.astype('float32')/255.0
111110
img = torch.from_numpy(img.transpose([2,0,1])).cuda()
112-
img = Variable(preprocess(img), volatile=True)
113-
tmp_fc, tmp_att = self.my_resnet(img)
111+
img = preprocess(img)
112+
with torch.no_grad():
113+
tmp_fc, tmp_att = self.my_resnet(img)
114114

115115
fc_batch[i] = tmp_fc.data.cpu().float().numpy()
116116
att_batch[i] = tmp_att.data.cpu().float().numpy()

0 commit comments

Comments
 (0)