diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..36a2815 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,38 @@ +FROM nvcr.io/nvidia/pytorch:19.05-py3 + +# basic python packages +RUN pip install pytorch-pretrained-bert==0.6.2 \ + tensorboardX==1.7 ipdb==0.12 lz4==2.1.9 lmdb==0.97 + +####### horovod for multi-GPU (distributed) training ####### + +# update OpenMPI to avoid horovod bug +RUN rm -r /usr/local/mpi &&\ + wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.4.tar.gz &&\ + gunzip -c openmpi-3.1.4.tar.gz | tar xf - &&\ + cd openmpi-3.1.4 &&\ + ./configure --prefix=/usr/local/mpi --enable-orterun-prefix-by-default \ + --with-verbs --disable-getpwuid &&\ + make -j$(nproc) all && make install &&\ + ldconfig &&\ + cd - && rm -r openmpi-3.1.4 && rm openmpi-3.1.4.tar.gz + +ENV OPENMPI_VERSION=3.1.4 + +# horovod +RUN HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \ + pip install --no-cache-dir horovod==0.16.4 &&\ + ldconfig + +# ssh +RUN apt-get update &&\ + apt-get install -y --no-install-recommends openssh-client openssh-server &&\ + mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + + +WORKDIR /src diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9b7b0f7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..8e76355 --- /dev/null +++ b/README.md @@ -0,0 +1,146 @@ +# NSGDC + +Some codes in this repo are copied/modified from opensource implementations made available by +[UNITER](https://github.com/ChenRocks/UNITER), +[PyTorch](https://github.com/pytorch/pytorch), +[HuggingFace](https://github.com/huggingface/transformers), +[OpenNMT](https://github.com/OpenNMT/OpenNMT-py), +and [Nvidia](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch). +The image features are extracted using [BUTD](https://github.com/peteanderson80/bottom-up-attention). + + +## Requirements +This is following UNITER. We provide Docker image for easier reproduction. Please install the following: + - [nvidia driver](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#package-manager-installation) (418+), + - [Docker](https://docs.docker.com/install/linux/docker-ce/ubuntu/) (19.03+), + - [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-docker#quickstart). + +Our scripts require the user to have the [docker group membership](https://docs.docker.com/install/linux/linux-postinstall/) +so that docker commands can be run without sudo. +We only support Linux with NVIDIA GPUs. We test on Ubuntu 18.04 and V100 cards. +We use mixed-precision training hence GPUs with Tensor Cores are recommended. + +## Image-Text Retrieval +### Download Data +``` +bash scripts/download_itm.sh $PATH_TO_STORAGE +``` + +### Launch the Docker Container +```bash +# docker image should be automatically pulled +source launch_container.sh $PATH_TO_STORAGE/txt_db $PATH_TO_STORAGE/img_db \ +$PATH_TO_STORAGE/finetune $PATH_TO_STORAGE/pretrained +``` + +In case you would like to reproduce the whole preprocessing pipeline. + +The launch script respects $CUDA_VISIBLE_DEVICES environment variable. +Note that the source code is mounted into the container under `/src` instead +of built into the image so that user modification will be reflected without +re-building the image. (Data folders are mounted into the container separately +for flexibility on folder structures.) + + +### Image-Text Retrieval (Flickr30k) +``` +# Train wit the base setting +bash run_cmds/tran_pnsgd_base_flickr.sh +bash run_cmds/tran_pnsgd2_base_flickr.sh + +# Train wit the large setting +bash run_cmds/tran_pnsgd_large_flickr.sh +bash run_cmds/tran_pnsgd2_large_flickr.sh +``` + +### Image-Text Retrieval (COCO) +``` +# Train wit the base setting +bash run_cmds/tran_pnsgd_base_coco.sh +bash run_cmds/tran_pnsgd2_base_coco.sh + +# Train wit the large setting +bash run_cmds/tran_pnsgd_large_coco.sh +bash run_cmds/tran_pnsgd2_large_coco.sh +``` + +### Run Inference +``` +bash run_cmds/inf_nsgd.sh +``` + +## Results + +Our models achieve the following performance. + +### MS-COCO + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelImage-to-TextText-to-Image
R@1R@5R@110R@1R@5R@10
NSGDC-Base66.688.694.051.679.187.5
NSGDC-Large67.889.694.253.380.088.0
+ +### Flickr30K + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelImage-to-TextText-to-Image
R@1R@5R@110R@1R@5R@10
NSGDC-Base87.998.199.374.593.396.3
NSGDC-Large90.698.899.177.394.397.3
diff --git a/config/train-itm-pnsgd-base-coco.json b/config/train-itm-pnsgd-base-coco.json new file mode 100644 index 0000000..9adfc0c --- /dev/null +++ b/config/train-itm-pnsgd-base-coco.json @@ -0,0 +1,43 @@ +{ + "compressed_db": false, + "checkpoint": "log/pretrained/uniter-base.pt", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 32, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "full_val": true, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "itm-data/txt_db3/itm_coco_train.db", + "itm-data/txt_db3/itm_coco_restval.db" + ], + "train_img_dbs": [ + "itm-data/img_db/coco_train2014/", + "itm-data/img_db/coco_val2014/" + ], + "val_txt_db": "itm-data/txt_db3/itm_coco_val.db", + "val_img_db": "itm-data/img_db/coco_val2014", + "test_txt_db": "itm-data/txt_db3/itm_coco_test.db", + "test_img_db": "itm-data/img_db/coco_val2014", + "model_config": "config/uniter-base.json" +} diff --git a/config/train-itm-pnsgd-base-flickr.json b/config/train-itm-pnsgd-base-flickr.json new file mode 100644 index 0000000..7c9b705 --- /dev/null +++ b/config/train-itm-pnsgd-base-flickr.json @@ -0,0 +1,41 @@ +{ + "compressed_db": false, + "checkpoint": "log/pretrained/uniter-base.pt", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 32, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "full_val": true, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "itm-data/txt_db3/itm_flickr30k_train.db" + ], + "train_img_dbs": [ + "itm-data/img_db/flickr30k/" + ], + "val_txt_db": "itm-data/txt_db3/itm_flickr30k_val.db", + "val_img_db": "itm-data/img_db/flickr30k/", + "test_txt_db": "itm-data/txt_db3/itm_flickr30k_test.db", + "test_img_db": "itm-data/img_db/flickr30k/", + "model_config": "config/uniter-base.json" +} diff --git a/config/train-itm-pnsgd-large-coco.json b/config/train-itm-pnsgd-large-coco.json new file mode 100644 index 0000000..cab857f --- /dev/null +++ b/config/train-itm-pnsgd-large-coco.json @@ -0,0 +1,43 @@ +{ + "compressed_db": false, + "checkpoint": "log/pretrained/uniter-large.pt", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 32, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "full_val": true, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "itm-data/txt_db3/itm_coco_train.db", + "itm-data/txt_db3/itm_coco_restval.db" + ], + "train_img_dbs": [ + "itm-data/img_db/coco_train2014/", + "itm-data/img_db/coco_val2014/" + ], + "val_txt_db": "itm-data/txt_db3/itm_coco_val.db", + "val_img_db": "itm-data/img_db/coco_val2014", + "test_txt_db": "itm-data/txt_db3/itm_coco_test.db", + "test_img_db": "itm-data/img_db/coco_val2014", + "model_config": "config/uniter-large.json" +} diff --git a/config/train-itm-pnsgd-large-flickr.json b/config/train-itm-pnsgd-large-flickr.json new file mode 100644 index 0000000..5862f15 --- /dev/null +++ b/config/train-itm-pnsgd-large-flickr.json @@ -0,0 +1,41 @@ +{ + "compressed_db": false, + "checkpoint": "log/pretrained/uniter-large.pt", + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 16, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "full_val": true, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "itm-data/txt_db3/itm_flickr30k_train.db" + ], + "train_img_dbs": [ + "itm-data/img_db/flickr30k/" + ], + "val_txt_db": "itm-data/txt_db3/itm_flickr30k_val.db", + "val_img_db": "itm-data/img_db/flickr30k/", + "test_txt_db": "itm-data/txt_db3/itm_flickr30k_test.db", + "test_img_db": "itm-data/img_db/flickr30k/", + "model_config": "config/uniter-large.json" +} diff --git a/config/train-itm-pnsgd2-base-coco.json b/config/train-itm-pnsgd2-base-coco.json new file mode 100644 index 0000000..362a870 --- /dev/null +++ b/config/train-itm-pnsgd2-base-coco.json @@ -0,0 +1,41 @@ +{ + "compressed_db": false, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 32, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "valid_steps": 150, + "num_train_steps": 1500, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 150, + "seed": 42, + "full_val": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "itm-data/txt_db3/itm_coco_train.db", + "itm-data/txt_db3/itm_coco_restval.db" + ], + "train_img_dbs": [ + "itm-data/img_db/coco_train2014/", + "itm-data/img_db/coco_val2014/" + ], + "val_txt_db": "itm-data/txt_db3/itm_coco_val.db", + "val_img_db": "itm-data/img_db/coco_val2014", + "test_txt_db": "itm-data/txt_db3/itm_coco_test.db", + "test_img_db": "itm-data/img_db/coco_val2014", + "model_config": "config/uniter-base.json" +} diff --git a/config/train-itm-pnsgd2-base-flickr.json b/config/train-itm-pnsgd2-base-flickr.json new file mode 100644 index 0000000..9b791c6 --- /dev/null +++ b/config/train-itm-pnsgd2-base-flickr.json @@ -0,0 +1,39 @@ +{ + "compressed_db": false, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 32, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "full_val": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "itm-data/txt_db3/itm_flickr30k_train.db" + ], + "train_img_dbs": [ + "itm-data/img_db/flickr30k/" + ], + "val_txt_db": "itm-data/txt_db3/itm_flickr30k_val.db", + "val_img_db": "itm-data/img_db/flickr30k/", + "test_txt_db": "itm-data/txt_db3/itm_flickr30k_test.db", + "test_img_db": "itm-data/img_db/flickr30k/", + "model_config": "config/uniter-base.json" +} diff --git a/config/train-itm-pnsgd2-large-coco.json b/config/train-itm-pnsgd2-large-coco.json new file mode 100644 index 0000000..d24848e --- /dev/null +++ b/config/train-itm-pnsgd2-large-coco.json @@ -0,0 +1,42 @@ +{ + "compressed_db": false, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 32, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "valid_steps": 500, + "num_train_steps": 5000, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 500, + "seed": 42, + "full_val": true, + "fp16": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "itm-data/txt_db3/itm_coco_train.db", + "itm-data/txt_db3/itm_coco_restval.db" + ], + "train_img_dbs": [ + "itm-data/img_db/coco_train2014/", + "itm-data/img_db/coco_val2014/" + ], + "val_txt_db": "itm-data/txt_db3/itm_coco_val.db", + "val_img_db": "itm-data/img_db/coco_val2014", + "test_txt_db": "itm-data/txt_db3/itm_coco_test.db", + "test_img_db": "itm-data/img_db/coco_val2014", + "model_config": "config/uniter-large.json" +} diff --git a/config/train-itm-pnsgd2-large-flickr.json b/config/train-itm-pnsgd2-large-flickr.json new file mode 100644 index 0000000..94b9dfe --- /dev/null +++ b/config/train-itm-pnsgd2-large-flickr.json @@ -0,0 +1,39 @@ +{ + "compressed_db": false, + "max_txt_len": 60, + "conf_th": 0.2, + "max_bb": 100, + "min_bb": 10, + "num_bb": 36, + "train_batch_size": 32, + "negative_size": 399, + "hard_neg_size": 31, + "inf_minibatch_size": 400, + "margin": 0.2, + "valid_steps": 150, + "num_train_steps": 1500, + "optim": "adamw", + "betas": [ + 0.9, + 0.98 + ], + "dropout": 0.1, + "weight_decay": 0.01, + "grad_norm": 2.0, + "warmup_steps": 150, + "seed": 42, + "full_val": true, + "n_workers": 4, + "pin_mem": true, + "train_txt_dbs": [ + "itm-data/txt_db3/itm_flickr30k_train.db" + ], + "train_img_dbs": [ + "itm-data/img_db/flickr30k/" + ], + "val_txt_db": "itm-data/txt_db3/itm_flickr30k_val.db", + "val_img_db": "itm-data/img_db/flickr30k/", + "test_txt_db": "itm-data/txt_db3/itm_flickr30k_test.db", + "test_img_db": "itm-data/img_db/flickr30k/", + "model_config": "config/uniter-large.json" +} diff --git a/config/uniter-base.json b/config/uniter-base.json new file mode 100644 index 0000000..691dacf --- /dev/null +++ b/config/uniter-base.json @@ -0,0 +1,13 @@ +{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 2, + "vocab_size": 28996 +} diff --git a/config/uniter-large.json b/config/uniter-large.json new file mode 100644 index 0000000..961e7ca --- /dev/null +++ b/config/uniter-large.json @@ -0,0 +1,13 @@ +{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 512, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "type_vocab_size": 2, + "vocab_size": 28996 +} diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..bf64845 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,17 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" +from .data import (TxtTokLmdb, DetectFeatLmdb, + ImageLmdbGroup, ConcatDatasetWithLens) +from .sampler import TokenBucketSampler +from .loader import PrefetchLoader, MetaLoader +from .itm import (TokenBucketSamplerForItm, ItmDataset, + itm_collate, itm_ot_collate, + ItmRankDataset, ItmValDataset, ItmEvalDataset, ItmAdvEvalDataset, ItmDCEvalDataset, ItmStaticDataAttackEvalDataset, + ItmRankDatasetHardNegFromImage, + ItmRankDatasetHardNegFromText, + itm_rank_collate, itm_val_collate, itm_eval_collate, + itm_rank_hn_collate) +from .pnsgd import (PNSGDFromImage, PNSGDFromText, pnsgd_collate) diff --git a/data/data.py b/data/data.py new file mode 100644 index 0000000..fb02737 --- /dev/null +++ b/data/data.py @@ -0,0 +1,326 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Dataset interfaces +""" +from collections import defaultdict +from contextlib import contextmanager +import io +import json +from os.path import exists + +import numpy as np +import torch +from torch.utils.data import Dataset, ConcatDataset +import horovod.torch as hvd +from tqdm import tqdm +import lmdb +from lz4.frame import compress, decompress + +import msgpack +import msgpack_numpy +msgpack_numpy.patch() + + +def _fp16_to_fp32(feat_dict): + out = {k: arr.astype(np.float32) + if arr.dtype == np.float16 else arr + for k, arr in feat_dict.items()} + return out + + +def compute_num_bb(confs, conf_th, min_bb, max_bb): + num_bb = max(min_bb, (confs > conf_th).sum()) + num_bb = min(max_bb, num_bb) + return num_bb + + +def _check_distributed(): + try: + dist = hvd.size() != hvd.local_size() + except ValueError: + # not using horovod + dist = False + return dist + + +class DetectFeatLmdb(object): + def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36, + compress=True): + self.img_dir = img_dir + print('| image_dir: {}'.format(img_dir)) + print('| conf_th: ', conf_th) + import os + print('| files: ', os.listdir(img_dir)) + if conf_th == -1: + db_name = f'feat_numbb{num_bb}' + self.name2nbb = defaultdict(lambda: num_bb) + else: + db_name = f'feat_th{conf_th}_max{max_bb}_min{min_bb}' + nbb = f'nbb_th{conf_th}_max{max_bb}_min{min_bb}.json' + if not exists(f'{img_dir}/{nbb}'): + # nbb is not pre-computed + self.name2nbb = None + else: + self.name2nbb = json.load(open(f'{img_dir}/{nbb}')) + self.compress = compress + if compress: + db_name += '_compressed' + + if self.name2nbb is None: + if compress: + db_name = 'all_compressed' + else: + db_name = 'all' + # only read ahead on single node training + self.env = lmdb.open(f'{img_dir}/{db_name}', + readonly=True, create=False, + readahead=not _check_distributed()) + self.txn = self.env.begin(buffers=True) + if self.name2nbb is None: + self.name2nbb = self._compute_nbb() + + def _compute_nbb(self): + name2nbb = {} + fnames = json.loads(self.txn.get(key=b'__keys__').decode('utf-8')) + for fname in tqdm(fnames, desc='reading images'): + dump = self.txn.get(fname.encode('utf-8')) + if self.compress: + with io.BytesIO(dump) as reader: + img_dump = np.load(reader, allow_pickle=True) + confs = img_dump['conf'] + else: + img_dump = msgpack.loads(dump, raw=False) + confs = img_dump['conf'] + name2nbb[fname] = compute_num_bb(confs, self.conf_th, + self.min_bb, self.max_bb) + + return name2nbb + + def __del__(self): + self.env.close() + + def get_dump(self, file_name): + # hack for MRC + dump = self.txn.get(file_name.encode('utf-8')) + nbb = self.name2nbb[file_name] + if self.compress: + with io.BytesIO(dump) as reader: + img_dump = np.load(reader, allow_pickle=True) + img_dump = _fp16_to_fp32(img_dump) + else: + img_dump = msgpack.loads(dump, raw=False) + img_dump = _fp16_to_fp32(img_dump) + img_dump = {k: arr[:nbb, ...] for k, arr in img_dump.items()} + return img_dump + + def __getitem__(self, file_name): + dump = self.txn.get(file_name.encode('utf-8')) + nbb = self.name2nbb[file_name] + if self.compress: + with io.BytesIO(dump) as reader: + img_dump = np.load(reader, allow_pickle=True) + img_dump = {'features': img_dump['features'], + 'norm_bb': img_dump['norm_bb']} + else: + img_dump = msgpack.loads(dump, raw=False) + img_feat = torch.tensor(img_dump['features'][:nbb, :]).float() + img_bb = torch.tensor(img_dump['norm_bb'][:nbb, :]).float() + return img_feat, img_bb + + +@contextmanager +def open_lmdb(db_dir, readonly=False): + db = TxtLmdb(db_dir, readonly) + try: + yield db + finally: + del db + + +class TxtLmdb(object): + def __init__(self, db_dir, readonly=True): + self.readonly = readonly + if readonly: + # training + print('| db_dir: ', db_dir) + self.env = lmdb.open(db_dir, + readonly=True, create=False, + readahead=not _check_distributed()) + self.txn = self.env.begin(buffers=True) + self.write_cnt = None + # cur = self.txn.cursor() # 生成迭代器指针 + # num = 0 + # for key, value in cur: + # num += 1 + # + # print('| number: ', num) + # raise Exception(" number: {}".format(num)) + + else: + # prepro + self.env = lmdb.open(db_dir, readonly=False, create=True, + map_size=4 * 1024**4) + self.txn = self.env.begin(write=True) + self.write_cnt = 0 + + def __del__(self): + if self.write_cnt: + self.txn.commit() + self.env.close() + + def __getitem__(self, key): + # print('| key: ', key, type(key)) + return msgpack.loads(decompress(self.txn.get(key.encode('utf-8'))), + raw=False) + + def __setitem__(self, key, value): + # NOTE: not thread safe + if self.readonly: + raise ValueError('readonly text DB') + ret = self.txn.put(key.encode('utf-8'), + compress(msgpack.dumps(value, use_bin_type=True))) + self.write_cnt += 1 + if self.write_cnt % 1000 == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + self.write_cnt = 0 + return ret + + +class TxtTokLmdb(object): + def __init__(self, db_dir, max_txt_len=60): + if max_txt_len == -1: + self.id2len = json.load(open(f'{db_dir}/id2len.json')) + else: + self.id2len = { + id_: len_ + for id_, len_ in json.load(open(f'{db_dir}/id2len.json') + ).items() + if len_ <= max_txt_len + } + self.db_dir = db_dir + self.db = TxtLmdb(db_dir, readonly=True) + meta = json.load(open(f'{db_dir}/meta.json', 'r')) + self.cls_ = meta['CLS'] + self.sep = meta['SEP'] + self.mask = meta['MASK'] + self.v_range = meta['v_range'] + + def __getitem__(self, id_): + txt_dump = self.db[id_] + return txt_dump + + def combine_inputs(self, *inputs): + input_ids = [self.cls_] + for ids in inputs: + input_ids.extend(ids + [self.sep]) + return torch.tensor(input_ids) + + @property + def txt2img(self): + txt2img = json.load(open(f'{self.db_dir}/txt2img.json')) + return txt2img + + @property + def img2txts(self): + img2txts = json.load(open(f'{self.db_dir}/img2txts.json')) + return img2txts + + +def get_ids_and_lens(db): + assert isinstance(db, TxtTokLmdb) + lens = [] + ids = [] + for id_ in list(db.id2len.keys())[hvd.rank()::hvd.size()]: + lens.append(db.id2len[id_]) + ids.append(id_) + return lens, ids + + +class DetectFeatTxtTokDataset(Dataset): + def __init__(self, txt_db, img_db): + assert isinstance(txt_db, TxtTokLmdb) + assert isinstance(img_db, DetectFeatLmdb) + self.txt_db = txt_db + self.img_db = img_db + txt_lens, self.ids = get_ids_and_lens(txt_db) + # self.ids would be split by the GPUs + txt2img = txt_db.txt2img + self.lens = [tl + self.img_db.name2nbb[txt2img[id_]] + for tl, id_ in zip(txt_lens, self.ids)] + + def __len__(self): + return len(self.ids) + + def __getitem__(self, i): + id_ = self.ids[i] + example = self.txt_db[id_] + return example + + def _get_img_feat(self, fname): + img_feat, bb = self.img_db[fname] + img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1) + num_bb = img_feat.size(0) + return img_feat, img_bb, num_bb + + +def pad_tensors(tensors, lens=None, pad=0): + """B x [T, ...]""" + if lens is None: + lens = [t.size(0) for t in tensors] + max_len = max(lens) + bs = len(tensors) + hid = tensors[0].size(-1) + dtype = tensors[0].dtype + output = torch.zeros(bs, max_len, hid, dtype=dtype) + if pad: + output.data.fill_(pad) + for i, (t, l) in enumerate(zip(tensors, lens)): + output.data[i, :l, ...] = t.data + return output + + +def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size): + assert len(txt_lens) == len(num_bbs) == batch_size + gather_index = torch.arange(0, out_size, dtype=torch.long, + ).unsqueeze(0).repeat(batch_size, 1) + + for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)): + gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb, + dtype=torch.long).data + return gather_index + + +class ConcatDatasetWithLens(ConcatDataset): + """ A thin wrapper on pytorch concat dataset for lens batching """ + def __init__(self, datasets): + super().__init__(datasets) + self.lens = [l for dset in datasets for l in dset.lens] + + def __getattr__(self, name): + return self._run_method_on_all_dsets(name) + + def _run_method_on_all_dsets(self, name): + def run_all(*args, **kwargs): + return [dset.__getattribute__(name)(*args, **kwargs) + for dset in self.datasets] + return run_all + + +class ImageLmdbGroup(object): + def __init__(self, conf_th, max_bb, min_bb, num_bb, compress): + self.path2imgdb = {} + self.conf_th = conf_th + self.max_bb = max_bb + self.min_bb = min_bb + self.num_bb = num_bb + self.compress = compress + + def __getitem__(self, path): + img_db = self.path2imgdb.get(path, None) + if img_db is None: + img_db = DetectFeatLmdb(path, self.conf_th, self.max_bb, + self.min_bb, self.num_bb, self.compress) + return img_db diff --git a/data/itm.py b/data/itm.py new file mode 100644 index 0000000..c572e1d --- /dev/null +++ b/data/itm.py @@ -0,0 +1,722 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Itm dataset +""" +from collections import defaultdict +import copy +import random + +import torch +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip +from cytoolz import concat +import numpy as np + +from .data import (DetectFeatTxtTokDataset, DetectFeatLmdb, TxtTokLmdb, + pad_tensors, get_gather_index, get_ids_and_lens) +from .sampler import TokenBucketSampler +from collections import Counter + + +def random_word(tokens, vocab_range, mask): + """ + Masking some random tokens for Language Model task with probabilities as in + the original BERT paper. + :param tokens: list of int, tokenized sentence. + :param vocab_range: for choosing a random word + :return: (list of int, list of int), masked tokens and related labels for + LM prediction + """ + output_label = [] + + for i, token in enumerate(tokens): + prob = random.random() + # mask token with 15% probability + if prob < 0.15: + prob /= 0.15 + + # 80% randomly change token to mask token + if prob < 0.8: + tokens[i] = mask + + # 10% randomly change token to random token + elif prob < 0.9: + tokens[i] = random.choice(list(range(*vocab_range))) + + # -> rest 10% randomly keep current token + + # append current token to output (we will predict these later) + output_label.append(token) + else: + # no masking token (will be ignored by loss function later) + output_label.append(-1) + if all(o == -1 for o in output_label): + # at least mask 1 + output_label[0] = tokens[0] + tokens[0] = mask + + return tokens, output_label + + +def random_word2(tokens, vocab_range, mask, mlm_positions, position_to_prob): + """ + Masking some random tokens for Language Model task with probabilities as in + the original BERT paper. + :param tokens: list of int, tokenized sentence. + :param vocab_range: for choosing a random word + :param mlm_positions: the positions sequence. if select, all tokens in the positions would be masked. + :param position_to_prob: the sampling probability of each token in the specific position + :return: (list of int, list of int), masked tokens and related labels for + LM prediction + """ + random.shuffle(mlm_positions) + output_label = [-1] * len(tokens) + for positions in mlm_positions: + sample_prob = sum([position_to_prob.get(position, 0.15) for position in positions]) / max(len(positions), 1) + prob = random.random() + if prob < sample_prob: + prob /= sample_prob + + for position in positions: + token = tokens[position] + if output_label[position] == -1: + prob2 = random.random() + if prob2 < 0.8: + tokens[position] = mask + elif prob2 < 0.9: + tokens[position] = random.choice(list(range(*vocab_range))) + output_label[position] = token + + if all(o == -1 for o in output_label): + # at least mask 1 + select_positions = mlm_positions[0] + for position in select_positions: + token = tokens[position] + prob2 = random.random() + if prob2 < 0.8: + tokens[position] = mask + elif prob2 < 0.9: + tokens[position] = random.choice(list(range(*vocab_range))) + output_label[position] = token + + return tokens, output_label + + +class TokenBucketSamplerForItm(TokenBucketSampler): + def __init__(self, dset, *args, **kwargs): + super().__init__(dset.lens, *args, **kwargs) + self.dset = dset + + def __iter__(self): + it = super().__iter__() + self.dset.new_epoch() + self._lens = self.dset.lens + return it + + +def _has_overlap(la, lb): + if len(la) < len(lb): + la, lb = lb, la + s = set(la) + return any(b in s for b in lb) + + +def sample_negative(sample_pool, ground_truths, num_sample): + """ random and retry """ + outputs = ground_truths[:1] + while _has_overlap(outputs, ground_truths): + outputs = random.sample(sample_pool, num_sample) + return outputs + + +class ItmDataset(DetectFeatTxtTokDataset): + """ NOTE this Dataset handles distributed training itself + (for more efficient negative sampling) """ + def __init__(self, txt_db, img_db, neg_sample_p=0.5): + assert isinstance(txt_db, TxtTokLmdb) + assert isinstance(img_db, DetectFeatLmdb) + + self.txt_db = txt_db + self.img_db = img_db + + self.txt_lens, self.ids = get_ids_and_lens(txt_db) + self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids)) + + self.neg_sample_p = neg_sample_p + self.new_epoch() + + def new_epoch(self): + """ should be called every epoch for more randomness""" + self.labels = np.random.choice( + [0, 1], size=len(self.ids), + p=[self.neg_sample_p, 1-self.neg_sample_p]) + + self.lens = [] + self.train_imgs = [] + for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)): + img_fname = super().__getitem__(i)['img_fname'] + if self.labels[i] == 0: + img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0] + self.train_imgs.append(img_fname) + self.lens.append(tl + self.img_db.name2nbb[img_fname]) + + def __getitem__(self, i): + example = super().__getitem__(i) + # labels and negative images should be sampled every epoch + ground_truth_label = self.labels[i] + img_fname = self.train_imgs[i] + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + target = torch.Tensor(1).long() + target.data.fill_(ground_truth_label) + + return input_ids, img_feat, img_pos_feat, attn_masks, target + + +def itm_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets} + return batch + + +def _compute_ot_scatter(txt_lens, max_txt_len, joint_len): + ot_scatter = torch.arange(0, joint_len, dtype=torch.long + ).unsqueeze(0).repeat(len(txt_lens), 1) + for i, tl in enumerate(txt_lens): + max_ind = max_txt_len + (joint_len-tl) + ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind, + dtype=torch.long).data + return ot_scatter + + +def _compute_pad(lens, max_len): + pad = torch.zeros(len(lens), max_len, dtype=torch.uint8) + for i, l in enumerate(lens): + pad.data[i, l:].fill_(1) + return pad + + +def itm_ot_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + # OT inputs + max_tl = max(txt_lens) + max_nbb = max(num_bbs) + ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1)) + txt_pad = _compute_pad(txt_lens, max_tl) + img_pad = _compute_pad(num_bbs, max_nbb) + ot_inputs = {'ot_scatter': ot_scatter, + 'scatter_max': ot_scatter.max().item(), + 'txt_pad': txt_pad, + 'img_pad': img_pad} + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets, + 'ot_inputs': ot_inputs} + return batch + + +class ItmRankDataset(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, \ + "ItmRankDataset need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + # images partitioned by rank + self.img2txts = defaultdict(list) + for id_, img in self.txt2img.items(): + self.img2txts[img].append(id_) + self.img_name_list = list(self.img2txts.keys()) + + assert neg_sample_size > 0 + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + id_pairs = [(gt_txt_id, gt_img_fname)] + # sample negatives + neg_sample_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + neg_sample_txt_ids = sample_negative( + self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) + id_pairs.extend([(gt_txt_id, neg_img_id) + for neg_img_id in neg_sample_img_ids] + + [(neg_txt_id, gt_img_fname) + for neg_txt_id in neg_sample_txt_ids]) + inputs = self._collect_inputs(id_pairs) + assert len(inputs) == (1 + 2*self.neg_sample_size) + return inputs + + def _collect_inputs(self, id_pairs): + # create input features + inputs = [] + for txt_id, img_id in id_pairs: + example = self.txt_db[txt_id] + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + # img input + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id) + # mask + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + inputs.append((input_ids, img_feat, img_pos_feat, attn_masks)) + + return inputs + + +def itm_rank_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, + ) = map(list, unzip(concat(i for i in inputs))) + + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + sample_size = len(inputs[0]) + assert all(sample_size == len(i) for i in inputs) + + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'sample_size': sample_size} + return batch + + +class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, "need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.img_name_list = list(self.img2txts.keys()) + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + input_ids = self.txt_db[gt_txt_id]['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + neg_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + img_ids = [gt_img_fname] + neg_img_ids + # process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() + for i, nbb in enumerate(num_bbs): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index([tl]*len(img_ids), num_bbs, + len(img_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index} + return batch + + +class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, "need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.txt_name_list = list(self.txt2img.keys()) + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + gt_txt_ids = self.img2txts[gt_img_id] + + # process image features (gt always first) + img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id) + img_feat = img_feat.unsqueeze(0) + img_pos_feat = img_pos_feat.unsqueeze(0) + + # sample negative + neg_txt_ids = sample_negative( + self.txt_name_list, gt_txt_ids, self.neg_sample_size) + txt_ids = [gt_txt_id] + neg_txt_ids + + # process text inputs + all_inputs = [] + txt_lens = [] + for txt_id in txt_ids: + input_ids = self.txt_db.combine_inputs( + self.txt_db[txt_id]['input_ids']) + all_inputs.append(input_ids) + txt_lens.append(len(input_ids)) + input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long() + for i, tl in enumerate(txt_lens): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids), + len(txt_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index} + return batch + + +def itm_rank_hn_collate(inputs): + assert len(inputs) == 1 + return inputs[0] + + +class ItmValDataset(DetectFeatTxtTokDataset): + """ For evaluating Image-Text-Retrieval task """ + def __init__(self, db_dir, img_dir, mini_batch_size=400, mlm_sample_size=1): + super().__init__(db_dir, img_dir) + + self.txt_lens = self.lens[:] + del self.lens + self.txt2img = self.txt_db.txt2img + self.img2txts = self.txt_db.img2txts + self.all_img_ids = list(self.img2txts.keys()) + self.id2len = self.txt_db.id2len + self.i2len = {i: self.id2len.get(self.ids[i]) for i in range(len(self.ids))} + + assert len(self.img2txts) >= mini_batch_size > 0 + self.bs = mini_batch_size + self.mlm_sample_size = mlm_sample_size + + def _get_batch_ids(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + + # sample fixed negatives for each gt image + i = self.all_img_ids.index(gt_img_id) + neg_st = i+1 + neg_end = neg_st+self.bs-1 + if neg_end > len(self.all_img_ids): + # warp around + neg_end -= len(self.all_img_ids) + neg_img_ids = (self.all_img_ids[neg_st:] + + self.all_img_ids[:neg_end]) + else: + neg_img_ids = self.all_img_ids[neg_st:neg_end] + + assert len(neg_img_ids) == (self.bs - 1),\ + "Did not sample enough neg samples" + + return gt_img_id, neg_img_ids + + def __getitem__(self, i): + """ this returns list of mini-batches """ + gt_img_id, neg_img_ids = self._get_batch_ids(i) + # NOTE 1st one is gt img + batch = self.get_batch(i, [gt_img_id] + neg_img_ids) + return batch + + def create_mlm_io(self, input_ids, nsample=1): + mlm_input_ids, mlm_txt_labels = [], [] + for i in range(nsample): + # tokens, vocab_range, mask, mlm_positions, position_to_prob + r_input_ids, txt_labels = random_word( + copy.copy(input_ids), self.txt_db.v_range, self.txt_db.mask) + mlm_input_ids.append(torch.tensor([self.txt_db.cls_] + r_input_ids + [self.txt_db.sep])) + mlm_txt_labels.append(torch.tensor([-1] + txt_labels + [-1])) + mlm_input_ids = torch.stack(mlm_input_ids, dim=0) + mlm_txt_labels = torch.stack(mlm_txt_labels, dim=0) + return mlm_input_ids, mlm_txt_labels + + def create_mlm_io2(self, input_ids, tree=None, nsample=1): + mlm_input_ids, mlm_txt_labels = [], [] + sample_prob = 0.15 + + mlm_positions = [[i] for i in range(len(input_ids))] + for struct_type in ['relation', 'attribute', 'node']: + struct_nodes = tree.get(struct_type) + for struct_node in struct_nodes: + positions = struct_node.get('ids') + if positions is not None: + mlm_positions.append(positions) + # mlm_positions = list(set(mlm_positions)) + position_counter = Counter() + for positions in mlm_positions: + position_counter.update(positions) + position_to_prob = {position: sample_prob / max(freq, 1.0) for position, freq in position_counter.items()} + + # print("| mlm_positions: ", mlm_positions) + for i in range(nsample): + r_input_ids, txt_labels = random_word2( + copy.copy(input_ids), self.txt_db.v_range, self.txt_db.mask, + mlm_positions=mlm_positions, position_to_prob=position_to_prob) + mlm_input_ids.append(torch.tensor([self.txt_db.cls_] + r_input_ids + [self.txt_db.sep])) + mlm_txt_labels.append(torch.tensor([-1] + txt_labels + [-1])) + mlm_input_ids = torch.stack(mlm_input_ids, dim=0) + mlm_txt_labels = torch.stack(mlm_txt_labels, dim=0) + return mlm_input_ids, mlm_txt_labels + + def get_batch(self, i, img_ids, forward_mlm=False): + batch = {} + + example = super().__getitem__(i) + + input_ids = example['input_ids'] + + if forward_mlm: + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + img_ids = [gt_img_fname] + # Process Masked Text to Generate Adversarial Samples + # mlm_input_ids, mlm_txt_labels = self.create_mlm_io(input_ids, nsample=self.mlm_sample_size) + + tree = self.txt_db[gt_txt_id]['tree'] + mlm_input_ids, mlm_txt_labels = self.create_mlm_io2(input_ids, tree=tree, nsample=self.mlm_sample_size) + + mlm_position_ids = torch.arange(0, mlm_input_ids.size(1), dtype=torch.long).\ + unsqueeze(0).expand(self.mlm_sample_size, -1) + img_feat, img_pos_feat, num_bbs = self._get_img_feat(gt_img_fname) + mlm_img_feat = img_feat.unsqueeze(dim=0).expand(self.mlm_sample_size, *list(img_feat.size())) + mlm_img_pos_feat = img_pos_feat.unsqueeze(dim=0).expand(self.mlm_sample_size, *list(img_pos_feat.size())) + tl = mlm_input_ids.size(1) + mlm_attn_masks = torch.zeros(self.mlm_sample_size, tl+num_bbs).long() + mlm_attn_masks.data[:, :tl+num_bbs].fill_(1) + mlm_gather_index = get_gather_index( + [tl]*self.mlm_sample_size, [num_bbs]*self.mlm_sample_size, self.mlm_sample_size, tl, tl+num_bbs) + + batch['mlm_input_ids'] = mlm_input_ids + batch['mlm_position_ids'] = mlm_position_ids + batch['mlm_img_feat'] = mlm_img_feat + batch['mlm_img_pos_feat'] = mlm_img_pos_feat + batch['mlm_attn_masks'] = mlm_attn_masks + batch['mlm_gather_index'] = mlm_gather_index + batch['mlm_txt_labels'] = mlm_txt_labels + + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + # process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() + for i, nbb in enumerate(num_bbs): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index([tl]*len(img_ids), num_bbs, + len(img_ids), tl, out_size) + + batch['input_ids'] = input_ids + batch['position_ids'] = position_ids + batch['img_feat'] = img_feat + batch['img_pos_feat'] = img_pos_feat + batch['attn_masks'] = attn_masks + batch['gather_index'] = gather_index + return batch + + +def itm_val_collate(inputs): + assert len(inputs) == 1, "input batch size > 1" + return inputs[0] + + +class ItmEvalDataset(ItmValDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), + key=lambda i: self.img_db.name2nbb[i]) + + def __getitem__(self, i): + mini_batches = [] + for st in range(0, len(self.all_img_ids), self.bs): + mini_batches.append( + self.get_batch(i, self.all_img_ids[st:st+self.bs])) + return mini_batches + + +class ItmAdvEvalDataset(ItmValDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), + key=lambda i: self.img_db.name2nbb[i]) + self.all_img_ids = np.array(self.all_img_ids) + + def __getitem__(self, i): + return self.get_batch(i, [], forward_mlm=True) + + +class ItmDCEvalDataset(ItmValDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), + key=lambda i: self.img_db.name2nbb[i]) + + def __getitem__(self, i): + return self.get_batch(i, [], forward_mlm=True) + + def get_by_img_index(self, i, image_indices): + mini_batch = self.get_batch(i, self.all_img_ids[image_indices]) + mini_batch_img_txt_input_ids = [] + for image_index in image_indices: + img_id = self.all_img_ids[image_index] + txt_ids = self.img2txts[img_id] + img_txt_input_ids = [] + for txt_id in txt_ids: + input_ids = self.txt_db[txt_id]['input_ids'] + img_txt_input_ids.extend(input_ids) + mini_batch_img_txt_input_ids.append(torch.tensor(list(set(mini_batch_img_txt_input_ids)))) + mini_batch['img_txt_input_ids'] = mini_batch_img_txt_input_ids + mini_batches = [mini_batch] + return mini_batches + + +class ItmStaticDataAttackEvalDataset(DetectFeatTxtTokDataset): + def __init__(self, db_dir, img_dir, mini_batch_size=400, mlm_sample_size=1): + super().__init__(db_dir, img_dir) + + self.txt_lens = self.lens[:] + del self.lens + self.txt2img = self.txt_db.txt2img + self.img2txts = self.txt_db.img2txts + self.all_img_ids = list(self.img2txts.keys()) + self.id2len = self.txt_db.id2len + self.i2len = {i: self.id2len.get(self.ids[i]) for i in range(len(self.ids))} + + assert len(self.img2txts) >= mini_batch_size > 0 + self.bs = mini_batch_size + self.mlm_sample_size = mlm_sample_size + + def __getitem__(self, i): + return self.get_batch(i, [], forward_mlm=True) + + def get_batch(self, i, img_ids, forward_mlm=False): + batch = {} + example = super().__getitem__(i) + input_ids = example['input_ids'] + all_attack_text_ids = example['attack_data'][0] + all_attack_text_ids = all_attack_text_ids[:400] + + all_inputs = [self.txt_db.combine_inputs(token_ids) for token_ids in [input_ids] + all_attack_text_ids] + txt_lens = [len(token_ids) for token_ids in all_inputs] + input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0) + # print('| inputs_ids: ', len(all_attack_text_ids), input_ids.size()) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) + + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + + # process image features (gt always first) + img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id) + img_feat = img_feat.unsqueeze(0) + img_pos_feat = img_pos_feat.unsqueeze(0) + # print("| img_feat: ", img_feat.size(), img_pos_feat.size()) + + attn_masks = torch.zeros(len(txt_lens), max(txt_lens) + nbb).long() + for i, tl in enumerate(txt_lens): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, [nbb]*len(txt_lens), len(txt_lens), tl, out_size) + + batch['input_ids'] = input_ids + batch['position_ids'] = position_ids + batch['img_feat'] = img_feat + batch['img_pos_feat'] = img_pos_feat + batch['attn_masks'] = attn_masks + batch['gather_index'] = gather_index + return batch + + +itm_eval_collate = itm_val_collate diff --git a/data/loader.py b/data/loader.py new file mode 100644 index 0000000..e177c07 --- /dev/null +++ b/data/loader.py @@ -0,0 +1,150 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +A prefetch loader to speedup data loading +Modified from Nvidia Deep Learning Examples +(https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch). +""" +import random + +import torch +from torch.utils.data import DataLoader + +from utils.distributed import any_broadcast + + +class MetaLoader(object): + """ wraps multiple data loaders """ + def __init__(self, loaders, accum_steps=1, distributed=False): + assert isinstance(loaders, dict) + self.name2loader = {} + self.name2iter = {} + self.sampling_pools = [] + for n, l in loaders.items(): + if isinstance(l, tuple): + l, r = l + elif isinstance(l, DataLoader): + r = 1 + else: + raise ValueError() + self.name2loader[n] = l + self.name2iter[n] = iter(l) + self.sampling_pools.extend([n]*r) + + self.accum_steps = accum_steps + self.distributed = distributed + self.step = 0 + + def __iter__(self): + """ this iterator will run indefinitely """ + task = self.sampling_pools[0] + while True: + if self.step % self.accum_steps == 0: + task = random.choice(self.sampling_pools) + if self.distributed: + # make sure all process is training same task + task = any_broadcast(task, 0) + self.step += 1 + iter_ = self.name2iter[task] + try: + batch = next(iter_) + except StopIteration: + iter_ = iter(self.name2loader[task]) + batch = next(iter_) + self.name2iter[task] = iter_ + + yield task, batch + + +def move_to_cuda(batch): + if isinstance(batch, torch.Tensor): + try: + return batch.cuda(non_blocking=True) + except RuntimeError: + # print('| Error |', batch.size(), batch.dtype) + return batch.contiguous().cuda(non_blocking=True) + elif isinstance(batch, list): + new_batch = [move_to_cuda(t) for t in batch] + elif isinstance(batch, tuple): + new_batch = tuple(move_to_cuda(t) for t in batch) + elif isinstance(batch, dict): + new_batch = {} + for n, t in batch.items(): + # print(n, t.size(), t[0]) + new_batch[n] = move_to_cuda(t) + # new_batch = {n: move_to_cuda(t) for n, t in batch.items()} + else: + return batch + return new_batch + + +def record_cuda_stream(batch): + if isinstance(batch, torch.Tensor): + batch.record_stream(torch.cuda.current_stream()) + elif isinstance(batch, list) or isinstance(batch, tuple): + for t in batch: + record_cuda_stream(t) + elif isinstance(batch, dict): + for t in batch.values(): + record_cuda_stream(t) + else: + pass + + +class PrefetchLoader(object): + """ + overlap compute and cuda data transfer + (copied and then modified from nvidia apex) + """ + def __init__(self, loader): + self.loader = loader + self.stream = torch.cuda.Stream() + + def __iter__(self): + loader_it = iter(self.loader) + self.preload(loader_it) + batch = self.next(loader_it) + while batch is not None: + yield batch + batch = self.next(loader_it) + + def __len__(self): + return len(self.loader) + + def preload(self, it): + try: + self.batch = next(it) + except StopIteration: + self.batch = None + return + # if record_stream() doesn't work, another option is to make sure + # device inputs are created on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, + # device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, + # device='cuda') + # Need to make sure the memory allocated for next_* is not still in use + # by the main stream at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.batch = move_to_cuda(self.batch) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this + # side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + def next(self, it): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is not None: + record_cuda_stream(batch) + self.preload(it) + return batch + + def __getattr__(self, name): + method = self.loader.__getattribute__(name) + return method diff --git a/data/pnsgd.py b/data/pnsgd.py new file mode 100644 index 0000000..c35ff26 --- /dev/null +++ b/data/pnsgd.py @@ -0,0 +1,616 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Itm dataset +""" +from collections import defaultdict +import copy +import random +from collections import Counter +import torch +from torch.nn.utils.rnn import pad_sequence +from toolz.sandbox import unzip +from cytoolz import concat +import numpy as np + +from .data import (DetectFeatTxtTokDataset, DetectFeatLmdb, TxtTokLmdb, + pad_tensors, get_gather_index, get_ids_and_lens) +from .sampler import TokenBucketSampler + + +def random_word(tokens, vocab_range, mask): + """ + Masking some random tokens for Language Model task with probabilities as in + the original BERT paper. + :param tokens: list of int, tokenized sentence. + :param vocab_range: for choosing a random word + :return: (list of int, list of int), masked tokens and related labels for + LM prediction + """ + output_label = [] + + for i, token in enumerate(tokens): + prob = random.random() + # mask token with 15% probability + if prob < 0.15: + prob /= 0.15 + + # 80% randomly change token to mask token + if prob < 0.8: + tokens[i] = mask + + # 10% randomly change token to random token + elif prob < 0.9: + tokens[i] = random.choice(list(range(*vocab_range))) + + # -> rest 10% randomly keep current token + + # append current token to output (we will predict these later) + output_label.append(token) + else: + # no masking token (will be ignored by loss function later) + output_label.append(-1) + if all(o == -1 for o in output_label): + # at least mask 1 + output_label[0] = tokens[0] + tokens[0] = mask + + return tokens, output_label + + +def random_word2(tokens, vocab_range, mask, mlm_positions, position_to_prob): + """ + Masking some random tokens for Language Model task with probabilities as in + the original BERT paper. + :param tokens: list of int, tokenized sentence. + :param vocab_range: for choosing a random word + :param mlm_positions: the positions sequence. if select, all tokens in the positions would be masked. + :param position_to_prob: the sampling probability of each token in the specific position + :return: (list of int, list of int), masked tokens and related labels for + LM prediction + """ + random.shuffle(mlm_positions) + output_label = [-1] * len(tokens) + for positions in mlm_positions: + sample_prob = sum([position_to_prob.get(position, 0.15) for position in positions]) / max(len(positions), 1) + prob = random.random() + if prob < sample_prob: + prob /= sample_prob + + for position in positions: + token = tokens[position] + if output_label[position] == -1: + prob2 = random.random() + if prob2 < 0.8: + tokens[position] = mask + elif prob2 < 0.9: + tokens[position] = random.choice(list(range(*vocab_range))) + output_label[position] = token + + if all(o == -1 for o in output_label): + # at least mask 1 + select_positions = mlm_positions[0] + for position in select_positions: + token = tokens[position] + prob2 = random.random() + if prob2 < 0.8: + tokens[position] = mask + elif prob2 < 0.9: + tokens[position] = random.choice(list(range(*vocab_range))) + output_label[position] = token + + return tokens, output_label + + +class TokenBucketSamplerForItm(TokenBucketSampler): + def __init__(self, dset, *args, **kwargs): + super().__init__(dset.lens, *args, **kwargs) + self.dset = dset + + def __iter__(self): + it = super().__iter__() + self.dset.new_epoch() + self._lens = self.dset.lens + return it + + +def _has_overlap(la, lb): + if len(la) < len(lb): + la, lb = lb, la + s = set(la) + return any(b in s for b in lb) + + +def sample_negative(sample_pool, ground_truths, num_sample): + """ random and retry """ + outputs = ground_truths[:1] + while _has_overlap(outputs, ground_truths): + outputs = random.sample(sample_pool, num_sample) + return outputs + + +class ItmDataset(DetectFeatTxtTokDataset): + """ NOTE this Dataset handles distributed training itself + (for more efficient negative sampling) """ + def __init__(self, txt_db, img_db, neg_sample_p=0.5): + assert isinstance(txt_db, TxtTokLmdb) + assert isinstance(img_db, DetectFeatLmdb) + + self.txt_db = txt_db + self.img_db = img_db + + self.txt_lens, self.ids = get_ids_and_lens(txt_db) + self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids)) + + self.neg_sample_p = neg_sample_p + self.new_epoch() + + def new_epoch(self): + """ should be called every epoch for more randomness""" + self.labels = np.random.choice( + [0, 1], size=len(self.ids), + p=[self.neg_sample_p, 1-self.neg_sample_p]) + + self.lens = [] + self.train_imgs = [] + for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)): + img_fname = super().__getitem__(i)['img_fname'] + if self.labels[i] == 0: + img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0] + self.train_imgs.append(img_fname) + self.lens.append(tl + self.img_db.name2nbb[img_fname]) + + def __getitem__(self, i): + example = super().__getitem__(i) + # labels and negative images should be sampled every epoch + ground_truth_label = self.labels[i] + img_fname = self.train_imgs[i] + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname) + + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + target = torch.Tensor(1).long() + target.data.fill_(ground_truth_label) + + return input_ids, img_feat, img_pos_feat, attn_masks, target + + +def itm_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets} + return batch + + +def _compute_ot_scatter(txt_lens, max_txt_len, joint_len): + ot_scatter = torch.arange(0, joint_len, dtype=torch.long + ).unsqueeze(0).repeat(len(txt_lens), 1) + for i, tl in enumerate(txt_lens): + max_ind = max_txt_len + (joint_len-tl) + ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind, + dtype=torch.long).data + return ot_scatter + + +def _compute_pad(lens, max_len): + pad = torch.zeros(len(lens), max_len, dtype=torch.uint8) + for i, l in enumerate(lens): + pad.data[i, l:].fill_(1) + return pad + + +def itm_ot_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, targets + ) = map(list, unzip(inputs)) + + txt_lens = [i.size(0) for i in input_ids] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + targets = torch.cat(targets, dim=0) + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + # OT inputs + max_tl = max(txt_lens) + max_nbb = max(num_bbs) + ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1)) + txt_pad = _compute_pad(txt_lens, max_tl) + img_pad = _compute_pad(num_bbs, max_nbb) + ot_inputs = {'ot_scatter': ot_scatter, + 'scatter_max': ot_scatter.max().item(), + 'txt_pad': txt_pad, + 'img_pad': img_pad} + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'targets': targets, + 'ot_inputs': ot_inputs} + return batch + + +class ItmRankDataset(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, \ + "ItmRankDataset need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + # images partitioned by rank + self.img2txts = defaultdict(list) + for id_, img in self.txt2img.items(): + self.img2txts[img].append(id_) + self.img_name_list = list(self.img2txts.keys()) + + assert neg_sample_size > 0 + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + id_pairs = [(gt_txt_id, gt_img_fname)] + # sample negatives + neg_sample_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + neg_sample_txt_ids = sample_negative( + self.ids, self.img2txts[gt_img_fname], self.neg_sample_size) + id_pairs.extend([(gt_txt_id, neg_img_id) + for neg_img_id in neg_sample_img_ids] + + [(neg_txt_id, gt_img_fname) + for neg_txt_id in neg_sample_txt_ids]) + inputs = self._collect_inputs(id_pairs) + assert len(inputs) == (1 + 2*self.neg_sample_size) + return inputs + + def _collect_inputs(self, id_pairs): + # create input features + inputs = [] + for txt_id, img_id in id_pairs: + example = self.txt_db[txt_id] + # text input + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + # img input + img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id) + # mask + attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long) + + inputs.append((input_ids, img_feat, img_pos_feat, attn_masks)) + + return inputs + + +def itm_rank_collate(inputs): + (input_ids, img_feats, img_pos_feats, attn_masks, + ) = map(list, unzip(concat(i for i in inputs))) + + txt_lens = [i.size(0) for i in input_ids] + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + num_bbs = [f.size(0) for f in img_feats] + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0) + sample_size = len(inputs[0]) + assert all(sample_size == len(i) for i in inputs) + + bs, max_tl = input_ids.size() + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'sample_size': sample_size} + return batch + + +class PNSGDFromText(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1, mlm_sample_size=1): + assert neg_sample_size > 0, "need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.img_name_list = list(self.img2txts.keys()) + self.neg_sample_size = neg_sample_size + self.mlm_sample_size = mlm_sample_size + + def create_mlm_io(self, input_ids, tree=None, nsample=1): + mlm_input_ids, mlm_txt_labels = [], [] + sample_prob = 0.15 + + mlm_positions = [] + for struct_type in ['relation', 'attribute', 'node']: + struct_nodes = tree.get(struct_type) + for struct_node in struct_nodes: + positions = struct_node.get('ids') + if positions is not None: + mlm_positions.append(positions) + if len(mlm_positions) < 1: + mlm_positions = [[i] for i in range(len(input_ids))] + + # mlm_positions = list(set(mlm_positions)) + position_counter = Counter() + for positions in mlm_positions: + position_counter.update(positions) + position_to_prob = {position: sample_prob / max(freq, 1.0) for position, freq in position_counter.items()} + + # print("| mlm_positions: ", mlm_positions) + for i in range(nsample): + # tokens, vocab_range, mask, mlm_positions, position_to_prob + r_input_ids, txt_labels = random_word2( + copy.copy(input_ids), self.txt_db.v_range, self.txt_db.mask, + mlm_positions=mlm_positions, position_to_prob=position_to_prob) + mlm_input_ids.append(torch.tensor([self.txt_db.cls_] + r_input_ids + [self.txt_db.sep])) + mlm_txt_labels.append(torch.tensor([-1] + txt_labels + [-1])) + mlm_input_ids = torch.stack(mlm_input_ids, dim=0) + mlm_txt_labels = torch.stack(mlm_txt_labels, dim=0) + return mlm_input_ids, mlm_txt_labels + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_fname = self.txt2img[gt_txt_id] + + txt_ids = self.txt_db.img2txts[gt_img_fname] + tot_input_ids = [] + for txt_id in txt_ids: + tot_input_ids.extend(self.txt_db[txt_id]['input_ids']) + tot_input_ids = torch.tensor(tot_input_ids) + + input_ids = self.txt_db[gt_txt_id]['input_ids'] + tree = self.txt_db[gt_txt_id]['tree'] + + mlm_input_ids, mlm_txt_labels = self.create_mlm_io(input_ids, tree=tree, nsample=self.mlm_sample_size) + mlm_position_ids = torch.arange(0, mlm_input_ids.size(1), dtype=torch.long).\ + unsqueeze(0).expand(self.mlm_sample_size, -1) + img_feat, img_pos_feat, num_bbs = self._get_img_feat(gt_img_fname) + mlm_img_feat = img_feat.unsqueeze(dim=0).expand(self.mlm_sample_size, *list(img_feat.size())) + mlm_img_pos_feat = img_pos_feat.unsqueeze(dim=0).expand(self.mlm_sample_size, *list(img_pos_feat.size())) + tl = mlm_input_ids.size(1) + mlm_attn_masks = torch.zeros(self.mlm_sample_size, tl+num_bbs).long() + mlm_attn_masks.data[:, :tl+num_bbs].fill_(1) + mlm_gather_index = get_gather_index( + [tl]*self.mlm_sample_size, [num_bbs]*self.mlm_sample_size, self.mlm_sample_size, tl, tl+num_bbs) + + # Process Text for Image and Hard Text Matching + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0) + + neg_img_ids = sample_negative( + self.img_name_list, [gt_img_fname], self.neg_sample_size) + img_ids = [gt_img_fname] + neg_img_ids + + # Process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() + for i, nbb in enumerate(num_bbs): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index([tl]*len(img_ids), num_bbs, + len(img_ids), tl, out_size) + + batch = { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index, + 'mlm_input_ids': mlm_input_ids, + 'mlm_position_ids': mlm_position_ids, + 'mlm_img_feat': mlm_img_feat, + 'mlm_img_pos_feat': mlm_img_pos_feat, + 'mlm_attn_masks': mlm_attn_masks, + 'mlm_gather_index': mlm_gather_index, + 'mlm_txt_labels': mlm_txt_labels, + 'input_dict': tot_input_ids + } + return batch + + +class PNSGDFromImage(DetectFeatTxtTokDataset): + def __init__(self, txt_db, img_db, neg_sample_size=1): + assert neg_sample_size > 0, "need at least 1 negative sample" + super().__init__(txt_db, img_db) + + txt2img = self.txt_db.txt2img + self.txt2img = {id_: txt2img[id_] for id_ in self.ids} + self.img2txts = self.txt_db.img2txts + self.txt_name_list = list(self.txt2img.keys()) + self.neg_sample_size = neg_sample_size + + def __getitem__(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + gt_txt_ids = self.img2txts[gt_img_id] + + # process image features (gt always first) + img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id) + img_feat = img_feat.unsqueeze(0) + img_pos_feat = img_pos_feat.unsqueeze(0) + + # sample negative + neg_txt_ids = sample_negative( + self.txt_name_list, gt_txt_ids, self.neg_sample_size) + txt_ids = [gt_txt_id] + neg_txt_ids + + # process text inputs + all_inputs = [] + txt_lens = [] + for txt_id in txt_ids: + input_ids = self.txt_db.combine_inputs( + self.txt_db[txt_id]['input_ids']) + all_inputs.append(input_ids) + txt_lens.append(len(input_ids)) + input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0) + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long() + for i, tl in enumerate(txt_lens): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids), + len(txt_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index} + return batch + + +def pnsgd_collate(inputs): + assert len(inputs) == 1 + return inputs[0] + + +class ItmValDataset(DetectFeatTxtTokDataset): + """ For evaluating Image-Text-Retrieval task """ + def __init__(self, db_dir, img_dir, mini_batch_size=400): + super().__init__(db_dir, img_dir) + del self.lens + self.txt2img = self.txt_db.txt2img + self.img2txts = self.txt_db.img2txts + self.all_img_ids = list(self.img2txts.keys()) + + assert len(self.img2txts) >= mini_batch_size > 0 + self.bs = mini_batch_size + + def _get_batch_ids(self, i): + gt_txt_id = self.ids[i] + gt_img_id = self.txt2img[gt_txt_id] + + # sample fixed negatives for each gt image + i = self.all_img_ids.index(gt_img_id) + neg_st = i+1 + neg_end = neg_st+self.bs-1 + if neg_end > len(self.all_img_ids): + # warp around + neg_end -= len(self.all_img_ids) + neg_img_ids = (self.all_img_ids[neg_st:] + + self.all_img_ids[:neg_end]) + else: + neg_img_ids = self.all_img_ids[neg_st:neg_end] + + assert len(neg_img_ids) == (self.bs - 1),\ + "Did not sample enough neg samples" + + return gt_img_id, neg_img_ids + + def __getitem__(self, i): + """ this returns list of mini-batches """ + gt_img_id, neg_img_ids = self._get_batch_ids(i) + # NOTE 1st one is gt img + batch = self.get_batch(i, [gt_img_id] + neg_img_ids) + return batch + + def get_batch(self, i, img_ids): + example = super().__getitem__(i) + + input_ids = example['input_ids'] + input_ids = self.txt_db.combine_inputs(input_ids) + input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone() + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long + ).unsqueeze(0) + + # process image features (gt always first) + img_feats, img_pos_feats, num_bbs = map( + list, unzip(map(self._get_img_feat, img_ids))) + img_feat = pad_tensors(img_feats, num_bbs) + img_pos_feat = pad_tensors(img_pos_feats, num_bbs) + + tl = input_ids.size(1) + attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long() + for i, nbb in enumerate(num_bbs): + attn_masks.data[i, :tl+nbb].fill_(1) + out_size = attn_masks.size(1) + gather_index = get_gather_index([tl]*len(img_ids), num_bbs, + len(img_ids), tl, out_size) + + batch = {'input_ids': input_ids, + 'position_ids': position_ids, + 'img_feat': img_feat, + 'img_pos_feat': img_pos_feat, + 'attn_masks': attn_masks, + 'gather_index': gather_index} + return batch + + +def itm_val_collate(inputs): + assert len(inputs) == 1, "input batch size > 1" + return inputs[0] + + +class ItmEvalDataset(ItmValDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids), + key=lambda i: self.img_db.name2nbb[i]) + + def __getitem__(self, i): + mini_batches = [] + for st in range(0, len(self.all_img_ids), self.bs): + mini_batches.append( + self.get_batch(i, self.all_img_ids[st:st+self.bs])) + return mini_batches + + +itm_eval_collate = itm_val_collate diff --git a/data/sampler.py b/data/sampler.py new file mode 100644 index 0000000..22863ce --- /dev/null +++ b/data/sampler.py @@ -0,0 +1,58 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +sampler for length bucketing (batch by tokens) +""" +import random + +from torch.utils.data import Sampler +from cytoolz import partition_all + + +class TokenBucketSampler(Sampler): + def __init__(self, lens, bucket_size, batch_size, + droplast=False, size_multiple=8): + self._lens = lens + self._max_tok = batch_size + self._bucket_size = bucket_size + self._droplast = droplast + self._size_mul = size_multiple + + def _create_ids(self): + return list(range(len(self._lens))) + + def _sort_fn(self, i): + return self._lens[i] + + def __iter__(self): + ids = self._create_ids() + random.shuffle(ids) + buckets = [sorted(ids[i:i+self._bucket_size], + key=self._sort_fn, reverse=True) + for i in range(0, len(ids), self._bucket_size)] + # fill batches until max_token (include padding) + batches = [] + for bucket in buckets: + max_len = 0 + batch_indices = [] + for indices in partition_all(self._size_mul, bucket): + max_len = max(max_len, max(self._lens[i] for i in indices)) + if (max_len * (len(batch_indices) + self._size_mul) + > self._max_tok): + if not batch_indices: + raise ValueError( + "max_tokens too small / max_seq_len too long") + assert len(batch_indices) % self._size_mul == 0 + batches.append(batch_indices) + batch_indices = list(indices) + else: + batch_indices.extend(indices) + if not self._droplast and batch_indices: + batches.append(batch_indices) + random.shuffle(batches) + return iter(batches) + + def __len__(self): + raise ValueError("NOT supported. " + "This has some randomness across epochs") diff --git a/inf_nsgd.py b/inf_nsgd.py new file mode 100644 index 0000000..03ad27a --- /dev/null +++ b/inf_nsgd.py @@ -0,0 +1,177 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +run inference for Image Text Retrieval +""" +import argparse +import json +import os +from os.path import exists +import pickle +from time import time + +import torch +from torch.utils.data import DataLoader + +from apex import amp +from horovod import torch as hvd + +from data import (PrefetchLoader, + DetectFeatLmdb, TxtTokLmdb, ItmEvalDataset, itm_eval_collate) +from model.nsgd import UniterForNSGD + +from utils.logger import LOGGER +from utils.distributed import all_gather_list +from utils.misc import Struct +from utils.const import IMG_DIM +from utils.itm_eval import inference, itm_eval + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + if opts.train_config is not None: + train_opts = Struct(json.load(open(opts.train_config))) + opts.conf_th = train_opts.conf_th + opts.max_bb = train_opts.max_bb + opts.min_bb = train_opts.min_bb + opts.num_bb = train_opts.num_bb + + # load DBs and image dirs + eval_img_db = DetectFeatLmdb(opts.img_db, + opts.conf_th, opts.max_bb, + opts.min_bb, opts.num_bb, + opts.compressed_db) + eval_txt_db = TxtTokLmdb(opts.txt_db, -1) + eval_dataset = ItmEvalDataset(eval_txt_db, eval_img_db, opts.batch_size) + + # Prepare model + checkpoint = torch.load(opts.checkpoint) + model = UniterForNSGD.from_pretrained( + opts.model_config, checkpoint, img_dim=IMG_DIM) + if 'rank_output' not in checkpoint: + model.init_output() # zero shot setting + + model.to(device) + model = amp.initialize(model, enabled=opts.fp16, opt_level='O2') + + print('| eval_dataset id2len: ', len(eval_dataset.id2len)) + print('| eval_dataset_id2len example key: ', list(eval_dataset.id2len.keys())[:5], + max([int(k) for k in list(eval_dataset.id2len.keys())])) + print('| eval_dataset_id2len example value: ', list(eval_dataset.id2len.values())[:5]) + # for k in range(10): + # print('| example:', k, eval_dataset.i2len[k]) + print('| i2len: ', len(eval_dataset.id2len), min(list(eval_dataset.id2len.keys())), max(eval_dataset.id2len.keys()), + min(list(eval_dataset.id2len.values())), max(eval_dataset.id2len.values())) + + print('| mean of all_txt_lens:', sum(list(eval_dataset.id2len.values())) / float(len(list(eval_dataset.id2len.values())))) + + + eval_dataloader = DataLoader(eval_dataset, batch_size=1, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, + collate_fn=itm_eval_collate) + eval_dataloader = PrefetchLoader(eval_dataloader) + + eval_log, results = evaluate(model, eval_dataloader) + if hvd.rank() == 0: + if not exists(opts.output_dir) and rank == 0: + os.makedirs(opts.output_dir) + with open(f'{opts.output_dir}/config.json', 'w') as f: + json.dump(vars(opts), f) + with open(f'{opts.output_dir}/results.bin', 'wb') as f: + pickle.dump(results, f) + with open(f'{opts.output_dir}/scores.json', 'w') as f: + json.dump(eval_log, f) + LOGGER.info(f'evaluation finished') + LOGGER.info( + f"======================== Results =========================\n" + f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n" + f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n" + f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n" + f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: {eval_log['txt_r10']*100:.2f}") + LOGGER.info("========================================================") + + +@torch.no_grad() +def evaluate(model, eval_loader): + model.eval() + st = time() + LOGGER.info("start running Image/Text Retrieval evaluation ...") + score_matrix = inference(model, eval_loader) + dset = eval_loader.dataset + all_score = hvd.allgather(score_matrix) + all_txt_ids = [i for ids in all_gather_list(dset.ids) + for i in ids] + # all_txt_lens = [l for lens in all_gather_list(dset.txt_lens) for l in lens] + print('| mean of all_txt_lens:', sum(list(dset.id2len.values())) / float(len(list(dset.id2len.values())))) + all_img_ids = dset.all_img_ids + assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) + if hvd.rank() != 0: + return {}, tuple() + # NOTE: only use rank0 to compute final scores + eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, + dset.txt2img, dset.img2txts, dset.id2len) + + results = (all_score, all_txt_ids, all_img_ids) + tot_time = time()-st + LOGGER.info(f"evaluation finished in {int(tot_time)} seconds, ") + return eval_log, results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--txt_db", default=None, type=str, + help="The input train corpus. (LMDB)") + parser.add_argument("--img_db", default=None, type=str, + help="The input train images.") + parser.add_argument("--checkpoint", default=None, type=str, + help="model checkpoint binary") + parser.add_argument("--model_config", default=None, type=str, + help="model config json") + parser.add_argument( + "--output_dir", default=None, type=str, + help="The output directory where the inference results will be " + "written.") + + # optional parameters + parser.add_argument("--train_config", default=None, type=str, + help="hps.json from training (for prepro hps)") + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + parser.add_argument("--batch_size", default=400, type=int, + help="number of tokens in a batch") + + # device parameters + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + args = parser.parse_args() + + main(args) diff --git a/launch_container.sh b/launch_container.sh new file mode 100644 index 0000000..45ceefb --- /dev/null +++ b/launch_container.sh @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +TXT_DB=$1 +IMG_DIR=$2 +OUTPUT=$3 +PRETRAIN_DIR=$4 + +if [ -z $CUDA_VISIBLE_DEVICES ]; then + CUDA_VISIBLE_DEVICES='all' +fi + + +docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm -it \ + --mount src=$(pwd),dst=/src,type=bind \ + --mount src=$OUTPUT,dst=/storage,type=bind \ + --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \ + --mount src=$TXT_DB,dst=/txt,type=bind,readonly \ + --mount src=$IMG_DIR,dst=/img,type=bind,readonly \ + -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \ + -w /src chenrocks/uniter \ No newline at end of file diff --git a/model/attention.py b/model/attention.py new file mode 100644 index 0000000..7d7e2b0 --- /dev/null +++ b/model/attention.py @@ -0,0 +1,402 @@ +""" +copy multi-head attention code from pytorch +(https://github.com/pytorch/pytorch), +""" +import warnings + +import torch +from torch.nn import Module, Parameter, Linear +from torch.nn.init import xavier_normal_, xavier_uniform_, constant_ +from torch.nn.functional import linear, softmax, dropout + + +def multi_head_attention_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Tensor + in_proj_bias, # type: Tensor + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None # type: Optional[Tensor] + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: mask that prevents attention to certain positions. This is an additive mask + (i.e. the values will be added to the attention layer). + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in differnt forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + qkv_same = torch.equal(query, key) and torch.equal(key, value) + kv_same = torch.equal(key, value) + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert list(query.size()) == [tgt_len, bsz, embed_dim] + assert key.size() == value.size() + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if use_separate_proj_weight is not True: + if qkv_same: + # self-attention + q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif kv_same: + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = linear(value, _w, _b) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) + k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) + v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + else: + q = linear(query, q_proj_weight_non_opt, in_proj_bias) + k = linear(key, k_proj_weight_non_opt, in_proj_bias) + v = linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, + torch.zeros((attn_mask.size(0), 1), + dtype=attn_mask.dtype, + device=attn_mask.device)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), + dtype=key_padding_mask.dtype, + device=key_padding_mask.device)], dim=1) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), + dtype=attn_mask.dtype, + device=attn_mask.device)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), + dtype=key_padding_mask.dtype, + device=key_padding_mask.device)], dim=1) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + attn_output_weights = softmax( + attn_output_weights, dim=-1) + attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in key. Default: None. + + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + + if self._qkv_same_embed_dim is False: + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) + self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None): + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: mask that prevents attention to certain positions. This is an additive mask + (i.e. the values will be added to the attention layer). + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + else: + if not hasattr(self, '_qkv_same_embed_dim'): + warnings.warn('A new version of MultiheadAttention module has been implemented. \ + Please re-train your model with the new module', + UserWarning) + + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask) diff --git a/model/itm.py b/model/itm.py new file mode 100644 index 0000000..a7aefd0 --- /dev/null +++ b/model/itm.py @@ -0,0 +1,140 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER for ITM model +""" +from collections import defaultdict + +import torch +from torch import nn +from .model import UniterPreTrainedModel, UniterModel + + +class UniterForImageTextRetrieval(UniterPreTrainedModel): + """ Finetune UNITER for image text retrieval + """ + def __init__(self, config, img_dim, margin=0.2): + super().__init__(config) + self.uniter = UniterModel(config, img_dim) + self.itm_output = nn.Linear(config.hidden_size, 2) + self.rank_output = nn.Linear(config.hidden_size, 1) + self.margin = margin + self.apply(self.init_weights) + + def init_output(self): + """ need to be called after from pretrained """ + self.rank_output.weight.data = self.itm_output.weight.data[1:, :] + self.rank_output.bias.data = self.itm_output.bias.data[1:] + + def forward(self, batch, compute_loss=True): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + sequence_output = self.uniter(input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index, + output_all_encoded_layers=False) + pooled_output = self.uniter.pooler(sequence_output) + rank_scores = self.rank_output(pooled_output) + + if compute_loss: + # triplet loss + rank_scores_sigmoid = torch.sigmoid(rank_scores) + sample_size = batch['sample_size'] + scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) + pos = scores[:, :1] + neg = scores[:, 1:] + rank_loss = torch.clamp(self.margin + neg - pos, 0) + # print('| success ratio: ', neg.lt(pos).float().sum().div(neg.numel())) + return rank_loss + else: + return rank_scores + + +class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval): + """ Finetune UNITER for image text retrieval + """ + def __init__(self, config, img_dim, margin=0.2, hard_size=16): + super().__init__(config, img_dim, margin) + self.hard_size = hard_size + + def forward(self, batch, sample_from='t', compute_loss=True): + # expect same input_ids for all pairs + batch_size = batch['attn_masks'].size(0) + input_ids = batch['input_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + if sample_from == 't': + if input_ids.size(0) == 1: + batch['input_ids'] = input_ids.expand(batch_size, -1) + elif sample_from == 'i': + if img_feat.size(0) == 1: + batch['img_feat'] = img_feat.expand(batch_size, -1, -1) + if img_pos_feat.size(0) == 1: + batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1) + else: + raise ValueError() + + if self.training and compute_loss: + with torch.no_grad(): + self.eval() + scores = super().forward(batch, compute_loss=False) + hard_batch = self._get_hard_batch(batch, scores, sample_from) + self.train() + return super().forward(hard_batch, compute_loss=True) + else: + return super().forward(batch, compute_loss) + + def _get_hard_batch(self, batch, scores, sample_from='t'): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + hard_batch = {'sample_size': self.hard_size + 1} + + # NOTE first example is positive + hard_indices = scores.squeeze(-1)[1:].topk( + self.hard_size, sorted=False)[1] + 1 + indices = torch.cat([torch.zeros(1, dtype=torch.long, + device=hard_indices.device), + hard_indices]) + + attention_mask = attention_mask.index_select(0, indices) + gather_index = gather_index.index_select(0, indices) + if position_ids.size(0) != 1: + position_ids = position_ids[:self.hard_size+1] + + if sample_from == 't': + # cut to minimum padding + max_len = attention_mask.sum(dim=1).max().item() + max_i = max_len - input_ids.size(1) + attention_mask = attention_mask[:, :max_len] + gather_index = gather_index[:, :max_len] + img_feat = img_feat.index_select(0, indices)[:, :max_i, :] + img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :] + # expect same input_ids for all pairs + input_ids = input_ids[:self.hard_size+1] + elif sample_from == 'i': + input_ids = input_ids.index_select(0, indices) + # expect same image features for all pairs + img_feat = img_feat[:self.hard_size+1] + img_pos_feat = img_pos_feat[:self.hard_size+1] + else: + raise ValueError() + + hard_batch['input_ids'] = input_ids + hard_batch['position_ids'] = position_ids + hard_batch['img_feat'] = img_feat + hard_batch['img_pos_feat'] = img_pos_feat + hard_batch['attn_masks'] = attention_mask + hard_batch['gather_index'] = gather_index + + return hard_batch diff --git a/model/layer.py b/model/layer.py new file mode 100644 index 0000000..ed0203f --- /dev/null +++ b/model/layer.py @@ -0,0 +1,233 @@ +""" +BERT layers from the huggingface implementation +(https://github.com/huggingface/transformers) +""" +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math + +import torch +from torch import nn +from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm + + +logger = logging.getLogger(__name__) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class GELU(nn.Module): + def forward(self, input_): + output = gelu(input_) + return output + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter( + torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, + bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000..7b2e6a9 --- /dev/null +++ b/model/model.py @@ -0,0 +1,367 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Pytorch modules +some classes are modified from HuggingFace +(https://github.com/huggingface/transformers) +""" +import copy +import json +import logging +from io import open + +import torch +from torch import nn +from apex.normalization.fused_layer_norm import FusedLayerNorm + +from .layer import BertLayer, BertPooler + + +logger = logging.getLogger(__name__) + + +class UniterConfig(object): + """Configuration class to store the configuration of a `UniterModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + """Constructs UniterConfig. + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in + `UniterModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer + encoder. + num_attention_heads: Number of attention heads for each attention + layer in the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e. + feed-forward) layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) + in the encoder and pooler. If string, "gelu", "relu" and + "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully + connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this + model might ever be used with. Typically set this to something + large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed + into `UniterModel`. + initializer_range: The sttdev of the truncated_normal_initializer + for initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str): + with open(vocab_size_or_config_json_file, + "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + else: + raise ValueError("First argument must be either a vocabulary size " + "(int) or the path to a pretrained model config " + "file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `UniterConfig` from a + Python dictionary of parameters.""" + config = UniterConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `UniterConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +class UniterPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super().__init__() + if not isinstance(config, UniterConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of " + "class `UniterConfig`. To create a model from a Google " + "pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses + # truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, + std=self.config.initializer_range) + elif isinstance(module, FusedLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, config_file, state_dict, *inputs, **kwargs): + """ + Instantiate a UniterPreTrainedModel from a pre-trained model file or a + pytorch state dict. + Params: + config_file: config json file + state_dict: an state dictionnary + *inputs, **kwargs: additional input for the specific Uniter class + """ + # Load config + config = UniterConfig.from_json_file(config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + # Load from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = ({} if metadata is None + else metadata.get(prefix[:-1], {})) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, + unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + start_prefix = '' + if not hasattr(model, 'bert') and any(s.startswith('bert.') + for s in state_dict.keys()): + start_prefix = 'bert.' + load(model, prefix=start_prefix) + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from " + "pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in " + "{}: {}".format( + model.__class__.__name__, unexpected_keys)) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for ' + '{}:\n\t{}'.format( + model.__class__.__name__, + "\n\t".join(error_msgs))) + return model + + +class UniterTextEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file + self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, position_ids, token_type_ids=None): + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = (words_embeddings + + position_embeddings + + token_type_embeddings) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class UniterImageEmbeddings(nn.Module): + def __init__(self, config, img_dim): + super().__init__() + self.img_linear = nn.Linear(img_dim, config.hidden_size) + self.img_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12) + self.pos_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12) + self.pos_linear = nn.Linear(7, config.hidden_size) + self.mask_embedding = nn.Embedding(2, img_dim, padding_idx=0) + + # tf naming convention for layer norm + self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, img_feat, img_pos_feat, type_embeddings, img_masks=None): + if img_masks is not None: + self.mask_embedding.weight.data[0, :].fill_(0) + mask = self.mask_embedding(img_masks.long()) + img_feat = img_feat + mask + + transformed_im = self.img_layer_norm(self.img_linear(img_feat)) + transformed_pos = self.pos_layer_norm(self.pos_linear(img_pos_feat)) + embeddings = transformed_im + transformed_pos + type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class UniterEncoder(nn.Module): + def __init__(self, config): + super().__init__() + layer = BertLayer(config) + self.layer = nn.ModuleList([copy.deepcopy(layer) + for _ in range(config.num_hidden_layers)]) + + def forward(self, input_, attention_mask, + output_all_encoded_layers=True): + all_encoder_layers = [] + hidden_states = input_ + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class UniterModel(UniterPreTrainedModel): + """ Modification for Joint Vision-Language Encoding + """ + def __init__(self, config, img_dim): + super().__init__(config) + self.embeddings = UniterTextEmbeddings(config) + self.img_embeddings = UniterImageEmbeddings(config, img_dim) + self.encoder = UniterEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_weights) + + def _compute_txt_embeddings(self, input_ids, position_ids, + txt_type_ids=None): + output = self.embeddings(input_ids, position_ids, txt_type_ids) + return output + + def _compute_img_embeddings(self, img_feat, img_pos_feat, img_masks=None, + img_type_ids=None): + if img_type_ids is None: + img_type_ids = torch.ones_like(img_feat[:, :, 0].long()) + img_type_embeddings = self.embeddings.token_type_embeddings( + img_type_ids) + output = self.img_embeddings(img_feat, img_pos_feat, + img_type_embeddings, img_masks) + return output + + def _compute_img_txt_embeddings(self, input_ids, position_ids, + img_feat, img_pos_feat, + gather_index, img_masks=None, + txt_type_ids=None, img_type_ids=None): + txt_emb = self._compute_txt_embeddings( + input_ids, position_ids, txt_type_ids) + img_emb = self._compute_img_embeddings( + img_feat, img_pos_feat, img_masks, img_type_ids) + # align back to most compact input + gather_index = gather_index.unsqueeze(-1).expand( + -1, -1, self.config.hidden_size) + txt_img_emb = torch.cat([txt_emb, img_emb], dim=1) + embedding_output = torch.gather(txt_img_emb, dim=1, index=gather_index) + return embedding_output + + def forward(self, input_ids, position_ids, + img_feat, img_pos_feat, + attention_mask, gather_index=None, img_masks=None, + output_all_encoded_layers=True, + txt_type_ids=None, img_type_ids=None): + # compute self-attention mask + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # embedding layer + if input_ids is None: + # image only + embedding_output = self._compute_img_embeddings( + img_feat, img_pos_feat, img_masks, img_type_ids) + elif img_feat is None: + # text only + embedding_output = self._compute_txt_embeddings( + input_ids, position_ids, txt_type_ids) + else: + embedding_output = self._compute_img_txt_embeddings( + input_ids, position_ids, + img_feat, img_pos_feat, + gather_index, img_masks, txt_type_ids, img_type_ids) + + encoded_layers = self.encoder( + embedding_output, extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers diff --git a/model/nsgd.py b/model/nsgd.py new file mode 100644 index 0000000..a9bc915 --- /dev/null +++ b/model/nsgd.py @@ -0,0 +1,358 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER for ITM model +""" +from collections import defaultdict + +import torch +from torch import nn +import torch.nn.functional as F +from .model import UniterPreTrainedModel, UniterModel +from .layer import GELU, BertOnlyMLMHead + + +def repeat_interleave(x, n_repeat=1, dim=0): + repeat_list = [1] * (dim + 1) + [n_repeat] + [1] * (x.dim() - dim - 1) + x_size = list(x.size()) + x_size[dim] = x_size[dim] * n_repeat + x = x.unsqueeze(dim+1).repeat(*repeat_list).view(*x_size) + return x + + +class UniterForNSGD(UniterPreTrainedModel): + """ Finetune UNITER for image text retrieval + """ + def __init__(self, config, img_dim, margin=0.2, hard_size=16, nsgd_sample_size=16, nsgd_sample_temperature=1.0): + super().__init__(config) + self.uniter = UniterModel(config, img_dim) + self.cls = BertOnlyMLMHead( + config, self.uniter.embeddings.word_embeddings.weight) + self.itm_output = nn.Linear(config.hidden_size, 2) + self.rank_output = nn.Linear(config.hidden_size, 1) + self.margin = margin + self.apply(self.init_weights) + self.hard_size = hard_size + self.nsgd_sample_size = nsgd_sample_size + self.nsgd_sample_temperature = nsgd_sample_temperature + + def init_output(self): + """ need to be called after from pretrained """ + self.rank_output.weight.data = self.itm_output.weight.data[1:, :] + self.rank_output.bias.data = self.itm_output.bias.data[1:] + + def forward_uniter( + self, + sample_size=None, + input_ids=None, position_ids=None, + img_feat=None, img_pos_feat=None, + attn_masks=None, gather_index=None, + compute_loss=True, sigmoid_norm=False + ): + model_outputs = {} + sequence_output = self.uniter( + input_ids, position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False + ) + pooled_output = self.uniter.pooler(sequence_output) + rank_scores = self.rank_output(pooled_output) + model_outputs['rank_scores'] = rank_scores + if compute_loss: + # triplet loss + rank_scores_sigmoid = torch.sigmoid(rank_scores) + # sample_size = batch['sample_size'] + scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) + pos = scores[:, :1] + neg = scores[:, 1:] + rank_loss = torch.clamp(self.margin + neg - pos, 0) + rank_corrects = neg.sub(pos).le(0).float() + model_outputs['rank_loss'] = rank_loss + model_outputs['rank_corrects'] = rank_corrects + if sigmoid_norm: + rank_scores_sigmoid = torch.sigmoid(rank_scores) + model_outputs['rank_scores_sigmoid'] = rank_scores_sigmoid + # sample_size = batch['sample_size'] + return model_outputs + + def _compute_masked_hidden(self, hidden, mask): + """ get only the masked region (don't compute unnecessary hiddens) """ + mask = mask.unsqueeze(-1).expand_as(hidden) + hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1)) + return hidden_masked + + def forward_mlm( + self, + input_ids=None, position_ids=None, + img_feat=None, img_pos_feat=None, + attn_masks=None, gather_index=None, + txt_labels=None, compute_loss=True, sampling=True + ): + model_outputs = {} + sequence_output = self.uniter( + input_ids, position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False) + # get only the text part + sequence_output = sequence_output[:, :input_ids.size(1), :] + + if compute_loss: + # only compute masked tokens for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, txt_labels != -1) + prediction_scores = self.cls(masked_output) + masked_lm_loss = F.cross_entropy(prediction_scores, txt_labels[txt_labels != -1], reduction='none') + model_outputs['masked_lm_loss'] = masked_lm_loss + model_outputs['mlm_corrects'] = prediction_scores.max(-1)[1].eq(txt_labels[txt_labels != -1]).float() + if sampling: + bsz, caption_len = input_ids.size(0), input_ids.size(1) + prediction_scores = self.cls(sequence_output) + sample_caption_tokens = torch.multinomial( + prediction_scores.div(self.nsgd_sample_temperature).softmax(-1).view(-1, prediction_scores.size(-1)), + num_samples=self.nsgd_sample_size, + replacement=True, + ).view(bsz, caption_len, self.nsgd_sample_size).permute(0, 2, 1) + mask_indicator = txt_labels.ne(-1).long().unsqueeze(1) + synthetic_input_ids = input_ids.unsqueeze(1).mul(1-mask_indicator).\ + add(sample_caption_tokens.mul(mask_indicator)).reshape(-1, caption_len) + model_outputs['fill_input_ids'] = synthetic_input_ids + return model_outputs + + def forward(self, batch, sample_from='t', compute_loss=True, compute_mlm=False): + # expect same input_ids for all pairs + model_outputs = {} + if not sample_from.startswith('g'): + batch_size = batch['attn_masks'].size(0) + input_ids = batch['input_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + if sample_from == 't': + if input_ids.size(0) == 1: + batch['input_ids'] = input_ids.expand(batch_size, -1) + elif sample_from == 'i': + if img_feat.size(0) == 1: + batch['img_feat'] = img_feat.expand(batch_size, -1, -1) + if img_pos_feat.size(0) == 1: + batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1) + else: + raise ValueError() + + if self.training and compute_loss: + with torch.no_grad(): + self.eval() + # print(f'| is_training: {self.training} | compute_loss: {compute_loss} |') + if torch.isnan(batch['input_ids']).sum().item() > 0 or \ + torch.isnan(batch['position_ids']).sum().item() > 0 or \ + torch.isnan(batch['img_feat']).sum().item() > 0 or \ + torch.isnan(batch['img_pos_feat']).sum().item() > 0 or \ + torch.isnan(batch['attn_masks']).sum().item() > 0 or \ + torch.isnan(batch['gather_index']).sum().item() > 0: + print(' | nan appear!') + if torch.isinf(batch['input_ids']).sum().item() > 0 or \ + torch.isinf(batch['position_ids']).sum().item() > 0 or \ + torch.isinf(batch['img_feat']).sum().item() > 0 or \ + torch.isinf(batch['img_pos_feat']).sum().item() > 0 or \ + torch.isinf(batch['attn_masks']).sum().item() > 0 or \ + torch.isinf(batch['gather_index']).sum().item() > 0: + print(' | inf appear!') + + forward_uniter_outputs = self.forward_uniter( + input_ids=batch['input_ids'], position_ids=batch['position_ids'], + img_feat=batch['img_feat'], img_pos_feat=batch['img_pos_feat'], + attn_masks=batch['attn_masks'], gather_index=batch['gather_index'], + compute_loss=False, + ) + hard_batch = self._get_hard_batch(batch, forward_uniter_outputs['rank_scores'], sample_from) + self.train() + forward_uniter_outputs = self.forward_uniter( + sample_size=hard_batch['sample_size'], + input_ids=hard_batch['input_ids'], position_ids=hard_batch['position_ids'], + img_feat=hard_batch['img_feat'], img_pos_feat=hard_batch['img_pos_feat'], + attn_masks=hard_batch['attn_masks'], gather_index=hard_batch['gather_index'], + compute_loss=True) + model_outputs['rank_loss'] = forward_uniter_outputs['rank_loss'] + model_outputs['rank_corrects'] = forward_uniter_outputs['rank_corrects'] + else: + forward_uniter_outputs = self.forward_uniter( + input_ids=batch['input_ids'], position_ids=batch['position_ids'], + img_feat=batch['img_feat'], img_pos_feat=batch['img_pos_feat'], + attn_masks=batch['attn_masks'], gather_index=batch['gather_index'], + compute_loss=compute_loss, sigmoid_norm=True) + model_outputs.update(forward_uniter_outputs) + return model_outputs + else: + if compute_mlm: + self.train() + mlm_outputs = self.forward_mlm( + input_ids=batch['mlm_input_ids'], position_ids=batch['mlm_position_ids'], + img_feat=batch['mlm_img_feat'], img_pos_feat=batch['mlm_img_pos_feat'], + attn_masks=batch['mlm_attn_masks'], gather_index=batch['mlm_gather_index'], + txt_labels=batch['mlm_txt_labels'], compute_loss=True, sampling=True + ) + model_outputs['masked_lm_loss'] = mlm_outputs['masked_lm_loss'] + # model_outputs['effect_nsgd_number'] = mlm_outputs['effect_nsgd_number'] + model_outputs['mlm_corrects'] = mlm_outputs['mlm_corrects'] + else: + with torch.no_grad(): + self.eval() + # print('| mlm_inference | mlm_input_ids: ', batch['mlm_input_ids'].size()) + mlm_outputs = self.forward_mlm( + input_ids=batch['mlm_input_ids'], position_ids=batch['mlm_position_ids'], + img_feat=batch['mlm_img_feat'], img_pos_feat=batch['mlm_img_pos_feat'], + attn_masks=batch['mlm_attn_masks'], gather_index=batch['mlm_gather_index'], + txt_labels=batch['mlm_txt_labels'], compute_loss=False, sampling=True + ) + with torch.no_grad(): + # select_indices = mlm_outputs['select_indices'] + nsgd_batch = {} + nsgd_batch['txt_ids'] = batch['input_ids'] + nsgd_batch['mlm_sample_size'] = len(batch['mlm_position_ids']) + nsgd_batch['nsgd_sample_size'] = self.nsgd_sample_size + nsgd_batch['input_ids'] = torch.cat( + [batch['input_ids'], mlm_outputs['fill_input_ids']], dim=0) + nsgd_batch['position_ids'] = torch.cat( + [batch['mlm_position_ids'][:1], + repeat_interleave(batch['mlm_position_ids'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['img_feat'] = torch.cat( + [batch['mlm_img_feat'][:1], + repeat_interleave(batch['mlm_img_feat'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['img_pos_feat'] = torch.cat( + [batch['mlm_img_pos_feat'][:1], + repeat_interleave(batch['mlm_img_pos_feat'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['attn_masks'] = torch.cat( + [batch['mlm_attn_masks'][:1], + repeat_interleave(batch['mlm_attn_masks'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['gather_index'] = torch.cat( + [batch['mlm_attn_masks'][:1], + repeat_interleave(batch['mlm_gather_index'], self.nsgd_sample_size, dim=0)], dim=0) + self.eval() + forward_uniter_outputs = self.forward_uniter( + input_ids=nsgd_batch['input_ids'], position_ids=nsgd_batch['position_ids'], + img_feat=nsgd_batch['img_feat'], img_pos_feat=nsgd_batch['img_pos_feat'], + attn_masks=nsgd_batch['attn_masks'], gather_index=nsgd_batch['gather_index'], + compute_loss=False, sigmoid_norm=True + ) + nsgd_batch = self._get_nsgd_batch( + nsgd_batch, scores=forward_uniter_outputs['rank_scores'], + clean=compute_loss) + self.train() + assert batch['input_ids'].ne(nsgd_batch['input_ids'][0]).long().sum().item() == 0 + model_outputs['effect_nsgd_number'] = nsgd_batch['effect_num'] + model_outputs['rank_adv_scores'] = forward_uniter_outputs['rank_scores_sigmoid'] + if compute_loss: + forward_uniter_outputs = self.forward_uniter( + sample_size=nsgd_batch['sample_size'], + input_ids=nsgd_batch['input_ids'], position_ids=nsgd_batch['position_ids'], + img_feat=nsgd_batch['img_feat'], img_pos_feat=nsgd_batch['img_pos_feat'], + attn_masks=nsgd_batch['attn_masks'], gather_index=nsgd_batch['gather_index'], + compute_loss=True + ) + model_outputs.update(forward_uniter_outputs) + else: + model_outputs['nsgd_adv_batch'] = nsgd_batch + return model_outputs + + def _get_nsgd_batch(self, batch, scores, clean=True): + batch = defaultdict(lambda: None, batch) + txt_ids = batch['txt_ids'] + mlm_sample_size, nsgd_sample_size = batch['mlm_sample_size'], batch['nsgd_sample_size'] + + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + hard_batch = {'sample_size': self.hard_size + 1} + + # NOTE first example is positive + if clean: + # print('| clean:', clean) + c1_penalty_scores = repeat_interleave(txt_ids, mlm_sample_size * nsgd_sample_size + 1, dim=0).\ + ne(input_ids).float().sum(-1).le(0).float().mul(3) + c2_penalty_scores = input_ids.unsqueeze(1).ne(input_ids.unsqueeze(0)).\ + float().sum(-1).le(0).float().\ + triu(diagonal=1).sum(-1).gt(0).float().mul(2) + effect_num = c1_penalty_scores.add(c2_penalty_scores).eq(0).long().sum() + hard_batch['effect_num'] = effect_num + hard_scores = scores.squeeze(-1).sub(c1_penalty_scores.add(c2_penalty_scores).type_as(scores)) + else: + hard_batch['effect_num'] = len(scores) + hard_scores = scores.squeeze(-1) + # print('| hard_scores: ', hard_scores.size(), self.hard_size) + hard_indices = hard_scores[1:].topk(self.hard_size, sorted=True)[1] + 1 + # hard_indices = hard_indices[256:] + hard_size = len(hard_indices) + indices = torch.cat([torch.zeros(1, dtype=torch.long, device=hard_indices.device), + hard_indices]) + # indices = hard_indices + + attention_mask = attention_mask.index_select(0, indices) + gather_index = gather_index.index_select(0, indices) + if position_ids.size(0) != 1: + position_ids = position_ids[:hard_size+1] + + input_ids = input_ids.index_select(0, indices) + # expect same image features for all pairs + img_feat = img_feat[:hard_size+1] + img_pos_feat = img_pos_feat[:hard_size+1] + + hard_batch['input_ids'] = input_ids + hard_batch['position_ids'] = position_ids + hard_batch['img_feat'] = img_feat + hard_batch['img_pos_feat'] = img_pos_feat + hard_batch['attn_masks'] = attention_mask + hard_batch['gather_index'] = gather_index + + return hard_batch + + def _get_hard_batch(self, batch, scores, sample_from='t'): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + hard_batch = {'sample_size': self.hard_size + 1} + + # NOTE first example is positive + hard_indices = scores.squeeze(-1)[1:].topk( + self.hard_size, sorted=False)[1] + 1 + indices = torch.cat([torch.zeros(1, dtype=torch.long, + device=hard_indices.device), + hard_indices]) + + attention_mask = attention_mask.index_select(0, indices) + gather_index = gather_index.index_select(0, indices) + if position_ids.size(0) != 1: + position_ids = position_ids[:self.hard_size+1] + + if sample_from == 't': + # cut to minimum padding + max_len = attention_mask.sum(dim=1).max().item() + max_i = max_len - input_ids.size(1) + attention_mask = attention_mask[:, :max_len] + gather_index = gather_index[:, :max_len] + img_feat = img_feat.index_select(0, indices)[:, :max_i, :] + img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :] + # expect same input_ids for all pairs + input_ids = input_ids[:self.hard_size+1] + elif sample_from == 'i': + input_ids = input_ids.index_select(0, indices) + # expect same image features for all pairs + img_feat = img_feat[:self.hard_size+1] + img_pos_feat = img_pos_feat[:self.hard_size+1] + else: + raise ValueError() + + hard_batch['input_ids'] = input_ids + hard_batch['position_ids'] = position_ids + hard_batch['img_feat'] = img_feat + hard_batch['img_pos_feat'] = img_pos_feat + hard_batch['attn_masks'] = attention_mask + hard_batch['gather_index'] = gather_index + + return hard_batch diff --git a/model/nsgd2.py b/model/nsgd2.py new file mode 100644 index 0000000..334a7e2 --- /dev/null +++ b/model/nsgd2.py @@ -0,0 +1,456 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER for ITM model +""" +from collections import defaultdict + +import torch +from torch import nn +import torch.nn.functional as F +from .model import UniterPreTrainedModel, UniterModel +from .layer import GELU, BertOnlyMLMHead + + +def repeat_interleave(x, n_repeat=1, dim=0): + repeat_list = [1] * (dim + 1) + [n_repeat] + [1] * (x.dim() - dim - 1) + x_size = list(x.size()) + x_size[dim] = x_size[dim] * n_repeat + x = x.unsqueeze(dim+1).repeat(*repeat_list).view(*x_size) + return x + + +class UniterForNSGD2(UniterPreTrainedModel): + """ Finetune UNITER for image text retrieval + """ + def __init__(self, config, img_dim, margin=0.2, hard_size=16, nsgd_sample_size=16, nsgd_sample_temperature=1.0, + disc_weights=(1.0, 0.11)): + super().__init__(config) + self.uniter = UniterModel(config, img_dim) + self.cls = BertOnlyMLMHead( + config, self.uniter.embeddings.word_embeddings.weight) + self.itm_output = nn.Linear(config.hidden_size, 2) + self.rank_output = nn.Linear(config.hidden_size, 1) + self.disc_output = nn.Linear(config.hidden_size, 2) + self.correction_output = BertOnlyMLMHead( + config, self.uniter.embeddings.word_embeddings.weight) + + self.margin = margin + self.apply(self.init_weights) + self.hard_size = hard_size + self.nsgd_sample_size = nsgd_sample_size + self.nsgd_sample_temperature = nsgd_sample_temperature + self.disc_weights = torch.tensor(disc_weights) + print('| disc_weights: ', self.disc_weights) + print('-'*100) + + def init_output(self): + """ need to be called after from pretrained """ + self.rank_output.weight.data = self.itm_output.weight.data[1:, :] + self.rank_output.bias.data = self.itm_output.bias.data[1:] + self.disc_output.weight.data.copy_(torch.tensor(self.itm_output.weight.data.cpu().numpy())) + self.disc_output.bias.data.copy_(torch.tensor(self.itm_output.bias.data.cpu().numpy())) + for name, param in self.correction_output.named_parameters(): + cls_param = self.cls.state_dict().get(name) + param.data.copy_(torch.tensor(cls_param.data.cpu().numpy())) + + def forward_uniter( + self, + sample_size=None, + input_ids=None, position_ids=None, + img_feat=None, img_pos_feat=None, + attn_masks=None, gather_index=None, + compute_loss=True, sigmoid_norm=False + ): + model_outputs = {} + sequence_output = self.uniter( + input_ids, position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False + ) + pooled_output = self.uniter.pooler(sequence_output) + rank_scores = self.rank_output(pooled_output) + model_outputs['rank_scores'] = rank_scores + if compute_loss: + # triplet loss + rank_scores_sigmoid = torch.sigmoid(rank_scores) + # sample_size = batch['sample_size'] + scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) + pos = scores[:, :1] + neg = scores[:, 1:] + rank_loss = torch.clamp(self.margin + neg - pos, 0) + rank_corrects = neg.sub(pos).le(0).float() + model_outputs['rank_loss'] = rank_loss + model_outputs['rank_corrects'] = rank_corrects + if sigmoid_norm: + rank_scores_sigmoid = torch.sigmoid(rank_scores) + model_outputs['rank_scores_sigmoid'] = rank_scores_sigmoid + return model_outputs + + def _compute_masked_hidden(self, hidden, mask): + """ get only the masked region (don't compute unnecessary hiddens) """ + mask = mask.unsqueeze(-1).expand_as(hidden) + hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1)) + return hidden_masked + + def forward_mlm( + self, + input_ids=None, position_ids=None, + img_feat=None, img_pos_feat=None, + attn_masks=None, gather_index=None, + txt_labels=None, compute_loss=True, sampling=True + ): + model_outputs = {} + if gather_index.size(0) == 1 and gather_index.size(0) < input_ids.size(0): + rep_in_other_dimension_size = [1] * (len(gather_index.size()) - 1) + gather_index = gather_index.repeat(input_ids.size(0), *rep_in_other_dimension_size) + + sequence_output = self.uniter( + input_ids, position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False) + # get only the text part + sequence_output = sequence_output[:, :input_ids.size(1), :] + + if compute_loss: + # only compute masked tokens for better efficiency + masked_output = self._compute_masked_hidden(sequence_output, txt_labels != -1) + # print('| forward_mlm: ', masked_output.min(), masked_output.max()) + prediction_scores = self.cls(masked_output) + masked_lm_loss = F.cross_entropy(prediction_scores, txt_labels[txt_labels != -1], reduction='none') + model_outputs['masked_lm_loss'] = masked_lm_loss + model_outputs['mlm_corrects'] = prediction_scores.max(-1)[1].eq(txt_labels[txt_labels != -1]).float() + if sampling: + bsz, caption_len = input_ids.size(0), input_ids.size(1) + prediction_scores = self.cls(sequence_output) + sample_caption_tokens = torch.multinomial( + prediction_scores.div(self.nsgd_sample_temperature).softmax(-1).view(-1, prediction_scores.size(-1)), + num_samples=self.nsgd_sample_size, + replacement=True, + ).view(bsz, caption_len, self.nsgd_sample_size).permute(0, 2, 1) + mask_indicator = txt_labels.ne(-1).long().unsqueeze(1) + synthetic_input_ids = input_ids.unsqueeze(1).mul(1-mask_indicator).\ + add(sample_caption_tokens.mul(mask_indicator)).reshape(-1, caption_len) + model_outputs['fill_input_ids'] = synthetic_input_ids + return model_outputs + + def forward_gd( + self, + syn_input_ids=None, syn_position_ids=None, + img_feat=None, img_pos_feat=None, + attn_masks=None, gather_index=None, + txt_labels=None, + gt_input_ids=None, compute_loss=True, compute_all_correction_output=False, + ): + model_outputs = {} + if gather_index.size(0) == 1 and gather_index.size(0) < syn_input_ids.size(0): + rep_in_other_dimension_size = [1] * (len(gather_index.size()) - 1) + gather_index = gather_index.repeat(syn_input_ids.size(0), *rep_in_other_dimension_size) + + sequence_output = self.uniter( + syn_input_ids, syn_position_ids, + img_feat, img_pos_feat, + attn_masks, gather_index, + output_all_encoded_layers=False) + # get only the text part + sequence_output = sequence_output[:, :syn_input_ids.size(1), :] + + # disc loss + disc_p = self.disc_output(sequence_output).softmax(-1) + model_outputs['disc_p'] = disc_p + if compute_loss: + disc_labels = syn_input_ids.eq(gt_input_ids).long() + # print('| disc_labels: ', disc_labels[:3]) + disc_loss = F.nll_loss( + disc_p.clamp(1e-10).log().reshape(-1, disc_p.size(-1)), + disc_labels.view(-1), + reduction='none', + weight=self.disc_weights.type_as(disc_p)) + model_outputs['disc_loss'] = disc_loss + model_outputs['disc_pos_corrects'] = disc_p.max(-1)[1].eq(disc_labels).float().\ + mul(disc_labels.eq(1).float()).sum() + model_outputs['disc_pos_samples'] = disc_pos_ntokens = disc_labels.eq(1).sum() + model_outputs['disc_neg_corrects'] = disc_p.max(-1)[1].eq(disc_labels).float().\ + mul(disc_labels.eq(0).float()).sum() + model_outputs['disc_neg_samples'] = disc_neg_ntokens = disc_labels.eq(0).sum() + model_outputs['disc_corrects'] = disc_p.max(-1)[1].eq(disc_labels).float().sum() + model_outputs['disc_samples'] = syn_input_ids.numel() + + if compute_loss: + masked_output = self._compute_masked_hidden(sequence_output, txt_labels != -1) + logits = self.correction_output(masked_output) # .div(3.0).clamp(-8.0, 8.0) + # print('| logits: ', logits.min(), logits.max()) + overall_p = correction_p = logits.softmax(-1) + masked_lm_loss = F.nll_loss( + overall_p.clamp(1e-10).log(), + txt_labels[txt_labels != -1], + reduction='none') + + model_outputs['correction_loss'] = masked_lm_loss + model_outputs['correction_corrects'] = correction_corrects = \ + overall_p.max(-1)[1].eq(txt_labels[txt_labels != -1]).float() + + if compute_all_correction_output: + logits = self.correction_output(sequence_output) # .div(3.0).clamp(-8.0, 8.0) + model_outputs['correction_texts'] = logits.max(-1)[1] + model_outputs['all_correction_corrects'] = all_correction_corrects = \ + logits.max(-1)[1][..., 1:-1].eq(gt_input_ids[..., 1:-1]).float() + + # print('| correction_corrects: {} | all_correction_corrects: {}'.format( + # correction_corrects.size(), all_correction_corrects.size())) + # print(correction_corrects.sum().item(), correction_corrects.numel(), + # all_correction_corrects.sum().item(), all_correction_corrects.numel()) + return model_outputs + + def forward(self, batch, sample_from='t', compute_loss=True, compute_mlm=False, compute_all_correction_output=False): + # expect same input_ids for all pairs + model_outputs = {} + if not sample_from.startswith('g'): + batch_size = batch['attn_masks'].size(0) + input_ids = batch['input_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + if sample_from == 't': + if input_ids.size(0) == 1: + batch['input_ids'] = input_ids.expand(batch_size, -1) + elif sample_from == 'i': + if img_feat.size(0) == 1: + batch['img_feat'] = img_feat.expand(batch_size, -1, -1) + if img_pos_feat.size(0) == 1: + batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1) + else: + raise ValueError() + + if self.training and compute_loss: + with torch.no_grad(): + self.eval() + forward_uniter_outputs = self.forward_uniter( + input_ids=batch['input_ids'], position_ids=batch['position_ids'], + img_feat=batch['img_feat'], img_pos_feat=batch['img_pos_feat'], + attn_masks=batch['attn_masks'], gather_index=batch['gather_index'], + compute_loss=False, + ) + hard_batch = self._get_hard_batch(batch, forward_uniter_outputs['rank_scores'], sample_from) + self.train() + forward_uniter_outputs = self.forward_uniter( + sample_size=hard_batch['sample_size'], + input_ids=hard_batch['input_ids'], position_ids=hard_batch['position_ids'], + img_feat=hard_batch['img_feat'], img_pos_feat=hard_batch['img_pos_feat'], + attn_masks=hard_batch['attn_masks'], gather_index=hard_batch['gather_index'], + compute_loss=True) + model_outputs['rank_loss'] = forward_uniter_outputs['rank_loss'] + model_outputs['rank_corrects'] = forward_uniter_outputs['rank_corrects'] + else: + forward_uniter_outputs = self.forward_uniter( + input_ids=batch['input_ids'], position_ids=batch['position_ids'], + img_feat=batch['img_feat'], img_pos_feat=batch['img_pos_feat'], + attn_masks=batch['attn_masks'], gather_index=batch['gather_index'], + compute_loss=compute_loss) + model_outputs.update(forward_uniter_outputs) + return model_outputs + elif sample_from == 'gsynt': + batch_size = batch['syn_attn_masks'].size(0) + if batch['gt_input_ids'].size(0) == 1: + batch['gt_input_ids'] = batch['gt_input_ids'].expand(batch_size, -1) + forward_uniter_outputs = self.forward_gd( + syn_input_ids=batch['syn_input_ids'], syn_position_ids=batch['syn_position_ids'], + img_feat=batch['syn_img_feat'], img_pos_feat=batch['syn_img_pos_feat'], + attn_masks=batch['syn_attn_masks'], gather_index=batch['syn_gather_index'], + txt_labels=batch['syn_txt_labels'], + gt_input_ids=batch['gt_input_ids'], compute_all_correction_output=compute_all_correction_output) + model_outputs.update(forward_uniter_outputs) + return model_outputs + else: + if compute_mlm: + self.train() + mlm_outputs = self.forward_mlm( + input_ids=batch['mlm_input_ids'], position_ids=batch['mlm_position_ids'], + img_feat=batch['mlm_img_feat'], img_pos_feat=batch['mlm_img_pos_feat'], + attn_masks=batch['mlm_attn_masks'], gather_index=batch['mlm_gather_index'], + txt_labels=batch['mlm_txt_labels'], compute_loss=True, sampling=True + ) + model_outputs['masked_lm_loss'] = mlm_outputs['masked_lm_loss'] + # model_outputs['effect_nsgd_number'] = mlm_outputs['effect_nsgd_number'] + model_outputs['mlm_corrects'] = mlm_outputs['mlm_corrects'] + else: + with torch.no_grad(): + self.eval() + mlm_outputs = self.forward_mlm( + input_ids=batch['mlm_input_ids'], position_ids=batch['mlm_position_ids'], + img_feat=batch['mlm_img_feat'], img_pos_feat=batch['mlm_img_pos_feat'], + attn_masks=batch['mlm_attn_masks'], gather_index=batch['gather_index'], + txt_labels=batch['mlm_txt_labels'], compute_loss=False, sampling=True + ) + with torch.no_grad(): + # select_indices = mlm_outputs['select_indices'] + nsgd_batch = {} + nsgd_batch['txt_ids'] = batch['input_ids'] + nsgd_batch['mlm_sample_size'] = len(batch['mlm_position_ids']) + nsgd_batch['nsgd_sample_size'] = self.nsgd_sample_size + nsgd_batch['input_dict'] = batch['input_dict'] + nsgd_batch['input_ids'] = torch.cat( + [batch['input_ids'], mlm_outputs['fill_input_ids']], dim=0) + nsgd_batch['position_ids'] = torch.cat( + [batch['mlm_position_ids'][:1], + repeat_interleave(batch['mlm_position_ids'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['img_feat'] = torch.cat( + [batch['mlm_img_feat'][:1], + repeat_interleave(batch['mlm_img_feat'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['img_pos_feat'] = torch.cat( + [batch['mlm_img_pos_feat'][:1], + repeat_interleave(batch['mlm_img_pos_feat'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['attn_masks'] = torch.cat( + [batch['mlm_attn_masks'][:1], + repeat_interleave(batch['mlm_attn_masks'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['gather_index'] = torch.cat( + [batch['mlm_attn_masks'][:1], + repeat_interleave(batch['mlm_gather_index'], self.nsgd_sample_size, dim=0)], dim=0) + nsgd_batch['txt_labels'] = torch.cat( + [batch['mlm_txt_labels'][:1], + repeat_interleave(batch['mlm_txt_labels'], self.nsgd_sample_size, dim=0)], dim=0) + self.eval() + forward_uniter_outputs = self.forward_uniter( + input_ids=nsgd_batch['input_ids'], position_ids=nsgd_batch['position_ids'], + img_feat=nsgd_batch['img_feat'], img_pos_feat=nsgd_batch['img_pos_feat'], + attn_masks=nsgd_batch['attn_masks'], gather_index=nsgd_batch['gather_index'], + compute_loss=False, sigmoid_norm=True + ) + nsgd_batch = self._get_nsgd_batch( + nsgd_batch, scores=forward_uniter_outputs['rank_scores_sigmoid']) + assert batch['input_ids'].ne(nsgd_batch['input_ids'][0]).long().sum().item() == 0 + # print('| gt_input_ids: {} | syn_input_ids[:2]: {} | txt_labels[:2]: {}'. + # format(batch['input_ids'], nsgd_batch['input_ids'][:3], nsgd_batch['txt_labels'][:3])) + model_outputs['syn_input_ids'] = nsgd_batch['input_ids'][1:] + model_outputs['syn_position_ids'] = nsgd_batch['position_ids'][1:] + model_outputs['syn_img_feat'] = nsgd_batch['img_feat'][1:] + model_outputs['syn_img_pos_feat'] = nsgd_batch['img_pos_feat'][1:] + model_outputs['syn_attn_masks'] = nsgd_batch['attn_masks'][1:] + model_outputs['syn_gather_index'] = nsgd_batch['gather_index'][1:] + model_outputs['syn_txt_labels'] = nsgd_batch['txt_labels'][1:] + model_outputs['gt_input_ids'] = batch['input_ids'][0].unsqueeze(0) + model_outputs['effect_nsgd_number'] = nsgd_batch['effect_num'] + model_outputs['rank_adv_scores'] = nsgd_batch['rank_scores_sigmoid'] + + if compute_loss: + self.train() + forward_uniter_outputs = self.forward_uniter( + sample_size=nsgd_batch['sample_size'], + input_ids=nsgd_batch['input_ids'], position_ids=nsgd_batch['position_ids'], + img_feat=nsgd_batch['img_feat'], img_pos_feat=nsgd_batch['img_pos_feat'], + attn_masks=nsgd_batch['attn_masks'], gather_index=nsgd_batch['gather_index'], + compute_loss=True + ) + model_outputs.update(forward_uniter_outputs) + else: + model_outputs['nsgd_adv_batch'] = nsgd_batch + return model_outputs + + def _get_nsgd_batch(self, batch, scores, clean=True): + batch = defaultdict(lambda: None, batch) + # txt_ids = batch['txt_ids'] + input_dict = batch['input_dict'] + + mlm_sample_size, nsgd_sample_size = batch['mlm_sample_size'], batch['nsgd_sample_size'] + + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + txt_labels = batch['txt_labels'] + hard_batch = {'sample_size': self.hard_size + 1} + + # NOTE first example is positive + if clean: + # c1_penalty_scores = repeat_interleave(txt_ids, mlm_sample_size * nsgd_sample_size + 1, dim=0).\ + # ne(input_ids).float().sum(-1).le(0).float().mul(3) + c1_penalty_scores = repeat_interleave(input_dict, mlm_sample_size * nsgd_sample_size + 1, dim=0).\ + ne(input_ids).float().sum(-1).le(0).float().mul(3) + + c2_penalty_scores = input_ids.unsqueeze(1).ne(input_ids.unsqueeze(0)).\ + float().sum(-1).le(0).float().\ + triu(diagonal=1).sum(-1).gt(0).float().mul(2) + effect_num = c1_penalty_scores.add(c2_penalty_scores).eq(0).long().sum() + hard_batch['effect_num'] = effect_num + hard_scores = scores.squeeze(-1).sub(c1_penalty_scores.add(c2_penalty_scores).type_as(scores)) + else: + hard_batch['effect_num'] = len(scores) + hard_scores = scores.squeeze(-1) + hard_indices = hard_scores[1:].topk(self.hard_size, sorted=True)[1] + 1 + hard_size = len(hard_indices) + indices = torch.cat([torch.zeros(1, dtype=torch.long, device=hard_indices.device), + hard_indices]) + + attention_mask = attention_mask.index_select(0, indices) + gather_index = gather_index.index_select(0, indices) + txt_labels = txt_labels.index_select(0, indices) + if position_ids.size(0) != 1: + position_ids = position_ids[:hard_size+1] + + input_ids = input_ids.index_select(0, indices) + # expect same image features for all pairs + img_feat = img_feat[:hard_size+1] + img_pos_feat = img_pos_feat[:hard_size+1] + + hard_batch['input_ids'] = input_ids + hard_batch['position_ids'] = position_ids + hard_batch['img_feat'] = img_feat + hard_batch['img_pos_feat'] = img_pos_feat + hard_batch['attn_masks'] = attention_mask + hard_batch['gather_index'] = gather_index + hard_batch['txt_labels'] = txt_labels + hard_batch['rank_scores_sigmoid'] = torch.cat([scores[0], hard_scores.index_select(0, hard_indices)], dim=0) + return hard_batch + + def _get_hard_batch(self, batch, scores, sample_from='t'): + batch = defaultdict(lambda: None, batch) + input_ids = batch['input_ids'] + position_ids = batch['position_ids'] + img_feat = batch['img_feat'] + img_pos_feat = batch['img_pos_feat'] + attention_mask = batch['attn_masks'] + gather_index = batch['gather_index'] + hard_batch = {'sample_size': self.hard_size + 1} + + # NOTE first example is positive + hard_indices = scores.squeeze(-1)[1:].topk( + self.hard_size, sorted=False)[1] + 1 + indices = torch.cat([torch.zeros(1, dtype=torch.long, + device=hard_indices.device), + hard_indices]) + + attention_mask = attention_mask.index_select(0, indices) + gather_index = gather_index.index_select(0, indices) + if position_ids.size(0) != 1: + position_ids = position_ids[:self.hard_size+1] + + if sample_from == 't': + # cut to minimum padding + max_len = attention_mask.sum(dim=1).max().item() + max_i = max_len - input_ids.size(1) + attention_mask = attention_mask[:, :max_len] + gather_index = gather_index[:, :max_len] + img_feat = img_feat.index_select(0, indices)[:, :max_i, :] + img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :] + # expect same input_ids for all pairs + input_ids = input_ids[:self.hard_size+1] + elif sample_from == 'i': + input_ids = input_ids.index_select(0, indices) + # expect same image features for all pairs + img_feat = img_feat[:self.hard_size+1] + img_pos_feat = img_pos_feat[:self.hard_size+1] + else: + raise ValueError() + + hard_batch['input_ids'] = input_ids + hard_batch['position_ids'] = position_ids + hard_batch['img_feat'] = img_feat + hard_batch['img_pos_feat'] = img_pos_feat + hard_batch['attn_masks'] = attention_mask + hard_batch['gather_index'] = gather_index + + return hard_batch diff --git a/optim/__init__.py b/optim/__init__.py new file mode 100644 index 0000000..3c21fa9 --- /dev/null +++ b/optim/__init__.py @@ -0,0 +1,7 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" +from .sched import noam_schedule, warmup_linear, vqa_schedule, get_lr_sched +from .adamw import AdamW diff --git a/optim/adamw.py b/optim/adamw.py new file mode 100644 index 0000000..f8e346c --- /dev/null +++ b/optim/adamw.py @@ -0,0 +1,103 @@ +""" +AdamW optimizer (weight decay fix) +copied from hugginface (https://github.com/huggingface/transformers). +""" +import math + +import torch +from torch.optim import Optimizer + + +class AdamW(Optimizer): + """ Implements Adam algorithm with weight decay fix. + Parameters: + lr (float): learning rate. Default 1e-3. + betas (tuple of 2 floats): Adams beta parameters (b1, b2). + Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + weight_decay (float): Weight decay. Default: 0.0 + correct_bias (bool): can be set to False to avoid correcting bias + in Adam (e.g. like in Bert TF repository). Default True. + """ + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0.0, correct_bias=True): + if lr < 0.0: + raise ValueError( + "Invalid learning rate: {} - should be >= 0.0".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter: {} - " + "should be in [0.0, 1.0[".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter: {} - " + "should be in [0.0, 1.0[".format(betas[1])) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {} - " + "should be >= 0.0".format(eps)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + correct_bias=correct_bias) + super(AdamW, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'Adam does not support sparse ' + 'gradients, please consider SparseAdam instead') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(1.0 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) + denom = exp_avg_sq.sqrt().add_(group['eps']) + + step_size = group['lr'] + if group['correct_bias']: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state['step'] + bias_correction2 = 1.0 - beta2 ** state['step'] + step_size = (step_size * math.sqrt(bias_correction2) + / bias_correction1) + + p.data.addcdiv_(-step_size, exp_avg, denom) + + # Just adding the square of the weights to the loss function is + # *not* the correct way of using L2 regularization/weight decay + # with Adam, since that will interact with the m and v + # parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't + # interact with the m/v parameters. This is equivalent to + # adding the square of the weights to the loss with plain + # (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group['weight_decay'] > 0.0: + p.data.add_(-group['lr'] * group['weight_decay'], p.data) + + return loss diff --git a/optim/misc.py b/optim/misc.py new file mode 100644 index 0000000..1368ab4 --- /dev/null +++ b/optim/misc.py @@ -0,0 +1,35 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Misc lr helper +""" +from torch.optim import Adam, Adamax + +from .adamw import AdamW + + +def build_optimizer(model, opts): + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer + if not any(nd in n for nd in no_decay)], + 'weight_decay': opts.weight_decay}, + {'params': [p for n, p in param_optimizer + if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0} + ] + + # currently Adam only + if opts.optim == 'adam': + OptimCls = Adam + elif opts.optim == 'adamax': + OptimCls = Adamax + elif opts.optim == 'adamw': + OptimCls = AdamW + else: + raise ValueError('invalid optimizer') + optimizer = OptimCls(optimizer_grouped_parameters, + lr=opts.learning_rate, betas=opts.betas) + return optimizer diff --git a/optim/sched.py b/optim/sched.py new file mode 100644 index 0000000..a4c46d6 --- /dev/null +++ b/optim/sched.py @@ -0,0 +1,46 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +optimizer learning rate scheduling helpers +""" +from math import ceil + + +def noam_schedule(step, warmup_step=4000): + """ original Transformer schedule""" + if step <= warmup_step: + return step / warmup_step + return (warmup_step ** 0.5) * (step ** -0.5) + + +def warmup_linear(step, warmup_step, tot_step): + """ BERT schedule """ + if step < warmup_step: + return step / warmup_step + return max(0, (tot_step-step)/(tot_step-warmup_step)) + + +def vqa_schedule(step, warmup_interval, decay_interval, + decay_start, decay_rate): + """ VQA schedule from MCAN """ + if step < warmup_interval: + return 1/4 + elif step < 2 * warmup_interval: + return 2/4 + elif step < 3 * warmup_interval: + return 3/4 + elif step >= decay_start: + num_decay = ceil((step - decay_start) / decay_interval) + return decay_rate ** num_decay + else: + return 1 + + +def get_lr_sched(global_step, opts): + # learning rate scheduling + lr_this_step = opts.learning_rate * warmup_linear( + global_step, opts.warmup_steps, opts.num_train_steps) + if lr_this_step <= 0: + lr_this_step = 1e-8 + return lr_this_step diff --git a/prepro.py b/prepro.py new file mode 100644 index 0000000..39236cf --- /dev/null +++ b/prepro.py @@ -0,0 +1,101 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +preprocess NLVR annotations into LMDB +""" +import argparse +import json +import os +from os.path import exists + +from cytoolz import curry +from tqdm import tqdm +from pytorch_pretrained_bert import BertTokenizer + +from data.data import open_lmdb + + +@curry +def bert_tokenize(tokenizer, text): + ids = [] + for word in text.strip().split(): + ws = tokenizer.tokenize(word) + if not ws: + # some special char + continue + ids.extend(tokenizer.convert_tokens_to_ids(ws)) + return ids + + +def process_nlvr2(jsonl, db, tokenizer, missing=None): + id2len = {} + txt2img = {} # not sure if useful + for line in tqdm(jsonl, desc='processing NLVR2'): + example = json.loads(line) + id_ = example['identifier'] + img_id = '-'.join(id_.split('-')[:-1]) + img_fname = (f'nlvr2_{img_id}-img0.npz', f'nlvr2_{img_id}-img1.npz') + if missing and (img_fname[0] in missing or img_fname[1] in missing): + continue + input_ids = tokenizer(example['sentence']) + if 'label' in example: + target = 1 if example['label'] == 'True' else 0 + else: + target = None + txt2img[id_] = img_fname + id2len[id_] = len(input_ids) + example['input_ids'] = input_ids + example['img_fname'] = img_fname + example['target'] = target + db[id_] = example + return id2len, txt2img + + +def main(opts): + if not exists(opts.output): + os.makedirs(opts.output) + else: + raise ValueError('Found existing DB. Please explicitly remove ' + 'for re-processing') + meta = vars(opts) + meta['tokenizer'] = opts.toker + toker = BertTokenizer.from_pretrained( + opts.toker, do_lower_case='uncased' in opts.toker) + tokenizer = bert_tokenize(toker) + meta['UNK'] = toker.convert_tokens_to_ids(['[UNK]'])[0] + meta['CLS'] = toker.convert_tokens_to_ids(['[CLS]'])[0] + meta['SEP'] = toker.convert_tokens_to_ids(['[SEP]'])[0] + meta['MASK'] = toker.convert_tokens_to_ids(['[MASK]'])[0] + meta['v_range'] = (toker.convert_tokens_to_ids('!')[0], + len(toker.vocab)) + with open(f'{opts.output}/meta.json', 'w') as f: + json.dump(vars(opts), f, indent=4) + + open_db = curry(open_lmdb, opts.output, readonly=False) + with open_db() as db: + with open(opts.annotation) as ann: + if opts.missing_imgs is not None: + missing_imgs = set(json.load(open(opts.missing_imgs))) + else: + missing_imgs = None + id2lens, txt2img = process_nlvr2(ann, db, tokenizer, missing_imgs) + + with open(f'{opts.output}/id2len.json', 'w') as f: + json.dump(id2lens, f) + with open(f'{opts.output}/txt2img.json', 'w') as f: + json.dump(txt2img, f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--annotation', required=True, + help='annotation JSON') + parser.add_argument('--missing_imgs', + help='some training image features are corrupted') + parser.add_argument('--output', required=True, + help='output dir of DB') + parser.add_argument('--toker', default='bert-base-cased', + help='which BERT tokenizer to used') + args = parser.parse_args() + main(args) diff --git a/run_cmds/inf_nsgd.sh b/run_cmds/inf_nsgd.sh new file mode 100644 index 0000000..972dacb --- /dev/null +++ b/run_cmds/inf_nsgd.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +export NGPU=8 +export DATA_ROOT=./UNITER/itm-data + +export model_version=1 +export arch=2000 +export MODEL_ROOT=/mnt/Projects/UNITER/log/itm_nsgd/uniter_nsgd_base_v${model_version}/ckpt +export OUTPUT_DIR=./model_log/uniter_nsgd_base_v${model_version}/step${arch} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np $NGPU python inf_nsgd.py \ + --txt_db ${DATA_ROOT}/txt_db3/itm_flickr30k_test.db --img_db ${DATA_ROOT}/img_db/flickr30k \ + --checkpoint ${MODEL_ROOT}/model_step_${arch}.pt --model_config ./config/uniter-base.json \ + --output_dir ${OUTPUT_DIR} --fp16 --pin_mem diff --git a/run_cmds/train_pnsgd2_base_coco.sh b/run_cmds/train_pnsgd2_base_coco.sh new file mode 100644 index 0000000..202cb3b --- /dev/null +++ b/run_cmds/train_pnsgd2_base_coco.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +pwd + +export ROOT_DIR=. + +export NGPU=8 +export LR=1e-5 +export NSGD_SAMPLE_TEMPERATURE=1.5 +export NSGD_RANK_LAMBDA=0.0001 +export MLM_LAMBDA=0.05 +export DISC_LAMBDA=5e-4 # 05 +export CORRECTION_LAMBDA=5e-4 +export OUTPUT_DIR=${ROOT_DIR}/log/coco/itm_pnsgd2_base_v1 +export CHECKPOINT=${ROOT_DIR}/log/coco/itm_pnsgd_base_v1/ckpt/model_step_2000.pt + +export NEGATIVE_SIZE=399 +export HARD_NEG_SIZE=31 +export MLM_SAMPLE_SIZE=20 +export NSGD_SAMPLE_SIZE=20 +export TRAIN_BATCH_SIZE=32 +export MARGIN=0.2 + +export VALID_STEPS=150 +export NUM_TRAIN_STEPS=1500 +export WARMUP_STEPS=150 + +rm ${OUTPUT_DIR} -rf +ls -lh ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np ${NGPU} python train_pnsgd2.py --config config/train-itm-pnsgd2-base-coco.json \ + --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \ + --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \ + --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \ + --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \ + --disc_lambda ${DISC_LAMBDA} --correction_lambda ${CORRECTION_LAMBDA} --checkpoint ${CHECKPOINT} \ +| tee -a ${OUTPUT_DIR}/train_log.txt diff --git a/run_cmds/train_pnsgd2_base_flickr.sh b/run_cmds/train_pnsgd2_base_flickr.sh new file mode 100644 index 0000000..d5deb59 --- /dev/null +++ b/run_cmds/train_pnsgd2_base_flickr.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +pwd + +export ROOT_DIR=. + +export NGPU=8 +export LR=1e-5 +export NSGD_SAMPLE_TEMPERATURE=1.5 +export NSGD_RANK_LAMBDA=5e-4 +export MLM_LAMBDA=0.1 +export DISC_LAMBDA=5e-4 +export CORRECTION_LAMBDA=5e-4 +export OUTPUT_DIR=${ROOT_DIR}/log/flickr/itm_pnsgd2_base_v1 +export CHECKPOINT=${ROOT_DIR}/log/flickr/itm_pnsgd_base_v1/ckpt/model_step_2000.pt + +export NEGATIVE_SIZE=399 +export HARD_NEG_SIZE=31 +export MLM_SAMPLE_SIZE=20 +export NSGD_SAMPLE_SIZE=20 +export TRAIN_BATCH_SIZE=32 +export MARGIN=0.2 + +export VALID_STEPS=150 +export NUM_TRAIN_STEPS=1500 +export WARMUP_STEPS=150 + +rm ${OUTPUT_DIR} -rf +ls -lh ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np ${NGPU} python train_pnsgd2.py --config config/train-itm-pnsgd2-base-flickr.json \ + --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \ + --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \ + --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \ + --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \ + --disc_lambda ${DISC_LAMBDA} --correction_lambda ${CORRECTION_LAMBDA} --checkpoint ${CHECKPOINT} \ +| tee -a ${OUTPUT_DIR}/train_log.txt diff --git a/run_cmds/train_pnsgd2_large_coco.sh b/run_cmds/train_pnsgd2_large_coco.sh new file mode 100644 index 0000000..511e46b --- /dev/null +++ b/run_cmds/train_pnsgd2_large_coco.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +pwd + +export ROOT_DIR=. + +export NGPU=8 +export LR=1e-5 +export NSGD_SAMPLE_TEMPERATURE=1.5 +export NSGD_RANK_LAMBDA=0.0001 +export MLM_LAMBDA=0.05 +export DISC_LAMBDA=5e-4 # 05 +export CORRECTION_LAMBDA=5e-4 +export OUTPUT_DIR=${ROOT_DIR}/log/coco/itm_pnsgd2_large_v1 +export CHECKPOINT=${ROOT_DIR}/log/coco/itm_pnsgd_large_v1/ckpt/model_step_2000.pt + +export NEGATIVE_SIZE=399 +export HARD_NEG_SIZE=31 +export MLM_SAMPLE_SIZE=20 +export NSGD_SAMPLE_SIZE=20 +export TRAIN_BATCH_SIZE=32 +export MARGIN=0.2 + +export VALID_STEPS=150 +export NUM_TRAIN_STEPS=1500 +export WARMUP_STEPS=150 + +rm ${OUTPUT_DIR} -rf +ls -lh ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np ${NGPU} python train_pnsgd2.py --config config/train-itm-pnsgd2-large-coco.json \ + --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \ + --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \ + --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \ + --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \ + --disc_lambda ${DISC_LAMBDA} --correction_lambda ${CORRECTION_LAMBDA} --checkpoint ${CHECKPOINT} \ +| tee -a ${OUTPUT_DIR}/train_log.txt diff --git a/run_cmds/train_pnsgd2_large_flickr.sh b/run_cmds/train_pnsgd2_large_flickr.sh new file mode 100644 index 0000000..41451fe --- /dev/null +++ b/run_cmds/train_pnsgd2_large_flickr.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +pwd + +export ROOT_DIR=. + +export NGPU=8 +export LR=1e-5 +export NSGD_SAMPLE_TEMPERATURE=1.5 +export NSGD_RANK_LAMBDA=5e-4 +export MLM_LAMBDA=0.05 +export DISC_LAMBDA=5e-4 +export CORRECTION_LAMBDA=5e-4 + +export OUTPUT_DIR=${ROOT_DIR}/log/flickr/itm_pnsgd2_large_v1 +export CHECKPOINT=${ROOT_DIR}/log/flickr/itm_pnsgd_large_v1/ckpt/model_step_2000.pt + +export NEGATIVE_SIZE=399 +export HARD_NEG_SIZE=23 +export MLM_SAMPLE_SIZE=20 +export NSGD_SAMPLE_SIZE=20 +export TRAIN_BATCH_SIZE=32 +export MARGIN=0.2 + +export VALID_STEPS=500 +export NUM_TRAIN_STEPS=5000 +export WARMUP_STEPS=500 + +rm ${OUTPUT_DIR} -rf +ls -lh ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np ${NGPU} python train_pnsgd2.py --config config/train-itm-pnsgd2-large-flickr.json \ + --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \ + --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \ + --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \ + --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \ + --disc_lambda ${DISC_LAMBDA} --correction_lambda ${CORRECTION_LAMBDA} --checkpoint ${CHECKPOINT} \ +| tee -a ${OUTPUT_DIR}/train_log.txt diff --git a/run_cmds/train_pnsgd_base_coco.sh b/run_cmds/train_pnsgd_base_coco.sh new file mode 100644 index 0000000..82dd7b5 --- /dev/null +++ b/run_cmds/train_pnsgd_base_coco.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +pwd + +export ROOT_DIR=. + +export NGPU=8 +export LR=5e-5 +export NSGD_SAMPLE_TEMPERATURE=1.5 +export NSGD_RANK_LAMBDA=5e-5 +export MLM_LAMBDA=0.01 +export OUTPUT_DIR=${ROOT_DIR}/log/coco/itm_pnsgd_base_v1 + +export NEGATIVE_SIZE=399 +export HARD_NEG_SIZE=31 +export MLM_SAMPLE_SIZE=20 +export NSGD_SAMPLE_SIZE=20 +export TRAIN_BATCH_SIZE=32 +export MARGIN=0.2 + +export VALID_STEPS=500 +export NUM_TRAIN_STEPS=5000 +export WARMUP_STEPS=500 + +rm ${OUTPUT_DIR} -rf +ls -lh ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np ${NGPU} python train_pnsgd.py --config config/train-itm-pnsgd-base-coco.json \ + --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \ + --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \ + --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \ + --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \ +| tee -a ${OUTPUT_DIR}/train_log.txt diff --git a/run_cmds/train_pnsgd_base_flickr.sh b/run_cmds/train_pnsgd_base_flickr.sh new file mode 100644 index 0000000..77faa17 --- /dev/null +++ b/run_cmds/train_pnsgd_base_flickr.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +pwd + +export ROOT_DIR=. + +export NGPU=8 +export LR=5e-5 +export NSGD_SAMPLE_TEMPERATURE=2.0 +export NSGD_RANK_LAMBDA=0.0005 +export MLM_LAMBDA=0.05 +export OUTPUT_DIR=${ROOT_DIR}/log/filckr/itm_pnsgd_base_v1 + +export NEGATIVE_SIZE=399 +export HARD_NEG_SIZE=31 +export MLM_SAMPLE_SIZE=20 +export NSGD_SAMPLE_SIZE=20 +export TRAIN_BATCH_SIZE=16 +export MARGIN=0.2 + +export VALID_STEPS=500 +export NUM_TRAIN_STEPS=5000 +export WARMUP_STEPS=500 + +rm ${OUTPUT_DIR} -rf +ls -lh ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np ${NGPU} python train_pnsgd.py --config config/train-itm-pnsgd-base-flickr.json \ + --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \ + --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \ + --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \ + --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \ +| tee -a ${OUTPUT_DIR}/train_log.txt diff --git a/run_cmds/train_pnsgd_large_coco.sh b/run_cmds/train_pnsgd_large_coco.sh new file mode 100644 index 0000000..db9b4ba --- /dev/null +++ b/run_cmds/train_pnsgd_large_coco.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +pwd + +export ROOT_DIR=. + +export NGPU=8 +export LR=4e-5 +export NSGD_SAMPLE_TEMPERATURE=1.5 +export NSGD_RANK_LAMBDA=1e-3 +export MLM_LAMBDA=0.05 +export OUTPUT_DIR=${ROOT_DIR}/log/coco/itm_pnsgd_large_v1 + +export NEGATIVE_SIZE=399 +export HARD_NEG_SIZE=23 +export MLM_SAMPLE_SIZE=20 +export NSGD_SAMPLE_SIZE=20 +export TRAIN_BATCH_SIZE=32 +export MARGIN=0.2 + + +export VALID_STEPS=500 +export NUM_TRAIN_STEPS=5000 +export WARMUP_STEPS=500 + +rm ${OUTPUT_DIR} -rf +ls -lh ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np ${NGPU} python train_pnsgd.py --config config/train-itm-pnsgd-large-coco.json \ + --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \ + --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \ + --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \ + --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \ +| tee -a ${OUTPUT_DIR}/train_log.txt diff --git a/run_cmds/train_pnsgd_large_flickr.sh b/run_cmds/train_pnsgd_large_flickr.sh new file mode 100644 index 0000000..7c913a9 --- /dev/null +++ b/run_cmds/train_pnsgd_large_flickr.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +pwd + +export ROOT_DIR=. + +export NGPU=8 +export LR=5e-5 +export NSGD_SAMPLE_TEMPERATURE=2.0 +export NSGD_RANK_LAMBDA=0.001 +export MLM_LAMBDA=0.05 +export OUTPUT_DIR=${ROOT_DIR}/log/flickr/itm_pnsgd_large_v1 + +export NEGATIVE_SIZE=399 +export HARD_NEG_SIZE=23 +export MLM_SAMPLE_SIZE=20 +export NSGD_SAMPLE_SIZE=20 +export TRAIN_BATCH_SIZE=32 +export MARGIN=0.2 + +export VALID_STEPS=500 +export NUM_TRAIN_STEPS=5000 +export WARMUP_STEPS=500 + +rm ${OUTPUT_DIR} -rf +ls -lh ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +horovodrun -np ${NGPU} python train_pnsgd.py --config config/train-itm-pnsgd-large-flickr.json \ + --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \ + --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \ + --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \ + --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \ +| tee -a ${OUTPUT_DIR}/train_log.txt diff --git a/scripts/convert_ckpt.py b/scripts/convert_ckpt.py new file mode 100644 index 0000000..9bb8a4e --- /dev/null +++ b/scripts/convert_ckpt.py @@ -0,0 +1,13 @@ +import sys +from collections import OrderedDict + +import torch + +bert_ckpt, output_ckpt = sys.argv[1:] + +bert = torch.load(bert_ckpt) +uniter = OrderedDict() +for k, v in bert.items(): + uniter[k.replace('bert', 'uniter')] = v + +torch.save(uniter, output_ckpt) diff --git a/scripts/convert_imgdir.py b/scripts/convert_imgdir.py new file mode 100644 index 0000000..ca4440c --- /dev/null +++ b/scripts/convert_imgdir.py @@ -0,0 +1,142 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +convert image npz to LMDB +""" +import argparse +import glob +import io +import json +import multiprocessing as mp +import os +from os.path import basename, exists + +from cytoolz import curry +import numpy as np +from tqdm import tqdm +import lmdb + +import msgpack +import msgpack_numpy +msgpack_numpy.patch() + + +def _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb): + num_bb = max(min_bb, (img_dump['conf'] > conf_th).sum()) + num_bb = min(max_bb, num_bb) + return int(num_bb) + + +@curry +def load_npz(conf_th, max_bb, min_bb, num_bb, fname, keep_all=False): + try: + img_dump = np.load(fname, allow_pickle=True) + if keep_all: + nbb = None + else: + nbb = _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb) + dump = {} + for key, arr in img_dump.items(): + if arr.dtype == np.float32: + arr = arr.astype(np.float16) + if arr.ndim == 2: + dump[key] = arr[:nbb, :] + elif arr.ndim == 1: + dump[key] = arr[:nbb] + else: + raise ValueError('wrong ndim') + except Exception as e: + # corrupted file + print(f'corrupted file {fname}', e) + dump = {} + nbb = 0 + + name = basename(fname) + return name, dump, nbb + + +def dumps_npz(dump, compress=False): + with io.BytesIO() as writer: + if compress: + np.savez_compressed(writer, **dump, allow_pickle=True) + else: + np.savez(writer, **dump, allow_pickle=True) + return writer.getvalue() + + +def dumps_msgpack(dump): + return msgpack.dumps(dump, use_bin_type=True) + + +def main(opts): + if opts.img_dir[-1] == '/': + opts.img_dir = opts.img_dir[:-1] + split = basename(opts.img_dir) + if opts.keep_all: + db_name = 'all' + else: + if opts.conf_th == -1: + db_name = f'feat_numbb{opts.num_bb}' + else: + db_name = (f'feat_th{opts.conf_th}_max{opts.max_bb}' + f'_min{opts.min_bb}') + if opts.compress: + db_name += '_compressed' + if not exists(f'{opts.output}/{split}'): + os.makedirs(f'{opts.output}/{split}') + env = lmdb.open(f'{opts.output}/{split}/{db_name}', map_size=1024**4) + txn = env.begin(write=True) + files = glob.glob(f'{opts.img_dir}/*.npz') + load = load_npz(opts.conf_th, opts.max_bb, opts.min_bb, opts.num_bb, + keep_all=opts.keep_all) + name2nbb = {} + with mp.Pool(opts.nproc) as pool, tqdm(total=len(files)) as pbar: + for i, (fname, features, nbb) in enumerate( + pool.imap_unordered(load, files, chunksize=128)): + if not features: + continue # corrupted feature + if opts.compress: + dump = dumps_npz(features, compress=True) + else: + dump = dumps_msgpack(features) + txn.put(key=fname.encode('utf-8'), value=dump) + if i % 1000 == 0: + txn.commit() + txn = env.begin(write=True) + name2nbb[fname] = nbb + pbar.update(1) + txn.put(key=b'__keys__', + value=json.dumps(list(name2nbb.keys())).encode('utf-8')) + txn.commit() + env.close() + if opts.conf_th != -1 and not opts.keep_all: + with open(f'{opts.output}/{split}/' + f'nbb_th{opts.conf_th}_' + f'max{opts.max_bb}_min{opts.min_bb}.json', 'w') as f: + json.dump(name2nbb, f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--img_dir", default=None, type=str, + help="The input images.") + parser.add_argument("--output", default=None, type=str, + help="output lmdb") + parser.add_argument('--nproc', type=int, default=8, + help='number of cores used') + parser.add_argument('--compress', action='store_true', + help='compress the tensors') + parser.add_argument('--keep_all', action='store_true', + help='keep all features, overrides all following args') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=100, + help='number of bounding boxes (fixed)') + args = parser.parse_args() + main(args) diff --git a/scripts/create_imgdb.sh b/scripts/create_imgdb.sh new file mode 100644 index 0000000..21264bf --- /dev/null +++ b/scripts/create_imgdb.sh @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +IMG_NPY=$1 +OUT_DIR=$2 + +set -e + +echo "converting image features ..." +if [ ! -d $OUT_DIR ]; then + mkdir -p $OUT_DIR +fi +NAME=$(basename $IMG_NPY) +docker run --ipc=host --rm -it \ + --mount src=$(pwd),dst=/src,type=bind \ + --mount src=$OUT_DIR,dst=/img_db,type=bind \ + --mount src=$IMG_NPY,dst=/$NAME,type=bind,readonly \ + -w /src chenrocks/uniter \ + python scripts/convert_imgdir.py --img_dir /$NAME --output /img_db + +echo "done" diff --git a/scripts/create_txtdb.sh b/scripts/create_txtdb.sh new file mode 100644 index 0000000..6789ef3 --- /dev/null +++ b/scripts/create_txtdb.sh @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +OUT_DIR=$1 +ANN_DIR=$2 + +set -e + +URL='https://raw.githubusercontent.com/lil-lab/nlvr/master/nlvr2/data' +if [ ! -d $OUT_DIR ]; then + mkdir -p $OUT_DIR +fi +if [ ! -d $ANN_DIR ]; then + mkdir -p $ANN_DIR +fi + +BLOB='https://convaisharables.blob.core.windows.net/uniter' +MISSING=$BLOB/ann/missing_nlvr2_imgs.json +if [ ! -f $ANN_DIR/missing_nlvr2_imgs.json ]; then + wget $MISSING -O $ANN_DIR/missing_nlvr2_imgs.json +fi + +for SPLIT in 'train' 'dev' 'test1'; do + if [ ! -f $ANN_DIR/$SPLIT.json ]; then + echo "downloading ${SPLIT} annotations..." + wget $URL/$SPLIT.json -O $ANN_DIR/$SPLIT.json + fi + + echo "preprocessing ${SPLIT} annotations..." + docker run --ipc=host --rm -it \ + --mount src=$(pwd),dst=/src,type=bind \ + --mount src=$OUT_DIR,dst=/txt_db,type=bind \ + --mount src=$ANN_DIR,dst=/ann,type=bind,readonly \ + -w /src chenrocks/uniter \ + python prepro.py --annotation /ann/$SPLIT.json \ + --missing_imgs /ann/missing_nlvr2_imgs.json \ + --output /txt_db/nlvr2_${SPLIT}.db +done + +echo "done" diff --git a/scripts/download_itm.sh b/scripts/download_itm.sh new file mode 100644 index 0000000..77ac0ea --- /dev/null +++ b/scripts/download_itm.sh @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +DOWNLOAD=$1 + +for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do + if [ ! -d $DOWNLOAD/$FOLDER ] ; then + mkdir -p $DOWNLOAD/$FOLDER + fi +done + +BLOB='https://convaisharables.blob.core.windows.net/uniter' + +# image dbs +for SPLIT in 'train2014' 'val2014'; do + if [ ! -d $DOWNLOAD/img_db/coco_$SPLIT ] ; then + wget $BLOB/img_db/coco_$SPLIT.tar -P $DOWNLOAD/img_db/ + tar -xvf $DOWNLOAD/img_db/coco_$SPLIT.tar -C $DOWNLOAD/img_db + fi +done +if [ ! -d $DOWNLOAD/img_db/flickr30k ] ; then + wget $BLOB/img_db/flickr30k.tar -P $DOWNLOAD/img_db/ + tar -xvf $DOWNLOAD/img_db/flickr30k.tar -C $DOWNLOAD/img_db +fi + +# text dbs +for SPLIT in 'train' 'restval' 'val' 'test'; do + wget $BLOB/txt_db/itm_coco_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ + tar -xvf $DOWNLOAD/txt_db/itm_coco_$SPLIT.db.tar -C $DOWNLOAD/txt_db +done +for SPLIT in 'train' 'val' 'test'; do + wget $BLOB/txt_db/itm_flickr30k_$SPLIT.db.tar -P $DOWNLOAD/txt_db/ + tar -xvf $DOWNLOAD/txt_db/itm_flickr30k_$SPLIT.db.tar -C $DOWNLOAD/txt_db +done + +if [ ! -f $DOWNLOAD/pretrained/uniter-base.pt ] ; then + wget $BLOB/pretrained/uniter-base.pt -P $DOWNLOAD/pretrained/ +fi + diff --git a/scripts/eval_nlvr2.py b/scripts/eval_nlvr2.py new file mode 100644 index 0000000..2e48569 --- /dev/null +++ b/scripts/eval_nlvr2.py @@ -0,0 +1,64 @@ +""" +copied from official NLVR2 github +(https://github.com/lil-lab/nlvr/tree/master/nlvr2) + +python scripts/eval_nlvr2.py +""" +import json +import sys + +# Load the predictions file. Assume it is a CSV. +predictions = { } +for line in open(sys.argv[1]).readlines(): + if line: + splits = line.strip().split(",") + # We assume identifiers are in the format "split-####-#-#.png". + identifier = splits[0] + prediction = splits[1] + predictions[identifier] = prediction + +# Load the labeled examples. +labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line] + +# If not, identify the ones that are missing, and exit. +total_num = len(labeled_examples) +if len(predictions) < total_num: + print("Some predictions are missing!") + print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num)) + + for example in labeled_examples: + lookup = example["identifier"] + if not lookup in predictions: + print("Missing prediction for item " + str(lookup)) + exit() + +# Get the precision by iterating through the examples and checking the value +# that was predicted. +# Also update the "consistency" dictionary that keeps track of whether all +# predictions for a given sentence were correct. +num_correct = 0. +consistency_dict = { } + +for example in labeled_examples: + anon_label = example["identifier"].split("-") + anon_label[2] = '' + anon_label = '-'.join(anon_label) + if not anon_label in consistency_dict: + consistency_dict[anon_label] = True + lookup = example["identifier"] + prediction = predictions[lookup] + if prediction.lower() == example["label"].lower(): + num_correct += 1. + else: + consistency_dict[anon_label] = False + +# Calculate consistency. +num_consistent = 0. +unique_sentence = len(consistency_dict) +for identifier, consistent in consistency_dict.items(): + if consistent: + num_consistent += 1 + +# Report values. +print("accuracy=" + str(num_correct / total_num)) +print("consistency=" + str(num_consistent / unique_sentence)) diff --git a/scripts/eval_zs_itm_flickr.sh b/scripts/eval_zs_itm_flickr.sh new file mode 100644 index 0000000..572beae --- /dev/null +++ b/scripts/eval_zs_itm_flickr.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +horovodrun -np $NGPU python inf_itm.py \ + --txt_db ${TXT_DB} --img_db ${IMG_DB} --checkpoint ${PRETRAINED_MODEL} \ + --model_config ${MODEL_CONFIG} --output_dir $ZS_ITM_RESULT --fp16 --pin_mem \ No newline at end of file diff --git a/scripts/extract_imgfeat.sh b/scripts/extract_imgfeat.sh new file mode 100644 index 0000000..849589b --- /dev/null +++ b/scripts/extract_imgfeat.sh @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +IMG_DIR=$1 +OUT_DIR=$2 + +set -e + +echo "extracting image features..." +if [ ! -d $OUT_DIR ]; then + mkdir -p $OUT_DIR +fi +docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm \ + --mount src=$IMG_DIR,dst=/img,type=bind,readonly \ + --mount src=$OUT_DIR,dst=/output,type=bind \ + -w /src chenrocks/butd-caffe:nlvr2 \ + bash -c "python tools/generate_npz.py --gpu 0" + +echo "done" diff --git a/scripts/train_itm.sh b/scripts/train_itm.sh new file mode 100644 index 0000000..4aa70b8 --- /dev/null +++ b/scripts/train_itm.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +ngpu=1 +db_root=${db_root} # ../itm-data +txt_dir=${db_root}/text_db +img_dir=${db_root}/img_db +model_dir=${db_root}/pretrained +config_dir=. +zs_itm_result=./log +mkdir -p ${zs_itm_result} + +horovodrun -np $ngpu python inf_itm.py \ + --txt_db ${txt_dir}/itm_flickr30k_test.db --img_db ${img_dir}/img/flickr30k \ + --checkpoint ${model_dir}/uniter-base.pt --model_config ${config_dir}/config/uniter-base.json \ + --output_dir ${zs_itm_result} --fp16 --pin_mem \ No newline at end of file diff --git a/setup_container.sh b/setup_container.sh new file mode 100644 index 0000000..4b3245c --- /dev/null +++ b/setup_container.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +sudo apt-get remove docker docker-engine docker.io containerd runc + +sudo apt-get update + +sudo apt-get install apt-transport-https ca-certificates curl gnupg + +curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg + +echo "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null + +sudo apt-get update + +sudo apt-get install docker-ce docker-ce-cli containerd.io + +apt-cache madison docker-ce + +sudo apt-get install docker-ce=5:19.03.12~3-0~ubuntu-bionic docker-ce-cli=5:19.03.12~3-0~ubuntu-bionic containerd.io + +sudo docker run hello-world + + +#INSTALL_DIR=~ +#source launch_container.sh ${INSTALL_DIR}/txt_db ${INSTALL_DIR}/img_db \ +# ${INSTALL_DIR}/finetune ${INSTALL_DIR}/pretrained diff --git a/support/case_study.py b/support/case_study.py new file mode 100644 index 0000000..064b7a2 --- /dev/null +++ b/support/case_study.py @@ -0,0 +1,49 @@ +import os +import json +import numpy as np + + +def main(): + log_dir = './log/adv_dc_visual2' + for filename in os.listdir(log_dir): + cont = False + for cand_imgname in [202175131, 539676201, 4373983146]: + if str(cand_imgname) in filename: + cont = True + if not cont: + continue + log = json.load(open(os.path.join(log_dir, filename), 'r')) + print('='*100) + print(log) + # print('='*100) + # print(log['idx'], log['imgname']) + # print("| gt_text: ", log['gt_text']) + gt_text = log['gt_text'] + syn_texts = log['syn_texts'] + disc_ps = log['disc_p'] + correction_texts = log['correction_texts'] + syn_scores = log['syn_scores'] + + i = 0 + for syn_score, syn_text, disc_p, correction_text in zip(syn_scores, syn_texts, disc_ps, correction_texts): + syn_idt = (np.array(syn_text) == np.array(gt_text)).astype(np.int64) + new_correction_text = [] + for idt, gt_token, correction_token in zip(syn_idt, gt_text, correction_text): + if idt == 1: + new_correction_text.append(gt_token) + else: + new_correction_text.append(correction_token) + correction_text = new_correction_text + # correction_text = (syn_idt * gt_text) + (correction_text * (1-syn_idt)) + + correction_idt = (np.array(correction_text) == np.array(gt_text)).astype(np.int64) + # print(syn_idt, correction_idt) + if correction_idt.sum() >= syn_idt.sum() + 2: # and (1 - correction_idt).sum() == 0: + print(i, syn_score, log['imgname']) + print(list(zip(gt_text, syn_text, correction_text, disc_p))) + + i += 1 + + +if __name__ == '__main__': + main() diff --git a/support/plot.py b/support/plot.py new file mode 100644 index 0000000..3242e4d --- /dev/null +++ b/support/plot.py @@ -0,0 +1,118 @@ +import os +import json +import math +import numpy as np +import matplotlib.pyplot as plt + +plt.style.use('ggplot') +np.random.seed(19680801) + + +def main(): + data_dir = "/Users/fanzhihao/Documents/Research/NIPS2021" + kwargs = dict(alpha=0.2, bins=2000, density=True) + + max_d = 4000.0 + data0 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(0)), 'r')) if math.exp(x) < max_d] + print('data0', len(data0)) + plt.hist(data0, color='g', **kwargs) + data1 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(1)), 'r')) if math.exp(x) < max_d] + print('data1', len(data1)) + plt.hist(data1, color='b', **kwargs) + data2 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(2)), 'r')) if math.exp(x) < max_d] + print('data2', len(data2)) + plt.hist(data2, color='r', **kwargs) + # plt.gca().set(title='Frequency Histogram of Diamond Depths', ylabel='Frequency') + plt.xlim(0, 250) + plt.show() + + +def main2(): + data_dir = "/Users/fanzhihao/Documents/Research/NIPS2021" + import seaborn as sns + # white, dark, whitegrid, darkgrid, ticks + + sns.set_style("darkgrid") + # sns.set_style("ticks") + + # Import data + # df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/diamonds.csv') + # x1 = df.loc[df.cut == 'Ideal', 'depth'] + # x2 = df.loc[df.cut == 'Fair', 'depth'] + # x3 = df.loc[df.cut == 'Good', 'depth'] + + # Plot + # kwargs = dict(hist_kws={'alpha': .6}, kde_kws={'linewidth': 2}) + # + # plt.figure(figsize=(10, 7), dpi=80) + # sns.distplot(x1, color="dodgerblue", label="Compact", **kwargs) + # sns.distplot(x2, color="orange", label="SUV", **kwargs) + # sns.distplot(x3, color="deeppink", label="minivan", **kwargs) + # plt.xlim(50, 75) + # plt.legend(); + + kwargs = dict(alpha=0.8, bins=2000) + + # plt.rcParams['figure.figsize'] = (40.0, 8.0) + max_d = 1500.0 + + kwargs = dict(hist_kws={'alpha': .2}, kde_kws={'linewidth': 2, 'shade': True}, bins=5000) + + plt.figure(figsize=(9.2, 6), dpi=80) + # data0 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(0)), 'r')) if math.exp(x) < max_d] + data0 = [x for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(0)), 'r')) if x < max_d] + print('data0', len(data0)) + line0 = sns.distplot(data0, **kwargs, color='g', hist=False, kde=True, label="Positive Text") + # data1 = [math.exp(x)+20 for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(1)), 'r')) if math.exp(x) < max_d] + data1 = [x for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(1)), 'r')) if x < max_d] + print('data1', len(data1)) + line1 = sns.distplot(data1, **kwargs, color='b', hist=False, kde=True, label="Synthetic Text") + # data2 = [math.exp(x)+10 for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(2)), 'r')) if math.exp(x) < max_d] + data2 = [x for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(2)), 'r')) if x < max_d] + line2 = sns.distplot(data2, **kwargs, color='r', hist=False, kde=True, label="Corrected Text") + plt.xlim(0, 300) + + plt.legend(loc="upper right", fontsize=20) + + # plt.xticks(np.arange(-5, 5, 0.5), fontproperties='Times New Roman', size=10) + # plt.yticks(np.arange(-2, 2, 0.3), fontproperties='Times New Roman', size=10) + + plt.xlabel('', fontdict={'family': 'Times New Roman', 'size': 20}) + plt.ylabel('', fontdict={'family': 'Times New Roman', 'size': 20}) + + plt.yticks(np.arange(0.0, 0.012, 0.002), fontproperties='Times New Roman', size=20) + plt.xticks(np.arange(0, 300, 50), fontproperties='Times New Roman', size=20) + plt.legend(prop={'family': 'Times New Roman', 'size': 20}) + # plt.rcParams['xtick.direction'] = 'out' + # plt.rcParams['ytick.direction'] = 'out' + # plt.bar([0, 50, 100, 150, 200, 250], [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]) + + # plt.minorticks_on() + # plt.tick_params(which='major', width=4, length=10, direction="inout") + + plt.show() + + # plt.gca().set(title='Frequency Histogram of Diamond Depths', ylabel='Frequency') + # plt.show() + # plt.legend() + + +def main3(): + import matplotlib.pyplot as plt + import numpy as np + from scipy.stats import gaussian_kde + + data_dir = "/Users/fanzhihao/Documents/Research/NIPS2021" + max_d = 4000.0 + data0 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(0)), 'r')) if math.exp(x) < max_d] + + density = gaussian_kde(data0) + xs = np.linspace(0, 8, 200) + density.covariance_factor = lambda: .25 + density._compute_covariance() + plt.plot(xs, density(xs)) + plt.show() + + +if __name__ == '__main__': + main2() diff --git a/support/ppl_solver.py b/support/ppl_solver.py new file mode 100644 index 0000000..32bf5cd --- /dev/null +++ b/support/ppl_solver.py @@ -0,0 +1,140 @@ +import os +import json + +import torch +from transformers import GPT2LMHeadModel, GPT2TokenizerFast +# from nlp import load_dataset +from tqdm import tqdm +import numpy as np +import string + + +def load_data(): + puncs = list(set(list(string.punctuation)) - set([',', '.', '-', '?', '!', '\''])) + puncs = [punc + ' ' for punc in puncs] + + def bpe2token(bpes): + text = ' '.join(bpes).replace(' ##', '').replace(" ,", ',').replace(" .", ".").replace(' - ', '-').\ + replace(' ?', '?').replace(' !', '!').replace(' \' ', '\'') + for punc in puncs: + text = text.replace(punc, ' ') + text = text.strip() + return text + + log_dir = './log/adv_dc_visual2' + dataset_gt_texts, dataset_syn_texts, dataset_correct_texts = [], [], [] + + max_num = 1 + for filename in tqdm(os.listdir(log_dir)): + log = json.load(open(os.path.join(log_dir, filename), 'r')) + gt_text = log['gt_text'][1:-1] + syn_texts = [text[1:-1] for text in log['syn_texts'][:max_num]] + correct_texts = [text[1:-1] for text in log['correction_texts'][:max_num]] + # print("| gt_text: ", gt_text) + # print("| syn_text: ", syn_texts[0]) + # print("| correct_text: ", correct_texts[0]) + dataset_gt_texts.append(bpe2token(gt_text)) + dataset_syn_texts.extend([bpe2token(syn_text) for syn_text in syn_texts]) + dataset_correct_texts.extend([bpe2token(correct_text) for correct_text in correct_texts]) + return dataset_gt_texts, dataset_syn_texts, dataset_correct_texts + + +def main(): + # device = 'cuda' + # from transformers import GPT2Tokenizer, GPT2LMHeadModel + # # Load pre-trained model (weights) + # with torch.no_grad(): + # model = GPT2LMHeadModel.from_pretrained('gpt2').to(device) + # model.eval() + # # Load pre-trained model tokenizer (vocabulary) + # tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + + device = 'cuda' + model_id = 'gpt2-large' + model = GPT2LMHeadModel.from_pretrained(model_id).to(device) + model.eval() + tokenizer = GPT2TokenizerFast.from_pretrained(model_id) + + # text = "From large scale power generators to the basic cooking in our homes, fuel is essential for all of these to happen and work." + # text = "A woman plays with finger puppets as a small child in a costume walks by." + # # "From large scale power generators to the basic cooking in our homes, fuel is essential for all of these to happen and work." + # input_ids = torch.tensor(tokenizer([text])["input_ids"]).to(device) + # target_ids = input_ids.clone() + # with torch.no_grad(): + # log_likelihood = model(input_ids, labels=target_ids)[0] + # ppl = torch.exp(log_likelihood) + # print(np.exp(log_likelihood.data.cpu().numpy())) + # print('1: ', log_likelihood, ppl) + # + # raise Exception + + # test = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + # encodings = tokenizer('\n\n'.join(test['text']), return_tensors='pt') + + # max_length = model.config.n_positions + # stride = 512 + + dataset_gt_texts, dataset_syn_texts, dataset_correct_texts = load_data() + for i, dataset_texts in enumerate([dataset_gt_texts, dataset_syn_texts, dataset_correct_texts]): + lls = [] + end_loc = 0 + lls2 = [] + end_loc2 = 0 + lls_save = [] + ppls = [] + for text in tqdm(dataset_texts): + # print(" text: ", text, type(text)) + # print('| text: ', text[0], text[1], text[-1]) + input_ids = torch.tensor(tokenizer([text])["input_ids"]).to(device) + # print('| ', text, input_ids) + trg_len = input_ids.size(-1) + end_loc += 1 + end_loc2 += trg_len + # print(type(input_ids)) + # print(input_ids) + # print('| input_ids: ', input_ids.size(), input_ids) + input_ids = input_ids + target_ids = input_ids.clone() + # print('| input_ids', input_ids.size(), target_ids.size()) + # print(input_ids.size(), target_ids.size()) + # target_ids[-1] = -100 + with torch.no_grad(): + outputs = model(input_ids, labels=target_ids) + ll = outputs[0].item() + log_likelihood = outputs[0] + ppl = torch.exp(log_likelihood) + if ppl > 200.0: + print('| large ppl: ', text) + # print(ll, ppl) + + ppls.append(ppl) + # print('| log_likelihood: ', log_likelihood, ppl) + lls.append(log_likelihood) + lls2.append(log_likelihood * trg_len) + lls_save.append(ppl.item()) + + ppl_mean = torch.stack(ppls).mean() + ppl = torch.exp(torch.stack(lls).sum() / end_loc) + ppl2 = torch.exp(torch.stack(lls2).sum() / end_loc2) + print('| ppl: ', ppl_mean.item(), ppl.item(), ppl2.item()) + + save_path = os.path.join('log', 'dataset_{}.json'.format(i)) + with open(save_path, 'w') as fw: + json.dump(lls_save, fw) + fw.close() + print('ppl: ', ppl) + + # for i in tqdm(range(0, encodings.input_ids.size(1), stride)): + # begin_loc = max(i + stride - max_length, 0) + # end_loc = min(i + stride, encodings.input_ids.size(1)) + # trg_len = end_loc - i # may be different from stride on last loop + # input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device) + # target_ids = input_ids.clone() + # target_ids[:,:-trg_len] = -100 + # + # with torch.no_grad(): + # outputs = model(input_ids, labels=target_ids) + # log_likelihood = outputs[0] * trg_len + +if __name__ == '__main__': + main() diff --git a/support/tensorboard_reload.py b/support/tensorboard_reload.py new file mode 100644 index 0000000..ab30c27 --- /dev/null +++ b/support/tensorboard_reload.py @@ -0,0 +1,18 @@ +import os +from tensorboard.backend.event_processing import event_accumulator + + +def main(): + data_dir = "/Users/fanzhihao/Documents/Research/NIPS2021" + tensorboard_path = os.path.join(data_dir, "events.out.tfevents.uniter.base.pm.dm") + ea = event_accumulator.EventAccumulator(tensorboard_path) + ea.Reload() + print(ea.scalars.Keys()) + + # val_psnr = ea.scalars.Items('val_psnr') + # print(len(val_psnr)) + # print([(i.step, i.value) for i in val_psnr]) + + +if __name__ == "__main__": + main() diff --git a/train_pnsgd.py b/train_pnsgd.py new file mode 100644 index 0000000..7be121e --- /dev/null +++ b/train_pnsgd.py @@ -0,0 +1,578 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER finetuning for Image-Text Retrieval with hard negatives +""" +import argparse +import os +from os.path import exists, join +from time import time +import math + +import torch +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader, ConcatDataset +from torch.optim.lr_scheduler import LambdaLR +from apex import amp +from horovod import torch as hvd +from tqdm import tqdm + +from data import (PrefetchLoader, TxtTokLmdb, ImageLmdbGroup, + PNSGDFromText, PNSGDFromImage, pnsgd_collate, itm_rank_hn_collate, + ItmValDataset, itm_val_collate, + ItmEvalDataset, itm_eval_collate) +from model.nsgd import UniterForNSGD +from optim import get_lr_sched +from optim.misc import build_optimizer + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed +from utils.const import IMG_DIM +from utils.itm_eval import evaluate + + +def build_dataloader(dataset, collate_fn, is_train, opts): + dataloader = DataLoader(dataset, batch_size=1, + shuffle=is_train, drop_last=is_train, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, collate_fn=collate_fn) + dataloader = PrefetchLoader(dataloader) + return dataloader + + +def build_lr_scheduler(opts, num_training_steps, optimizer): + num_warmup_steps = ( + opts.warmup_steps + if opts.warmup_steps > 0 + else opts.ceil(num_training_steps * opts.args.warmup_ratio) + ) + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + return LambdaLR(optimizer, lr_lambda) + + +def load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile( + os.path.join(checkpoint, "scheduler.pt") + ): + # Load in optimizer and scheduler states + device = 'gpu' if torch.cuda.is_available() else 'cpu' + optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=device) + ) + lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt"))) + return optimizer, lr_scheduler + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + set_random_seed(opts.seed) + + if hvd.rank() == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + # store ITM predictions + os.makedirs(join(opts.output_dir, 'results_val')) + os.makedirs(join(opts.output_dir, 'results_test')) + os.makedirs(join(opts.output_dir, 'results_train')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + # train_examples = None + LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, " + f"{opts.train_img_dbs}") + # check multiple DBs + assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \ + "train txt_db and img_db have different length" + + # load DBs and image dirs + all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, + opts.num_bb, opts.compressed_db) + # train + LOGGER.info(f"Loading Train Dataset " + f"{opts.train_txt_dbs}, {opts.train_img_dbs}") + train_datasets_t = [] + train_datasets_i = [] + for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): + img_db = all_img_dbs[img_path] + txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) + train_datasets_t.append( + PNSGDFromText(txt_db, img_db, opts.negative_size, opts.mlm_sample_size)) + train_datasets_i.append( + PNSGDFromImage(txt_db, img_db, opts.negative_size)) + train_dataset_t = ConcatDataset(train_datasets_t) + train_dataset_i = ConcatDataset(train_datasets_i) + train_dataloader_t = build_dataloader( + train_dataset_t, pnsgd_collate, True, opts) + train_dataloader_i = build_dataloader( + train_dataset_i, pnsgd_collate, True, opts) + + # val + LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}") + val_img_db = all_img_dbs[opts.val_img_db] + val_txt_db = TxtTokLmdb(opts.val_txt_db, -1) + val_dataset = ItmValDataset(val_txt_db, val_img_db, + opts.inf_minibatch_size) + val_dataloader = build_dataloader(val_dataset, itm_val_collate, + False, opts) + # eval + LOGGER.info(f"Loading val, test Dataset for full evaluation: " + f"{opts.val_txt_db}, {opts.val_img_db}" + f"{opts.test_txt_db}, {opts.test_img_db}") + eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db, + opts.inf_minibatch_size) + eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate, + False, opts) + test_img_db = all_img_dbs[opts.test_img_db] + test_txt_db = TxtTokLmdb(opts.test_txt_db, -1) + eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db, + opts.inf_minibatch_size) + eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate, + False, opts) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + + model = UniterForNSGD.from_pretrained( + opts.model_config, state_dict=checkpoint, + img_dim=IMG_DIM, margin=opts.margin, hard_size=opts.hard_neg_size, + nsgd_sample_size=opts.nsgd_sample_size, nsgd_sample_temperature=opts.nsgd_sample_temperature) + model.init_output() # pretrain ITM head is different from ranking head + model.to(device) + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + set_dropout(model, opts.dropout) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + # print('| optimizer file:', os.path.join(checkpoint, "optimizer.pt")) + # if os.path.exists(os.path.join(checkpoint, "optimizer.pt")) and \ + # os.path.isfile(os.path.join(checkpoint, "optimizer.pt")): + # # Load in optimizer and scheduler states + # device = 'gpu' if torch.cuda.is_available() else 'cpu' + # optimizer.load_state_dict(torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=device)) + model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16, opt_level='O2') + # Prepare scheduler + # lr_scheduler = build_lr_scheduler(opts, opts.num_train_steps, optimizer) + # print('| scheduler file:', os.path.join(checkpoint, "scheduler.pt")) + # if os.path.exists(os.path.join(checkpoint, "scheduler.pt")) and \ + # os.path.isfile(os.path.join(checkpoint, "scheduler.pt")): + # lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, 'scheduler.pt'))) + + LOGGER.info(f"***** Running training on {n_gpu} GPUs *****") + LOGGER.info(" Num examples = %d", + sum(all_gather_list(len(train_dataset_t)))) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + # running_loss = RunningMeter('loss') + running_loss_dict = {loss: RunningMeter(loss) for loss in + ['mlm_loss', 'nsgd_rank_loss', 'i2t_hn_rank_loss', 't2i_hn_rank_loss']} + model.train() + + global_step = 0 + step = 0 + n_examples = 0 + n_hard_ex = 0 + start = time() + train_iter_i = iter(train_dataloader_i) + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + effect_nsgd_number, mlm_corrects, nsgd_rank_corrects, i2t_hn_rank_corrects, t2i_hn_rank_corrects = \ + [], [[], []], [[], []], [[], []], [[], []] + incremental_adv_scores = [] + while True: + for batch in train_dataloader_t: + + # hard text from image + try: + batch_i = next(train_iter_i) + except StopIteration: + train_iter_i = iter(train_dataloader_i) + batch_i = next(train_iter_i) + + n_examples += batch['mlm_attn_masks'].size(0) + model_outputs = model( + batch, sample_from='gi', compute_loss=True, + compute_mlm=True if isinstance(args.mlm_lambda, float) and args.mlm_lambda > 0 else False) + loss = 0.0 + if model_outputs.get('masked_lm_loss') is not None: + mlm_loss = model_outputs.get('masked_lm_loss').mean() + # print('| mlm_loss: ', mlm_loss) + loss += args.mlm_lambda * mlm_loss + running_loss_dict.get('mlm_loss')(mlm_loss.item()) + mlm_corrects[0].append(model_outputs['mlm_corrects'].sum().item()) + mlm_corrects[1].append(model_outputs['mlm_corrects'].numel()) + # print('mlm_corrects:', mlm_corrects[0][-1]/mlm_corrects[1][-1]) + effect_nsgd_number.append(model_outputs['effect_nsgd_number']) + if model_outputs.get('rank_loss') is not None: + rank_adv_scores = model_outputs['rank_adv_scores'].squeeze(-1) + # print('| rank_adv_scores: ', rank_adv_scores.min(), rank_adv_scores.max()) + incremental_adv_score = rank_adv_scores[1:] - rank_adv_scores[0] + incremental_adv_scores.append(incremental_adv_score) + nsgd_rank_loss = model_outputs.get('rank_loss') + rank_corrects = model_outputs.get('rank_corrects') + nsgd_hard_ex = rank_corrects.numel() + nsgd_rank_corrects[0].append(rank_corrects.sum().item()) + nsgd_rank_corrects[1].append(nsgd_hard_ex) + # print('nsgd_rank_corrects:', nsgd_rank_corrects[0][-1]/nsgd_rank_corrects[1][-1]) + n_hard_ex += nsgd_hard_ex + nsgd_rank_loss = nsgd_rank_loss.mean() + running_loss_dict.get('nsgd_rank_loss')(nsgd_rank_loss.item()) + loss += args.nsgd_rank_lambda * nsgd_rank_loss.mean() + if isinstance(loss, torch.Tensor): + loss = loss / opts.train_batch_size + with amp.scale_loss(loss, optimizer, delay_unscale=True, # loss_id=0 + ) as scaled_loss: + scaled_loss.backward() + + n_examples += batch_i['attn_masks'].size(0) + model_outputs = model(batch_i, sample_from='i', compute_loss=True) + loss = model_outputs.get('rank_loss') + rank_corrects = model_outputs.get('rank_corrects') + i2t_hard_ex = rank_corrects.numel() + i2t_hn_rank_corrects[0].append(rank_corrects.sum().item()) + i2t_hn_rank_corrects[1].append(i2t_hard_ex) + n_hard_ex += i2t_hard_ex + loss = loss.mean() / opts.train_batch_size + with amp.scale_loss(loss, optimizer, delay_unscale=True, # loss_id=1 + ) as scaled_loss: + scaled_loss.backward() + running_loss_dict.get('t2i_hn_rank_loss')(loss.item()) + + # hard image from text + n_examples += batch['attn_masks'].size(0) + model_outputs = model(batch, sample_from='t', compute_loss=True) + loss = model_outputs.get('rank_loss') + rank_corrects = model_outputs.get('rank_corrects') + t2i_hard_ex = rank_corrects.numel() + t2i_hn_rank_corrects[0].append(rank_corrects.sum().item()) + t2i_hn_rank_corrects[1].append(t2i_hard_ex) + n_hard_ex += t2i_hard_ex + # NOTE we use gradient accumulation to implemented train_batch_size + loss = loss.mean() / opts.train_batch_size + + step += 1 + delay_unscale = step % opts.train_batch_size != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale, #loss_id=2 + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + running_loss_dict.get('i2t_hn_rank_loss')(loss.item()) + + if step % opts.train_batch_size == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + # NOTE: not gathered across GPUs for efficiency + incremental_adv_scores = torch.cat(incremental_adv_scores, dim=0) + TB_LOGGER.add_histogram( + 'incremental_adv_score', incremental_adv_scores.data.cpu().numpy(), global_step) + incremental_adv_scores = [] + + TB_LOGGER.add_scalar( + 'mlm_loss', running_loss_dict['mlm_loss'].val, global_step) + TB_LOGGER.add_scalar( + 'nsgd_rank_loss', running_loss_dict['nsgd_rank_loss'].val, global_step) + TB_LOGGER.add_scalar( + 'i2t_hn_rank_loss', running_loss_dict['i2t_hn_rank_loss'].val, global_step) + TB_LOGGER.add_scalar( + 't2i_hn_rank_loss', running_loss_dict['t2i_hn_rank_loss'].val, global_step) + TB_LOGGER.add_scalar( + 'nsgd_rank_cr', sum(nsgd_rank_corrects[0]) / float(sum(nsgd_rank_corrects[1])), global_step) + TB_LOGGER.add_scalar( + 'i2t_hn_rank_cr', sum(i2t_hn_rank_corrects[0]) / float(sum(i2t_hn_rank_corrects[1])), global_step) + TB_LOGGER.add_scalar( + 't2i_hn_rank_cr', sum(t2i_hn_rank_corrects[0]) / float(sum(t2i_hn_rank_corrects[1])), global_step) + TB_LOGGER.add_scalar( + 'effect_nsgd_number', sum(effect_nsgd_number) / float(len(effect_nsgd_number)), global_step) + TB_LOGGER.step() + effect_nsgd_number, mlm_corrects, nsgd_rank_corrects, i2t_hn_rank_corrects, t2i_hn_rank_corrects = \ + [], [[], []], [[], []], [[], []], [[], []] + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 50 == 0: + # monitor training throughput + LOGGER.info(f'------------Step {global_step}-------------') + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + tot_hn = sum(all_gather_list(n_hard_ex)) + hn_per_sec = int(tot_hn / (time()-start)) + LOGGER.info(f'{tot_ex} ({tot_hn}) examples (hard) ' + f'trained at {ex_per_sec} ({hn_per_sec}) ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + TB_LOGGER.add_scalar('perf/hn_per_s', + hn_per_sec, global_step) + LOGGER.info(f'-------------------------------------------') + + if global_step % opts.valid_steps == 0: + if opts.full_val: + LOGGER.info( + f"========================== Step {global_step} " + f"==========================") + val_log = evaluate(model, eval_loader_val) + try: + TB_LOGGER.log_scaler_dict( + {f"valid/{k}": v for k, v in val_log.items()}) + LOGGER.info(f"image retrieval R1: " + f"{val_log['img_r1']*100:.2f},\n" + f"image retrieval R5: " + f"{val_log['img_r5']*100:.2f},\n" + f"image retrieval R10: " + f"{val_log['img_r10']*100:.2f}\n" + f"text retrieval R1: " + f"{val_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: " + f"{val_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: " + f"{val_log['txt_r10']*100:.2f}") + LOGGER.info("=================================" + "=================================") + except KeyError: + pass + else: + val_log = validate(model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, global_step, optimizer=optimizer) + # torch.save(optimizer.state_dict(), os.path.join()) + # torch.save(lr_scheduler.state_dict(), os.path.join()) + + if global_step >= opts.num_train_steps: + break + + if global_step >= opts.num_train_steps: + break + + pbar.close() + # final validation + val_log = validate(model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, f'{global_step}_final') + + # evaluation + for split, loader in [('val', eval_loader_val), + ('test', eval_loader_test)]: + eval_log = evaluate(model, loader) + TB_LOGGER.log_scaler_dict({f"eval/{split}_{k}": v + for k, v in eval_log.items()}) + if hvd.rank() != 0: + continue + LOGGER.info( + f"========================= {split} ===========================\n" + f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n" + f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n" + f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n" + f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: {eval_log['txt_r10']*100:.2f}") + LOGGER.info("=========================================================") + + +@torch.no_grad() +def validate(model, val_loader): + if hvd.rank() == 0: + pbar = tqdm(total=len(val_loader)) + else: + pbar = NoOp() + LOGGER.info("start running Image Retrieval validation ...") + model.eval() + n_ex = 0 + st = time() + + recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0 + for batch in val_loader: + model_outputs = model(batch, compute_loss=False) + if isinstance(model_outputs, dict): + scores = model_outputs['rank_scores'] + else: + scores = model_outputs + _, indices = scores.squeeze(1).topk(10, dim=0) + rank = (indices == 0).nonzero() + if rank.numel(): + rank = rank.item() + if rank < 1: + recall_at_1 += 1 + if rank < 5: + recall_at_5 += 1 + if rank < 10: + recall_at_10 += 1 + n_ex += 1 + pbar.update(1) + n_ex = sum(all_gather_list(n_ex)) + recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex + recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex + recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex + tot_time = time()-st + val_log = {'valid/ex_per_s': n_ex/tot_time, + 'valid/recall_1': recall_at_1, + 'valid/recall_5': recall_at_5, + 'valid/recall_10': recall_at_10} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"recall_1: {recall_at_1*100:.2f}, " + f"recall_5: {recall_at_5*100:.2f}, " + f"recall_10: {recall_at_10*100:.2f}") + pbar.close() + return val_log + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained MLM") + + parser.add_argument("--output_dir", default=None, type=str, + help="The output directory where the model " + "checkpoints will be written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", default=32, type=int, + help="batch size (# positive examples) for training. " + "(implemented with gradient accumulation)") + + parser.add_argument("--negative_size", default=511, type=int, + help="Number of negative samples per positive sample" + "(forward only)") + parser.add_argument("--hard_neg_size", default=31, type=int, + help="Number of hard negative samples " + "per positive sample (acutally used to train)") + parser.add_argument("--mlm_sample_size", default=22, type=int, + help="Number of samples following masked language masking" + "per positive sample (acutally used to train)") + parser.add_argument("--nsgd_sample_size", default=22, type=int, + help="Number of NSGD for each mlm sample" + "per positive sample (acutally used to train)") + parser.add_argument("--nsgd_sample_temperature", default=2.0, type=float, + help="sampling temperature of NSGD sampling. ") + + parser.add_argument("--mlm_lambda", default=0.1, type=float, + help="lambda in training of mask language modeling") + parser.add_argument("--nsgd_rank_lambda", default=1.0, type=float, + help="lambda in training of NSGD ranking loss") + parser.add_argument("--margin", default=0.2, type=float, + help="margin of ranking loss") + parser.add_argument("--learning_rate", default=3e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", default=1000, type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", default=100000, type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--dropout", default=0.1, type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", default=0.01, type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", default=0.25, type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", default=4000, type=int, + help="Number of training steps to perform linear " + "learning rate warmup for.") + + # device parameters + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--full_val', action='store_true', + help="Always run full evaluation during training") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + # if exists(args.output_dir) and os.listdir(args.output_dir): + # raise ValueError("Output directory ({}) already exists and is not " + # "empty.".format(args.output_dir)) + + # options safe guard + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + + # for tensor core + assert (args.negative_size+1) % 8 == (args.hard_neg_size+1) % 8 == 0 + + print('| args: ', args) + main(args) diff --git a/train_pnsgd2.py b/train_pnsgd2.py new file mode 100644 index 0000000..0ddb3b9 --- /dev/null +++ b/train_pnsgd2.py @@ -0,0 +1,630 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +UNITER finetuning for Image-Text Retrieval with hard negatives +""" +import argparse +import os +from os.path import exists, join +from time import time +import math + +import torch +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader, ConcatDataset +from torch.optim.lr_scheduler import LambdaLR +from apex import amp +from horovod import torch as hvd +from tqdm import tqdm + +from data import (PrefetchLoader, TxtTokLmdb, ImageLmdbGroup, + PNSGDFromText, PNSGDFromImage, pnsgd_collate, itm_rank_hn_collate, + ItmValDataset, itm_val_collate, + ItmEvalDataset, itm_eval_collate) +from model.nsgd2 import UniterForNSGD2 +from optim import get_lr_sched +from optim.misc import build_optimizer + +from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file +from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, + broadcast_tensors) +from utils.save import ModelSaver, save_training_meta +from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed +from utils.const import IMG_DIM +from utils.itm_eval import evaluate + + +def build_dataloader(dataset, collate_fn, is_train, opts): + dataloader = DataLoader(dataset, batch_size=1, + shuffle=is_train, drop_last=is_train, + num_workers=opts.n_workers, + pin_memory=opts.pin_mem, collate_fn=collate_fn) + dataloader = PrefetchLoader(dataloader) + return dataloader + + +def build_lr_scheduler(opts, num_training_steps, optimizer): + num_warmup_steps = ( + opts.warmup_steps + if opts.warmup_steps > 0 + else opts.ceil(num_training_steps * opts.args.warmup_ratio) + ) + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + return LambdaLR(optimizer, lr_lambda) + + +def load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile( + os.path.join(checkpoint, "scheduler.pt") + ): + # Load in optimizer and scheduler states + device = 'gpu' if torch.cuda.is_available() else 'cpu' + optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=device) + ) + lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt"))) + return optimizer, lr_scheduler + + +def main(opts): + hvd.init() + n_gpu = hvd.size() + device = torch.device("cuda", hvd.local_rank()) + torch.cuda.set_device(hvd.local_rank()) + rank = hvd.rank() + opts.rank = rank + LOGGER.info("device: {} n_gpu: {}, rank: {}, " + "16-bits training: {}".format( + device, n_gpu, hvd.rank(), opts.fp16)) + + set_random_seed(opts.seed) + + if hvd.rank() == 0: + save_training_meta(opts) + TB_LOGGER.create(join(opts.output_dir, 'log')) + pbar = tqdm(total=opts.num_train_steps) + model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) + add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) + # store ITM predictions + os.makedirs(join(opts.output_dir, 'results_val')) + os.makedirs(join(opts.output_dir, 'results_test')) + os.makedirs(join(opts.output_dir, 'results_train')) + else: + LOGGER.disabled = True + pbar = NoOp() + model_saver = NoOp() + + # train_examples = None + LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, " + f"{opts.train_img_dbs}") + # check multiple DBs + assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \ + "train txt_db and img_db have different length" + + # load DBs and image dirs + all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, + opts.num_bb, opts.compressed_db) + # train + LOGGER.info(f"Loading Train Dataset " + f"{opts.train_txt_dbs}, {opts.train_img_dbs}") + train_datasets_t = [] + train_datasets_i = [] + for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): + img_db = all_img_dbs[img_path] + txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) + train_datasets_t.append( + PNSGDFromText(txt_db, img_db, opts.negative_size, opts.mlm_sample_size)) + train_datasets_i.append( + PNSGDFromImage(txt_db, img_db, opts.negative_size)) + train_dataset_t = ConcatDataset(train_datasets_t) + train_dataset_i = ConcatDataset(train_datasets_i) + train_dataloader_t = build_dataloader( + train_dataset_t, pnsgd_collate, True, opts) + train_dataloader_i = build_dataloader( + train_dataset_i, pnsgd_collate, True, opts) + + # val + LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}") + val_img_db = all_img_dbs[opts.val_img_db] + val_txt_db = TxtTokLmdb(opts.val_txt_db, -1) + val_dataset = ItmValDataset(val_txt_db, val_img_db, + opts.inf_minibatch_size) + val_dataloader = build_dataloader(val_dataset, itm_val_collate, + False, opts) + # eval + LOGGER.info(f"Loading val, test Dataset for full evaluation: " + f"{opts.val_txt_db}, {opts.val_img_db}" + f"{opts.test_txt_db}, {opts.test_img_db}") + eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db, + opts.inf_minibatch_size) + eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate, + False, opts) + test_img_db = all_img_dbs[opts.test_img_db] + test_txt_db = TxtTokLmdb(opts.test_txt_db, -1) + eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db, + opts.inf_minibatch_size) + eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate, + False, opts) + + # Prepare model + if opts.checkpoint: + checkpoint = torch.load(opts.checkpoint) + else: + checkpoint = {} + + model = UniterForNSGD2.from_pretrained( + opts.model_config, state_dict=checkpoint, + img_dim=IMG_DIM, margin=opts.margin, hard_size=opts.hard_neg_size, + nsgd_sample_size=opts.nsgd_sample_size, nsgd_sample_temperature=opts.nsgd_sample_temperature) + model.init_output() # pretrain ITM head is different from ranking head + model.to(device) + # for name, param in model.named_parameters(): + # print(name, param.size()) + # make sure every process has same model parameters in the beginning + broadcast_tensors([p.data for p in model.parameters()], 0) + set_dropout(model, opts.dropout) + + # Prepare optimizer + optimizer = build_optimizer(model, opts) + # print('| optimizer file:', os.path.join(checkpoint, "optimizer.pt")) + # if os.path.exists(os.path.join(checkpoint, "optimizer.pt")) and \ + # os.path.isfile(os.path.join(checkpoint, "optimizer.pt")): + # # Load in optimizer and scheduler states + # device = 'gpu' if torch.cuda.is_available() else 'cpu' + # optimizer.load_state_dict(torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=device)) + model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16, opt_level='O0') # , num_losses=4) + # Prepare scheduler + # lr_scheduler = build_lr_scheduler(opts, opts.num_train_steps, optimizer) + # print('| scheduler file:', os.path.join(checkpoint, "scheduler.pt")) + # if os.path.exists(os.path.join(checkpoint, "scheduler.pt")) and \ + # os.path.isfile(os.path.join(checkpoint, "scheduler.pt")): + # lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, 'scheduler.pt'))) + + LOGGER.info(f"***** Running training on {n_gpu} GPUs *****") + LOGGER.info(" Num examples = %d", + sum(all_gather_list(len(train_dataset_t)))) + LOGGER.info(" Batch size = %d", opts.train_batch_size) + LOGGER.info(" Num steps = %d", opts.num_train_steps) + + # running_loss = RunningMeter('loss') + running_loss_dict = {loss: RunningMeter(loss) for loss in + ['mlm_loss', 'nsgd_rank_loss', 'i2t_hn_rank_loss', 't2i_hn_rank_loss', + 'disc_loss', 'correction_loss']} + model.train() + + global_step = 0 + step = 0 + n_examples = 0 + n_hard_ex = 0 + start = time() + train_iter_i = iter(train_dataloader_i) + # quick hack for amp delay_unscale bug + optimizer.zero_grad() + optimizer.step() + effect_nsgd_number, mlm_corrects, nsgd_rank_corrects, i2t_hn_rank_corrects, t2i_hn_rank_corrects = \ + [], [[], []], [[], []], [[], []], [[], []] + correction_corrects = [[], []] + disc_pos_corrects, disc_neg_corrects = [[], []], [[], []] + while True: + for batch in train_dataloader_t: + + # hard text from image + try: + batch_i = next(train_iter_i) + except StopIteration: + train_iter_i = iter(train_dataloader_i) + batch_i = next(train_iter_i) + + n_examples += batch['mlm_attn_masks'].size(0) + model_outputs = model( + batch, sample_from='gi', compute_loss=True, + compute_mlm=True if isinstance(args.mlm_lambda, float) and args.mlm_lambda > 0 else False) + loss = 0.0 + if model_outputs.get('masked_lm_loss') is not None: + mlm_loss = model_outputs.get('masked_lm_loss').mean() + loss += args.mlm_lambda * mlm_loss + running_loss_dict.get('mlm_loss')(mlm_loss.item() / opts.train_batch_size) + # print('| mlm_loss:', mlm_loss.item() / opts.train_batch_size) + mlm_corrects[0].append(model_outputs['mlm_corrects'].sum().item()) + mlm_corrects[1].append(model_outputs['mlm_corrects'].numel()) + # print('mlm_corrects:', mlm_corrects[0][-1]/mlm_corrects[1][-1]) + effect_nsgd_number.append(model_outputs['effect_nsgd_number']) + if model_outputs.get('rank_loss') is not None: + nsgd_rank_loss = model_outputs.get('rank_loss').mean() + running_loss_dict.get('nsgd_rank_loss')(nsgd_rank_loss.item() / opts.train_batch_size) + # print('| nsgd_rank_loss:', nsgd_rank_loss.item() / opts.train_batch_size) + loss += args.nsgd_rank_lambda * nsgd_rank_loss + rank_corrects = model_outputs.get('rank_corrects') + nsgd_hard_ex = rank_corrects.numel() + nsgd_rank_corrects[0].append(rank_corrects.sum().item()) + nsgd_rank_corrects[1].append(nsgd_hard_ex) + n_hard_ex += nsgd_hard_ex + if isinstance(loss, torch.Tensor): + loss = loss / opts.train_batch_size + with amp.scale_loss(loss, optimizer, delay_unscale=False, # loss_id=0 + ) as scaled_loss: + scaled_loss.backward() + + # discrimination and correction of synthetic text samples + batch['syn_input_ids'] = model_outputs['syn_input_ids'] + batch['syn_position_ids'] = model_outputs['syn_position_ids'] + batch['syn_img_feat'] = model_outputs['syn_img_feat'] + batch['syn_img_pos_feat'] = model_outputs['syn_img_pos_feat'] + batch['syn_attn_masks'] = model_outputs['syn_attn_masks'] + batch['syn_gather_index'] = model_outputs['syn_gather_index'] + batch['syn_txt_labels'] = model_outputs['syn_txt_labels'] + batch['gt_input_ids'] = model_outputs['gt_input_ids'] + n_examples += batch['syn_input_ids'].size(0) + model_outputs = model( + batch, sample_from='gsynt', compute_loss=True) + loss = 0.0 + if model_outputs.get('disc_loss') is not None: + disc_loss = model_outputs.get('disc_loss').mean() + loss += args.disc_lambda * disc_loss + # print('| disc_loss:', disc_loss.item() / opts.train_batch_size) + running_loss_dict.get('disc_loss')(disc_loss.item() / opts.train_batch_size) + disc_pos_corrects[0].append(model_outputs['disc_pos_corrects'].item()) + disc_pos_corrects[1].append(model_outputs['disc_pos_samples'].item()) + disc_neg_corrects[0].append(model_outputs['disc_neg_corrects'].item()) + disc_neg_corrects[1].append(model_outputs['disc_neg_samples'].item()) + if model_outputs.get('correction_loss') is not None: + correction_loss = model_outputs.get('correction_loss').mean() + loss += args.correction_lambda * correction_loss + # print('| correction_loss:', correction_loss.item() / opts.train_batch_size) + running_loss_dict.get('correction_loss')(correction_loss.item() / opts.train_batch_size) + correction_corrects[0].append(model_outputs['correction_corrects'].sum().item()) + correction_corrects[1].append(model_outputs['correction_corrects'].numel()) + if isinstance(loss, torch.Tensor): + loss = loss / opts.train_batch_size + with amp.scale_loss(loss, optimizer, delay_unscale=False, # loss_id=1 + ) as scaled_loss: + scaled_loss.backward() + + # ITM training with hard negative samples + n_examples += batch_i['attn_masks'].size(0) + model_outputs = model(batch_i, sample_from='i', compute_loss=True) + loss = model_outputs.get('rank_loss') + rank_corrects = model_outputs.get('rank_corrects') + i2t_hard_ex = rank_corrects.numel() + i2t_hn_rank_corrects[0].append(rank_corrects.sum().item()) + i2t_hn_rank_corrects[1].append(i2t_hard_ex) + n_hard_ex += i2t_hard_ex + loss = loss.mean() / opts.train_batch_size + + with amp.scale_loss(loss, optimizer, delay_unscale=False, # loss_id=2 + ) as scaled_loss: + scaled_loss.backward() + running_loss_dict.get('t2i_hn_rank_loss')(loss.item()) + + # hard image from text + n_examples += batch['attn_masks'].size(0) + model_outputs = model(batch, sample_from='t', compute_loss=True) + loss = model_outputs.get('rank_loss') + rank_corrects = model_outputs.get('rank_corrects') + t2i_hard_ex = rank_corrects.numel() + t2i_hn_rank_corrects[0].append(rank_corrects.sum().item()) + t2i_hn_rank_corrects[1].append(t2i_hard_ex) + n_hard_ex += t2i_hard_ex + # NOTE we use gradient accumulation to implemented train_batch_size + loss = loss.mean() / opts.train_batch_size + running_loss_dict.get('i2t_hn_rank_loss')(loss.item()) + # print('| t2i_triplet_loss:', loss.item()) + + step += 1 + delay_unscale = step % opts.train_batch_size != 0 + with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale, # loss_id=3 + ) as scaled_loss: + scaled_loss.backward() + if not delay_unscale: + # gather gradients from every processes + # do this before unscaling to make sure every process uses + # the same gradient scale + grads = [p.grad.data for p in model.parameters() + if p.requires_grad and p.grad is not None] + all_reduce_and_rescale_tensors(grads, float(1)) + + if step % opts.train_batch_size == 0: + global_step += 1 + + # learning rate scheduling + lr_this_step = get_lr_sched(global_step, opts) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + TB_LOGGER.add_scalar('lr', lr_this_step, global_step) + + # log loss + # NOTE: not gathered across GPUs for efficiency + TB_LOGGER.add_scalar( + 'mlm_loss', running_loss_dict['mlm_loss'].val, global_step) + TB_LOGGER.add_scalar( + 'nsgd_rank_loss', running_loss_dict['nsgd_rank_loss'].val, global_step) + TB_LOGGER.add_scalar( + 'i2t_hn_rank_loss', running_loss_dict['i2t_hn_rank_loss'].val, global_step) + TB_LOGGER.add_scalar( + 't2i_hn_rank_loss', running_loss_dict['t2i_hn_rank_loss'].val, global_step) + TB_LOGGER.add_scalar( + 'disc_loss', running_loss_dict['disc_loss'].val, global_step) + TB_LOGGER.add_scalar( + 'correction_loss', running_loss_dict['correction_loss'].val, global_step) + TB_LOGGER.add_scalar( + 'nsgd_rank_cr', sum(nsgd_rank_corrects[0]) / float(sum(nsgd_rank_corrects[1])), global_step) + TB_LOGGER.add_scalar( + 'i2t_hn_rank_cr', sum(i2t_hn_rank_corrects[0]) / float(sum(i2t_hn_rank_corrects[1])), global_step) + TB_LOGGER.add_scalar( + 't2i_hn_rank_cr', sum(t2i_hn_rank_corrects[0]) / float(sum(t2i_hn_rank_corrects[1])), global_step) + TB_LOGGER.add_scalar( + 'effect_nsgd_number', sum(effect_nsgd_number) / float(len(effect_nsgd_number)), global_step) + TB_LOGGER.add_scalar( + 'correction_cr', sum(correction_corrects[0]) / max(float(sum(correction_corrects[1])), 1.0), global_step) + TB_LOGGER.add_scalar( + 'disc_pos_ntokens', sum(disc_pos_corrects[1]), global_step) + TB_LOGGER.add_scalar( + 'disc_neg_ntokens', sum(disc_neg_corrects[1]), global_step) + TB_LOGGER.add_scalar( + 'disc_pos_cr', sum(disc_pos_corrects[0]) / max(float(sum(disc_pos_corrects[1])), 1.0), global_step) + TB_LOGGER.add_scalar( + 'disc_neg_cr', sum(disc_neg_corrects[0]) / max(float(sum(disc_neg_corrects[1])), 1.0), global_step) + TB_LOGGER.step() + effect_nsgd_number, mlm_corrects, nsgd_rank_corrects, i2t_hn_rank_corrects, t2i_hn_rank_corrects = \ + [], [[], []], [[], []], [[], []], [[], []] + correction_corrects = [[], []] + disc_pos_corrects, disc_neg_corrects = [[], []], [[], []] + + # update model params + if opts.grad_norm != -1: + grad_norm = clip_grad_norm_(amp.master_params(optimizer), + opts.grad_norm) + TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + if global_step % 50 == 0: + # monitor training throughput + LOGGER.info(f'------------Step {global_step}-------------') + tot_ex = sum(all_gather_list(n_examples)) + ex_per_sec = int(tot_ex / (time()-start)) + tot_hn = sum(all_gather_list(n_hard_ex)) + hn_per_sec = int(tot_hn / (time()-start)) + LOGGER.info(f'{tot_ex} ({tot_hn}) examples (hard) ' + f'trained at {ex_per_sec} ({hn_per_sec}) ex/s') + TB_LOGGER.add_scalar('perf/ex_per_s', + ex_per_sec, global_step) + TB_LOGGER.add_scalar('perf/hn_per_s', + hn_per_sec, global_step) + LOGGER.info(f'-------------------------------------------') + + if global_step % opts.valid_steps == 0: + if opts.full_val: + LOGGER.info( + f"========================== Step {global_step} " + f"==========================") + val_log = evaluate(model, eval_loader_val) + try: + TB_LOGGER.log_scaler_dict( + {f"valid/{k}": v for k, v in val_log.items()}) + LOGGER.info(f"image retrieval R1: " + f"{val_log['img_r1']*100:.2f},\n" + f"image retrieval R5: " + f"{val_log['img_r5']*100:.2f},\n" + f"image retrieval R10: " + f"{val_log['img_r10']*100:.2f}\n" + f"text retrieval R1: " + f"{val_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: " + f"{val_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: " + f"{val_log['txt_r10']*100:.2f}") + LOGGER.info("=================================" + "=================================") + except KeyError: + pass + else: + val_log = validate(model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, global_step, optimizer=optimizer) + # torch.save(optimizer.state_dict(), os.path.join()) + # torch.save(lr_scheduler.state_dict(), os.path.join()) + + if global_step >= opts.num_train_steps: + break + + if global_step >= opts.num_train_steps: + break + + pbar.close() + # final validation + val_log = validate(model, val_dataloader) + TB_LOGGER.log_scaler_dict(val_log) + model_saver.save(model, f'{global_step}_final') + + # evaluation + for split, loader in [('val', eval_loader_val), + ('test', eval_loader_test)]: + eval_log = evaluate(model, loader) + TB_LOGGER.log_scaler_dict({f"eval/{split}_{k}": v + for k, v in eval_log.items()}) + if hvd.rank() != 0: + continue + LOGGER.info( + f"========================= {split} ===========================\n" + f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n" + f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n" + f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n" + f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n" + f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n" + f"text retrieval R10: {eval_log['txt_r10']*100:.2f}") + LOGGER.info("=========================================================") + + +@torch.no_grad() +def validate(model, val_loader): + if hvd.rank() == 0: + pbar = tqdm(total=len(val_loader)) + else: + pbar = NoOp() + LOGGER.info("start running Image Retrieval validation ...") + model.eval() + n_ex = 0 + st = time() + + recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0 + for batch in val_loader: + model_outputs = model(batch, compute_loss=False) + if isinstance(model_outputs, dict): + scores = model_outputs['rank_scores'] + else: + scores = model_outputs + _, indices = scores.squeeze(1).topk(10, dim=0) + rank = (indices == 0).nonzero() + if rank.numel(): + rank = rank.item() + if rank < 1: + recall_at_1 += 1 + if rank < 5: + recall_at_5 += 1 + if rank < 10: + recall_at_10 += 1 + n_ex += 1 + pbar.update(1) + n_ex = sum(all_gather_list(n_ex)) + recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex + recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex + recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex + tot_time = time()-st + val_log = {'valid/ex_per_s': n_ex/tot_time, + 'valid/recall_1': recall_at_1, + 'valid/recall_5': recall_at_5, + 'valid/recall_10': recall_at_10} + model.train() + LOGGER.info(f"validation finished in {int(tot_time)} seconds, " + f"recall_1: {recall_at_1*100:.2f}, " + f"recall_5: {recall_at_5*100:.2f}, " + f"recall_10: {recall_at_10*100:.2f}") + pbar.close() + return val_log + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # Required parameters + + parser.add_argument('--compressed_db', action='store_true', + help='use compressed LMDB') + parser.add_argument("--checkpoint", + default=None, type=str, + help="pretrained MLM") + + parser.add_argument("--output_dir", default=None, type=str, + help="The output directory where the model " + "checkpoints will be written.") + + # Prepro parameters + parser.add_argument('--max_txt_len', type=int, default=60, + help='max number of tokens in text (BERT BPE)') + parser.add_argument('--conf_th', type=float, default=0.2, + help='threshold for dynamic bounding boxes ' + '(-1 for fixed)') + parser.add_argument('--max_bb', type=int, default=100, + help='max number of bounding boxes') + parser.add_argument('--min_bb', type=int, default=10, + help='min number of bounding boxes') + parser.add_argument('--num_bb', type=int, default=36, + help='static number of bounding boxes') + + # training parameters + parser.add_argument("--train_batch_size", default=32, type=int, + help="batch size (# positive examples) for training. " + "(implemented with gradient accumulation)") + + parser.add_argument("--negative_size", default=511, type=int, + help="Number of negative samples per positive sample" + "(forward only)") + parser.add_argument("--hard_neg_size", default=31, type=int, + help="Number of hard negative samples " + "per positive sample (acutally used to train)") + parser.add_argument("--mlm_sample_size", default=22, type=int, + help="Number of samples following masked language masking" + "per positive sample (acutally used to train)") + parser.add_argument("--nsgd_sample_size", default=22, type=int, + help="Number of NSGD for each mlm sample" + "per positive sample (acutally used to train)") + parser.add_argument("--nsgd_sample_temperature", default=2.0, type=float, + help="sampling temperature of NSGD sampling. ") + + parser.add_argument("--disc_lambda", default=0.1, type=float, + help="lambda in training of discrimination") + parser.add_argument("--correction_lambda", default=0.1, type=float, + help="lambda in training of correction") + parser.add_argument("--mlm_lambda", default=0.1, type=float, + help="lambda in training of mask language modeling") + parser.add_argument("--nsgd_rank_lambda", default=1.0, type=float, + help="lambda in training of NSGD ranking loss") + parser.add_argument("--margin", default=0.2, type=float, + help="margin of ranking loss") + parser.add_argument("--learning_rate", default=3e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--valid_steps", default=1000, type=int, + help="Run validation every X steps") + parser.add_argument("--num_train_steps", default=100000, type=int, + help="Total number of training updates to perform.") + parser.add_argument("--optim", default='adam', + choices=['adam', 'adamax', 'adamw'], + help="optimizer") + parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', + help="beta for adam optimizer") + parser.add_argument("--dropout", default=0.1, type=float, + help="tune dropout regularization") + parser.add_argument("--weight_decay", default=0.01, type=float, + help="weight decay (L2) regularization") + parser.add_argument("--grad_norm", default=0.25, type=float, + help="gradient clipping (-1 for no clipping)") + parser.add_argument("--warmup_steps", default=4000, type=int, + help="Number of training steps to perform linear " + "learning rate warmup for.") + + # device parameters + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--full_val', action='store_true', + help="Always run full evaluation during training") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit float precision instead " + "of 32-bit") + parser.add_argument('--n_workers', type=int, default=4, + help="number of data workers") + parser.add_argument('--pin_mem', action='store_true', + help="pin memory") + + # can use config files + parser.add_argument('--config', help='JSON config files') + + args = parse_with_config(parser) + + # if exists(args.output_dir) and os.listdir(args.output_dir): + # raise ValueError("Output directory ({}) already exists and is not " + # "empty.".format(args.output_dir)) + + # options safe guard + if args.conf_th == -1: + assert args.max_bb + args.max_txt_len + 2 <= 512 + else: + assert args.num_bb + args.max_txt_len + 2 <= 512 + + # for tensor core + assert (args.negative_size+1) % 8 == (args.hard_neg_size+1) % 8 == 0 + + print('| args: ', args) + main(args) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/ans2label.json b/utils/ans2label.json new file mode 100644 index 0000000..9fc717b --- /dev/null +++ b/utils/ans2label.json @@ -0,0 +1 @@ +{"net": 0, "pitcher": 1, "orange": 2, "yes": 3, "white": 4, "skiing": 5, "red": 6, "frisbee": 7, "brushing teeth": 8, "no": 9, "black and white": 10, "skateboard": 11, "1": 12, "blue": 13, "green": 14, "motorcycle": 15, "gray": 16, "2": 17, "purse": 18, "skis": 19, "poles": 20, "surfboard": 21, "dog": 22, "on": 23, "office": 24, "large": 25, "very big": 26, "laptop": 27, "vent": 28, "computer": 29, "black": 30, "bear": 31, "3": 32, "wii": 33, "glasses": 34, "tree": 35, "eating": 36, "log": 37, "5": 38, "raft": 39, "left": 40, "living room": 41, "pink": 42, "right": 43, "railing": 44, "grass": 45, "wire": 46, "10 years": 47, "knife": 48, "cake": 49, "banana": 50, "chef": 51, "vanilla": 52, "4": 53, "outdoor": 54, "mustard": 55, "bun": 56, "clouds": 57, "dock": 58, "brown": 59, "silver": 60, "refrigerator": 61, "square": 62, "teddy": 63, "elm": 64, "stripes": 65, "baseball": 66, "catcher": 67, "beer": 68, "bottom": 69, "north": 70, "nike": 71, "yellow and white": 72, "morning": 73, "elephant": 74, "red and white": 75, "propeller": 76, "tan": 77, "wall": 78, "rolex": 79, "clock": 80, "table": 81, "0": 82, "wood": 83, "christmas": 84, "spinach": 85, "thick": 86, "bag": 87, "leaves": 88, "necklace": 89, "6": 90, "bathroom": 91, "shower": 92, "towel": 93, "solid": 94, "referee": 95, "wilson": 96, "8:00": 97, "e": 98, "24": 99, "hat": 100, "grazing": 101, "sheep": 102, "10": 103, "tag": 104, "spanish": 105, "hot dog": 106, "plate": 107, "lunch": 108, "butter": 109, "peppers": 110, "onions": 111, "very": 112, "mayonnaise": 113, "mayo": 114, "sweet potato": 115, "pig": 116, "sweet": 117, "flowers": 118, "floral": 119, "yellow": 120, "window": 121, "7": 122, "pizza": 123, "car": 124, "": 125, "cargo": 126, "stairs": 127, "abstract": 128, "rug": 129, "baseball cap": 130, "texting": 131, "pole": 132, "crosswalk": 133, "nothing": 134, "urban": 135, "bus": 136, "light": 137, "afternoon": 138, "boat": 139, "cheese": 140, "paper": 141, "real": 142, "sun": 143, "birthday": 144, "words": 145, "inside": 146, "shadows": 147, "tomato": 148, "evergreen": 149, "100 feet": 150, "shingles": 151, "trees": 152, "building": 153, "hay": 154, "ski pole": 155, "patterned": 156, "walking": 157, "ice": 158, "laundry": 159, "pepsi": 160, "good": 161, "1:50": 162, "purple": 163, "13": 164, "africa": 165, "teddy bears": 166, "socks": 167, "giraffe": 168, "soccer": 169, "blue and yellow": 170, "zebras": 171, "cupcake": 172, "broccoli": 173, "soldier": 174, "parking lot": 175, "cows": 176, "herding": 177, "on table": 178, "fish": 179, "nightstand": 180, "50": 181, "overcast": 182, "cross": 183, "toaster oven": 184, "tile": 185, "11:55": 186, "red and yellow": 187, "nowhere": 188, "hair dryer": 189, "truck": 190, "11": 191, "people": 192, "rectangle": 193, "hot dogs": 194, "party": 195, "12:55": 196, "apron": 197, "kitchen": 198, "cooking": 199, "ring": 200, "1 way": 201, "stop": 202, "neither": 203, "many": 204, "female": 205, "brushing": 206, "tie": 207, "tennis racket": 208, "knife and fork": 209, "restaurant": 210, "cat": 211, "bed": 212, "sand": 213, "ocean": 214, "cold": 215, "kites": 216, "cumulus": 217, "standing": 218, "male": 219, "star": 220, "tracks": 221, "chocolate": 222, "round": 223, "fork and knife": 224, "yankees": 225, "pictures": 226, "dots": 227, "bird": 228, "parrot": 229, "red white and blue": 230, "man": 231, "metal": 232, "fence": 233, "snowboarding": 234, "pine": 235, "snow": 236, "shorts": 237, "swim": 238, "wine": 239, "brick": 240, "no parking": 241, "children": 242, "beef": 243, "phone": 244, "english": 245, "cell phone": 246, "pink and yellow": 247, "clear": 248, "watermelon": 249, "bedroom": 250, "fork": 251, "cow": 252, "rackets": 253, "tennis rackets": 254, "8": 255, "collar": 256, "tennis": 257, "1950s": 258, "playing tennis": 259, "skirt": 260, "30": 261, "polka dot": 262, "beach": 263, "horse": 264, "grill": 265, "african american": 266, "down": 267, "street": 268, "in air": 269, "sweater": 270, "yellow and blue": 271, "park": 272, "backyard": 273, "spectators": 274, "parasailing": 275, "31": 276, "river": 277, "55": 278, "shadow": 279, "winter": 280, "chicken": 281, "tea": 282, "evening": 283, "dusk": 284, "ski resort": 285, "helmet": 286, "penne": 287, "bench": 288, "resting": 289, "elephants": 290, "southwest": 291, "usa": 292, "cars": 293, "town": 294, "bananas": 295, "umbrella": 296, "container": 297, "woman": 298, "on counter": 299, "salad": 300, "striped": 301, "motel": 302, "vertical": 303, "oranges": 304, "hot sauce": 305, "bottle": 306, "juice": 307, "eyes": 308, "ground": 309, "backpack": 310, "black and yellow": 311, "forward": 312, "jackets": 313, "1 on right": 314, "green and yellow": 315, "playing baseball": 316, "riding": 317, "sitting": 318, "carrot": 319, "basket": 320, "seagull": 321, "ski poles": 322, "p": 323, "parking": 324, "street light": 325, "mets": 326, "strap": 327, "bike": 328, "riding bike": 329, "poodle": 330, "shoes": 331, "carpet": 332, "lettuce": 333, "food": 334, "1 foot": 335, "roses": 336, "mountains": 337, "scissors": 338, "camera": 339, "beige": 340, "beard": 341, "cutting": 342, "baby": 343, "tape": 344, "watch": 345, "never": 346, "taking picture": 347, "eggs": 348, "syrup": 349, "sandwich": 350, "water skiing": 351, "microphone": 352, "back": 353, "bears": 354, "donuts": 355, "w": 356, "sky": 357, "double decker": 358, "england": 359, "surfing": 360, "running": 361, "shirt": 362, "barn": 363, "weather vane": 364, "white and blue": 365, "fishing": 366, "bridge": 367, "los angeles": 368, "open": 369, "red sox": 370, "bat": 371, "plane": 372, "white and green": 373, "transportation": 374, "sunny": 375, "bus stop": 376, "city": 377, "brown and white": 378, "bicycle": 379, "crow": 380, "magazines": 381, "daisy": 382, "14": 383, "old": 384, "curtains": 385, "jumped": 386, "snowboard": 387, "dinosaur": 388, "racing": 389, "asphalt": 390, "court": 391, "plastic": 392, "circle": 393, "red and blue": 394, "zebra": 395, "12": 396, "biplane": 397, "shallow": 398, "brazil": 399, "logo": 400, "2:20": 401, "electric": 402, "night time": 403, "motion": 404, "toothbrushes": 405, "orange and white": 406, "66": 407, "spoon": 408, "toyota": 409, "tennis shoes": 410, "46": 411, "second": 412, "no 1": 413, "iphone": 414, "friend": 415, "apple": 416, "carnation": 417, "15": 418, "tiger": 419, "glove": 420, "airplane": 421, "bow": 422, "air france": 423, "passengers": 424, "tv": 425, "on building": 426, "3:55": 427, "victorian": 428, "steeple": 429, "happy": 430, "skateboarding": 431, "fruit": 432, "cutting board": 433, "cantaloupe": 434, "kiwi": 435, "sliced": 436, "heart": 437, "water": 438, "rainy": 439, "carrots": 440, "giraffes": 441, "eat": 442, "ramp": 443, "lab": 444, "field": 445, "horizontal": 446, "birds": 447, "home": 448, "shrimp": 449, "12 feet": 450, "girl": 451, "modern": 452, "turtle": 453, "dell": 454, "boots": 455, "sunglasses": 456, "black and orange": 457, "yellow and black": 458, "gloves": 459, "hp": 460, "desk": 461, "both": 462, "sign": 463, "on street": 464, "2000": 465, "cirrus": 466, "to dry": 467, "ceiling": 468, "fluorescent": 469, "up": 470, "9": 471, "boys": 472, "playing soccer": 473, "american": 474, "passenger": 475, "turn": 476, "palm": 477, "no train": 478, "wedding": 479, "branch": 480, "parrots": 481, "air force": 482, "on tracks": 483, "small": 484, "tank": 485, "dirty": 486, "france": 487, "honda": 488, "2.00": 489, "whale": 490, "vase": 491, "flying": 492, "professional": 493, "driving": 494, "tissue": 495, "protest": 496, "corona": 497, "for balance": 498, "twin": 499, "clothes": 500, "t shirt": 501, "window sill": 502, "wild": 503, "noon": 504, "caution": 505, "spring": 506, "raining": 507, "cane": 508, "school": 509, "windsurfing": 510, "parachute": 511, "black and red": 512, "25": 513, "background": 514, "toaster": 515, "planes": 516, "yellow and red": 517, "spatula": 518, "10:10": 519, "ivory": 520, "train": 521, "welcome": 522, "highway": 523, "off": 524, "on track": 525, "electricity": 526, "italy": 527, "dinner": 528, "sink": 529, "squares": 530, "5 ft": 531, "parked": 532, "store": 533, "dress": 534, "signs": 535, "meow": 536, "football": 537, "rugby": 538, "stainless steel": 539, "la": 540, "dirt": 541, "blue and white": 542, "klm": 543, "house": 544, "unknown": 545, "ford": 546, "reading": 547, "chair": 548, "mountain": 549, "alive": 550, "water skis": 551, "picture": 552, "parade": 553, "slippers": 554, "trailer": 555, "boating": 556, "holding it": 557, "shade": 558, "cloth": 559, "6:20": 560, "candle": 561, "hose": 562, "hand": 563, "3:25": 564, "on sidewalk": 565, "poster": 566, "downhill": 567, "68": 568, "reflection": 569, "summer": 570, "pickles": 571, "halloween": 572, "bats": 573, "london": 574, "zoo": 575, "surfer": 576, "racket": 577, "flickr": 578, "cutting hair": 579, "strawberries": 580, "mushroom": 581, "teddy bear": 582, "big": 583, "suitcase": 584, "veggie": 585, "pepper": 586, "houses": 587, "70": 588, "toshiba": 589, "triangle": 590, "boxes": 591, "photograph": 592, "smoke": 593, "engine": 594, "camel": 595, "sidewalk": 596, "left 1": 597, "red and green": 598, "4:35": 599, "on couch": 600, "candy": 601, "minnie mouse": 602, "homemade": 603, "mouse": 604, "box": 605, "movie": 606, "45": 607, "strawberry": 608, "fridge": 609, "full": 610, "vegetables": 611, "bright": 612, "play": 613, "remote": 614, "pond": 615, "savannah": 616, "celery": 617, "concrete": 618, "semi": 619, "dump": 620, "scania": 621, "safety": 622, "posing": 623, "fabric": 624, "laying": 625, "couch": 626, "blueberries": 627, "handle": 628, "pipe": 629, "stick": 630, "parmesan": 631, "steak": 632, "chain link": 633, "catch": 634, "barbed wire": 635, "mozzarella": 636, "soda": 637, "fire hydrant": 638, "cat food": 639, "pepperoni": 640, "lot": 641, "licking": 642, "red and black": 643, "clay": 644, "tennis court": 645, "jumping": 646, "potatoes": 647, "toothbrush": 648, "kite": 649, "not at all": 650, "flying kite": 651, "broken": 652, "black and silver": 653, "lap": 654, "outside": 655, "44": 656, "delta": 657, "greyhound": 658, "ring finger": 659, "talking on phone": 660, "bad": 661, "kettle": 662, "35": 663, "motorcycles": 664, "produce": 665, "comfort": 666, "steering wheel": 667, "18": 668, "humans": 669, "coffee": 670, "white and brown": 671, "fall": 672, "bread": 673, "cherry": 674, "4:30": 675, "flag": 676, "night": 677, "lamp": 678, "cucumber": 679, "can't see": 680, "porcelain": 681, "oval": 682, "museum": 683, "rain": 684, "sprinkles": 685, "20": 686, "kids": 687, "bracelet": 688, "sneakers": 689, "mask": 690, "mickey mouse": 691, "twins": 692, "very high": 693, "costume": 694, "cabbage": 695, "paint": 696, "lighting": 697, "young": 698, "air conditioner": 699, "wooden": 700, "board": 701, "someone": 702, "beets": 703, "16": 704, "day time": 705, "4 inches": 706, "lights": 707, "ladder": 708, "glass": 709, "ferris wheel": 710, "fries": 711, "steamed": 712, "shepherd": 713, "cotton": 714, "suit": 715, "goatee": 716, "on his head": 717, "print": 718, "happy birthday": 719, "forks": 720, "travel": 721, "maple": 722, "200": 723, "oil": 724, "jeans": 725, "can": 726, "chopsticks": 727, "on wall": 728, "construction": 729, "mack": 730, "36": 731, "chinese": 732, "moped": 733, "festival": 734, "gas": 735, "throwing": 736, "circus": 737, "wires": 738, "not possible": 739, "plates": 740, "sugar": 741, "in": 742, "women's": 743, "door": 744, "no man": 745, "volleyball": 746, "serving": 747, "ponytail": 748, "business": 749, "decoration": 750, "santa": 751, "flat": 752, "barrel": 753, "12:15": 754, "candles": 755, "atv": 756, "free": 757, "hair": 758, "waffle": 759, "ball": 760, "stop sign": 761, "wetsuit": 762, "very deep": 763, "swimsuit": 764, "green and black": 765, "foreground": 766, "stands": 767, "china airlines": 768, "flower": 769, "300": 770, "lobster": 771, "on bench": 772, "plaster": 773, "phones": 774, "sailboat": 775, "apples": 776, "road": 777, "recently": 778, "cones": 779, "cactus": 780, "rice": 781, "vegetarian": 782, "donut": 783, "ketchup": 784, "police": 785, "mirror": 786, "rock": 787, "meat": 788, "blinds": 789, "cell phones": 790, "china": 791, "rust": 792, "7:25": 793, "stone": 794, "vans": 795, "middle": 796, "eagle": 797, "9:30": 798, "ping pong": 799, "microwave": 800, "gmc": 801, "umbrellas": 802, "wrist": 803, "cuddling": 804, "laughing": 805, "boy": 806, "next to toilet": 807, "tabby": 808, "petting": 809, "south": 810, "40": 811, "name tag": 812, "checkered": 813, "name": 814, "slow": 815, "cardboard": 816, "windows": 817, "croissant": 818, "plain": 819, "cookie": 820, "on ground": 821, "low": 822, "water bottle": 823, "goggles": 824, "turkey": 825, "pull": 826, "shut": 827, "kite flying": 828, "bowl": 829, "smile": 830, "in bowl": 831, "bush": 832, "cloudy": 833, "top left": 834, "skateboarder": 835, "coca cola": 836, "pan": 837, "drinking": 838, "short": 839, "floor": 840, "thanksgiving": 841, "radio": 842, "drink": 843, "on toilet": 844, "bike rack": 845, "bleachers": 846, "train tracks": 847, "horses": 848, "far": 849, "top": 850, "toilet": 851, "in water": 852, "private": 853, "nature": 854, "checkers": 855, "commercial": 856, "stroller": 857, "power": 858, "stuffed animals": 859, "uniforms": 860, "japan": 861, "liquor": 862, "faucet": 863, "green and orange": 864, "corn": 865, "sub": 866, "white and yellow": 867, "mercedes": 868, "in sky": 869, "tarp": 870, "indian": 871, "counter": 872, "multicolored": 873, "polar": 874, "go": 875, "now": 876, "no number": 877, "swimming": 878, "bridle": 879, "cowboy": 880, "union station": 881, "salt and pepper": 882, "olives": 883, "pizza cutter": 884, "british airways": 885, "nighttime": 886, "domestic": 887, "trolley": 888, "australia": 889, "tiles": 890, "pug": 891, "wicker": 892, "british": 893, "us airways express": 894, "burton": 895, "christmas tree": 896, "napkin": 897, "writing": 898, "rocks": 899, "hello kitty": 900, "lacoste": 901, "gold": 902, "fan": 903, "skateboards": 904, "day": 905, "on floor": 906, "2008": 907, "dark": 908, "flying kites": 909, "rural": 910, "olympics": 911, "bmw": 912, "34": 913, "factory": 914, "denim": 915, "typing": 916, "for fun": 917, "steel": 918, "watching tv": 919, "chevron": 920, "driver": 921, "baggage claim": 922, "grapes": 923, "f": 924, "angels": 925, "roof": 926, "handlebars": 927, "train station": 928, "public": 929, "oak": 930, "sleeping": 931, "canada": 932, "on runway": 933, "air canada": 934, "on top": 935, "tired": 936, "blonde": 937, "cups": 938, "little": 939, "adidas": 940, "10 feet": 941, "white and gray": 942, "leaf": 943, "fisheye": 944, "forest": 945, "war": 946, "octagon": 947, "raspberry": 948, "helmets": 949, "united states": 950, "29": 951, "noodles": 952, "van": 953, "long": 954, "traveling": 955, "luggage": 956, "airport": 957, "single": 958, "pitching": 959, "dugout": 960, "garbage": 961, "in street": 962, "happiness": 963, "cigarette": 964, "on tower": 965, "antelope": 966, "graffiti": 967, "skating": 968, "on road": 969, "curved": 970, "red light": 971, "washington": 972, "ski lift": 973, "athletics": 974, "brace": 975, "squatting": 976, "catching": 977, "batter": 978, "batting": 979, "game": 980, "towards": 981, "33": 982, "sliding": 983, "makeup": 984, "japanese": 985, "person": 986, "pirates": 987, "plaid": 988, "rose": 989, "daytime": 990, "keyboard": 991, "surfboards": 992, "hummingbird": 993, "ollie": 994, "11:30": 995, "clock tower": 996, "5:55": 997, "san francisco": 998, "stopping": 999, "tags": 1000, "samsung": 1001, "computers": 1002, "cabinets": 1003, "talking": 1004, "cage": 1005, "asparagus": 1006, "5 years": 1007, "hanger": 1008, "adult": 1009, "rabbit": 1010, "empty": 1011, "softball": 1012, "1st": 1013, "playing": 1014, "chairs": 1015, "farm": 1016, "cross country": 1017, "dump truck": 1018, "women": 1019, "snowboarder": 1020, "tall": 1021, "monkey": 1022, "mantle": 1023, "fire": 1024, "books": 1025, "quilt": 1026, "cessna": 1027, "chandelier": 1028, "dunkin donuts": 1029, "beans": 1030, "relish": 1031, "no flag": 1032, "parking meter": 1033, "spots": 1034, "ducks": 1035, "sandals": 1036, "doughnut": 1037, "lighthouse": 1038, "yacht": 1039, "german shepherd": 1040, "in middle": 1041, "raw": 1042, "chain": 1043, "2 feet": 1044, "pedestal": 1045, "sauerkraut": 1046, "bagels": 1047, "mutt": 1048, "dog and cat": 1049, "race": 1050, "poor": 1051, "cat and dog": 1052, "station": 1053, "printer": 1054, "daisies": 1055, "front": 1056, "gravel": 1057, "rear": 1058, "grassy": 1059, "pigeons": 1060, "dogs": 1061, "in car": 1062, "life": 1063, "wii remotes": 1064, "suv": 1065, "leather": 1066, "bottom right": 1067, "peace": 1068, "facebook": 1069, "blanket": 1070, "fountain": 1071, "frisbees": 1072, "12:30": 1073, "am": 1074, "scooter": 1075, "going": 1076, "analog": 1077, "america": 1078, "pitbull": 1079, "relaxing": 1080, "paddle boarding": 1081, "white and pink": 1082, "shampoo": 1083, "alps": 1084, "ride": 1085, "side": 1086, "mane": 1087, "on desk": 1088, "on chair": 1089, "2012": 1090, "multi": 1091, "straight": 1092, "big ben": 1093, "closed": 1094, "frosted": 1095, "3 feet": 1096, "waves": 1097, "buoy": 1098, "life vest": 1099, "trash can": 1100, "medium": 1101, "boxer": 1102, "very tall": 1103, "yamaha": 1104, "sunlight": 1105, "hit ball": 1106, "dry": 1107, "coke": 1108, "gym": 1109, "orange and black": 1110, "center": 1111, "rope": 1112, "flip flops": 1113, "4th of july": 1114, "siamese": 1115, "crafts": 1116, "color": 1117, "italian": 1118, "playing frisbee": 1119, "skate park": 1120, "orange juice": 1121, "windowsill": 1122, "corgi": 1123, "thumb": 1124, "peanut butter": 1125, "pie": 1126, "toast": 1127, "no hat": 1128, "benches": 1129, "diamond": 1130, "blender": 1131, "avocado": 1132, "television": 1133, "speakers": 1134, "pony": 1135, "baseball field": 1136, "pavement": 1137, "sydney": 1138, "not there": 1139, "diamonds": 1140, "4 feet": 1141, "goalie": 1142, "soccer ball": 1143, "runway": 1144, "video game": 1145, "gaming": 1146, "casual": 1147, "green and white": 1148, "toilet brush": 1149, "working": 1150, "pickup": 1151, "girls": 1152, "remotes": 1153, "pasta": 1154, "hood": 1155, "braves": 1156, "skier": 1157, "motorola": 1158, "17": 1159, "b": 1160, "100": 1161, "diet coke": 1162, "hospital": 1163, "wagon": 1164, "milk": 1165, "ferry": 1166, "rainbow": 1167, "on bed": 1168, "toward": 1169, "1:30": 1170, "19": 1171, "security": 1172, "herself": 1173, "mercedes benz": 1174, "supreme": 1175, "thin": 1176, "platform": 1177, "gray and red": 1178, "thai": 1179, "storage": 1180, "thailand": 1181, "swan": 1182, "peach": 1183, "10:05": 1184, "dome": 1185, "chiquita": 1186, "2:00": 1187, "mountain dew": 1188, "23": 1189, "knives": 1190, "street sign": 1191, "on beach": 1192, "playing wii": 1193, "using laptop": 1194, "stickers": 1195, "yogurt": 1196, "on grass": 1197, "9:50": 1198, "9:45": 1199, "sweat": 1200, "gatorade": 1201, "umpire": 1202, "37": 1203, "transport": 1204, "desktop": 1205, "desserts": 1206, "main": 1207, "boston": 1208, "fell": 1209, "top right": 1210, "case": 1211, "asleep": 1212, "over": 1213, "9:55": 1214, "grapefruit": 1215, "breakfast": 1216, "headphones": 1217, "freight": 1218, "cup": 1219, "sweatband": 1220, "nobody": 1221, "lamps": 1222, "9:25": 1223, "scarf": 1224, "on fridge": 1225, "main st": 1226, "moving": 1227, "confused": 1228, "fresh": 1229, "kiting": 1230, "blue jay": 1231, "flats": 1232, "long time": 1233, "chihuahua": 1234, "ceramic": 1235, "mushrooms": 1236, "on plate": 1237, "human": 1238, "power lines": 1239, "hotel": 1240, "map": 1241, "earring": 1242, "boarding": 1243, "display": 1244, "warm": 1245, "napkins": 1246, "brown and black": 1247, "broom": 1248, "basketball": 1249, "papers": 1250, "holding baby": 1251, "sad": 1252, "kickstand": 1253, "60": 1254, "shoulder": 1255, "sleep": 1256, "footprints": 1257, "tunnel": 1258, "1990": 1259, "hats": 1260, "6 inches": 1261, "ham": 1262, "bacon": 1263, "church": 1264, "53": 1265, "pineapple": 1266, "at camera": 1267, "red bull": 1268, "pilot": 1269, "tattoo": 1270, "work": 1271, "polar bear": 1272, "taking off": 1273, "website": 1274, "22": 1275, "4:00": 1276, "coffee maker": 1277, "fast": 1278, "fur": 1279, "rubber": 1280, "tongs": 1281, "german": 1282, "germany": 1283, "3 inches": 1284, "toy": 1285, "3:20": 1286, "calm": 1287, "pots": 1288, "balloons": 1289, "fruits": 1290, "9:20": 1291, "drawer": 1292, "oven": 1293, "soup": 1294, "stove": 1295, "heels": 1296, "wind": 1297, "island": 1298, "blood": 1299, "leg": 1300, "theater": 1301, "tennis racquet": 1302, "21": 1303, "gothic": 1304, "2:35": 1305, "wii remote": 1306, "turning": 1307, "20 feet": 1308, "pink and black": 1309, "ears": 1310, "fun": 1311, "wreath": 1312, "to right": 1313, "child": 1314, "fly": 1315, "head": 1316, "drywall": 1317, "shorter": 1318, "pier": 1319, "feeding giraffe": 1320, "in vase": 1321, "burger": 1322, "easter": 1323, "onion": 1324, "uniform": 1325, "remote control": 1326, "guitar": 1327, "time": 1328, "verizon": 1329, "tomatoes": 1330, "ship": 1331, "tulips": 1332, "glaze": 1333, "on suitcase": 1334, "tent": 1335, "1:45": 1336, "market": 1337, "bnsf": 1338, "bandana": 1339, "still": 1340, "don't know": 1341, "piano": 1342, "mouth": 1343, "run": 1344, "sparrow": 1345, "throw": 1346, "lines": 1347, "vest": 1348, "1950": 1349, "jet": 1350, "sepia": 1351, "2015": 1352, "busy": 1353, "lighter": 1354, "dessert": 1355, "bending": 1356, "75": 1357, "finch": 1358, "pastries": 1359, "outdoors": 1360, "bakery": 1361, "clean": 1362, "ipod": 1363, "tablecloth": 1364, "cigarettes": 1365, "looking at phone": 1366, "in front": 1367, "food truck": 1368, "face": 1369, "swinging": 1370, "safari": 1371, "500": 1372, "volkswagen": 1373, "2010": 1374, "shape": 1375, "shelves": 1376, "riding horses": 1377, "2016": 1378, "behind bus": 1379, "towels": 1380, "lemon": 1381, "straw": 1382, "bamboo": 1383, "5 feet": 1384, "hardwood": 1385, "oregon": 1386, "schnauzer": 1387, "organic": 1388, "h": 1389, "kid": 1390, "meter": 1391, "61": 1392, "charging": 1393, "bald": 1394, "caucasian": 1395, "man on left": 1396, "stand": 1397, "27": 1398, "dining room": 1399, "sandwiches": 1400, "32": 1401, "apartment": 1402, "tower": 1403, "virgin": 1404, "out": 1405, "white and red": 1406, "2:05": 1407, "i don't know": 1408, "chains": 1409, "legs": 1410, "age": 1411, "goats": 1412, "s": 1413, "congratulations": 1414, "dresser": 1415, "camper": 1416, "half": 1417, "silverware": 1418, "decorative": 1419, "hawaiian": 1420, "petting horse": 1421, "wheel": 1422, "florida": 1423, "reds": 1424, "washington dc": 1425, "moon": 1426, "conference": 1427, "screen": 1428, "controller": 1429, "robin": 1430, "men": 1431, "protection": 1432, "roll": 1433, "harley davidson": 1434, "coal": 1435, "mustache": 1436, "smiling": 1437, "pedestrians": 1438, "88": 1439, "me": 1440, "tray": 1441, "males": 1442, "monitor": 1443, "bell": 1444, "landscape": 1445, "club": 1446, "toothpick": 1447, "seagulls": 1448, "bowtie": 1449, "lake": 1450, "steam": 1451, "surf": 1452, "baseball glove": 1453, "blinders": 1454, "woods": 1455, "stuffed": 1456, "sunbathing": 1457, "shearing": 1458, "dad": 1459, "mixer": 1460, "pot": 1461, "blending": 1462, "identification": 1463, "owl": 1464, "wine glass": 1465, "on bike": 1466, "billabong": 1467, "new york": 1468, "yarn": 1469, "tube": 1470, "tennis ball": 1471, "2:55": 1472, "ice cream": 1473, "chevrolet": 1474, "shirt and tie": 1475, "taking selfie": 1476, "blue and green": 1477, "he isn't": 1478, "cutting cake": 1479, "east": 1480, "setting": 1481, "brewers": 1482, "riding bikes": 1483, "7 eleven": 1484, "stars": 1485, "jockey": 1486, "jacket": 1487, "standing still": 1488, "book": 1489, "gray and white": 1490, "pen": 1491, "red white blue": 1492, "above": 1493, "alaska": 1494, "tongue": 1495, "feathers": 1496, "k": 1497, "camping": 1498, "pasture": 1499, "corner": 1500, "away": 1501, "ski": 1502, "texas": 1503, "fire truck": 1504, "sailboats": 1505, "jump": 1506, "walk": 1507, "spray paint": 1508, "loading": 1509, "united": 1510, "1000": 1511, "brushing his teeth": 1512, "roman numerals": 1513, "garlic": 1514, "surprise": 1515, "3rd": 1516, "first": 1517, "side of road": 1518, "dodgers": 1519, "airplanes": 1520, "unsure": 1521, "russian": 1522, "wet": 1523, "skyscraper": 1524, "5 star": 1525, "brushing her teeth": 1526, "blankets": 1527, "natural": 1528, "across street": 1529, "smartphone": 1530, "duck": 1531, "sausage": 1532, "paris": 1533, "newspaper": 1534, "pants": 1535, "spices": 1536, "pillow": 1537, "to left": 1538, "snowboards": 1539, "colgate": 1540, "on elephant": 1541, "string": 1542, "horns": 1543, "2:40": 1544, "men's": 1545, "cobblestone": 1546, "regular": 1547, "staring": 1548, "28": 1549, "barber shop": 1550, "linoleum": 1551, "grind": 1552, "cut": 1553, "x": 1554, "above sink": 1555, "above stove": 1556, "dishes": 1557, "dalmatian": 1558, "watching": 1559, "glazed": 1560, "5:25": 1561, "j": 1562, "messy": 1563, "wallet": 1564, "tuna": 1565, "toasted": 1566, "grilled": 1567, "french": 1568, "green and blue": 1569, "sunflowers": 1570, "to catch frisbee": 1571, "wool": 1572, "sprint": 1573, "no grass": 1574, "cabinet": 1575, "shell": 1576, "foil": 1577, "bottles": 1578, "bar": 1579, "king": 1580, "paper towels": 1581, "friends": 1582, "beagle": 1583, "school bus": 1584, "laptops": 1585, "snowing": 1586, "cement": 1587, "pc": 1588, "accident": 1589, "stuffed animal": 1590, "wakeboard": 1591, "balance": 1592, "in suitcase": 1593, "white and black": 1594, "nikon": 1595, "cleats": 1596, "on sink": 1597, "pool": 1598, "mom": 1599, "downtown": 1600, "asian": 1601, "heater": 1602, "bathing": 1603, "193": 1604, "against wall": 1605, "canopy": 1606, "jungle": 1607, "berries": 1608, "military": 1609, "pickle": 1610, "clams": 1611, "seafood": 1612, "in box": 1613, "boats": 1614, "tables": 1615, "lizard": 1616, "lemonade": 1617, "m": 1618, "soft": 1619, "illinois": 1620, "country": 1621, "for sale": 1622, "arm": 1623, "listening": 1624, "curly": 1625, "play tennis": 1626, "hands": 1627, "cereal": 1628, "blue and red": 1629, "robe": 1630, "around neck": 1631, "red and silver": 1632, "soap": 1633, "trains": 1634, "throwing frisbee": 1635, "smoking": 1636, "india": 1637, "headband": 1638, "not very": 1639, "westin": 1640, "serve": 1641, "bicycles": 1642, "can't tell": 1643, "to catch ball": 1644, "visibility": 1645, "ana": 1646, "reins": 1647, "rodeo": 1648, "boot": 1649, "on horse": 1650, "12:35": 1651, "riding motorcycle": 1652, "mexico": 1653, "mother": 1654, "african": 1655, "left and right": 1656, "button": 1657, "earrings": 1658, "blackberry": 1659, "cell": 1660, "10:00": 1661, "harness": 1662, "pillows": 1663, "vegetable": 1664, "tablet": 1665, "fern": 1666, "cats": 1667, "golden retriever": 1668, "goat": 1669, "tractor": 1670, "valentine's day": 1671, "hearts": 1672, "khaki": 1673, "man on right": 1674, "mcdonald's": 1675, "player": 1676, "arriving": 1677, "husky": 1678, "on skateboard": 1679, "vases": 1680, "coat": 1681, "beanie": 1682, "coming": 1683, "granite": 1684, "shopping cart": 1685, "it's raining": 1686, "sports": 1687, "leash": 1688, "balls": 1689, "blurry": 1690, "baseball bat": 1691, "team": 1692, "mango": 1693, "mug": 1694, "eiffel tower": 1695, "worms": 1696, "trash": 1697, "robot": 1698, "show": 1699, "terrier": 1700, "painting": 1701, "rooster": 1702, "42": 1703, "jones": 1704, "state farm": 1705, "balloon": 1706, "trunk": 1707, "coach": 1708, "t": 1709, "playing game": 1710, "fireplace": 1711, "behind clouds": 1712, "uphill": 1713, "motocross": 1714, "sony": 1715, "magazine": 1716, "kitesurfing": 1717, "catching frisbee": 1718, "catch frisbee": 1719, "bud light": 1720, "drive": 1721, "fighting": 1722, "1 on left": 1723, "very old": 1724, "hallway": 1725, "lexus": 1726, "wii controller": 1727, "9:15": 1728, "fast food": 1729, "5:45": 1730, "catholic": 1731, "muffin": 1732, "traffic light": 1733, "band": 1734, "button up": 1735, "grocery": 1736, "shelf": 1737, "2:25": 1738, "honey": 1739, "plants": 1740, "oars": 1741, "foggy": 1742, "nathan's": 1743, "cord": 1744, "yard": 1745, "48": 1746, "donut shop": 1747, "chimney": 1748, "calico": 1749, "suits": 1750, "sideways": 1751, "animals": 1752, "black and blue": 1753, "bikini": 1754, "photographer": 1755, "700": 1756, "queen": 1757, "1:00": 1758, "12:05": 1759, "horseback riding": 1760, "awake": 1761, "bunny": 1762, "12:00": 1763, "continental": 1764, "flamingo": 1765, "rye": 1766, "family": 1767, "lots": 1768, "owner": 1769, "stew": 1770, "palm tree": 1771, "cruise ship": 1772, "56": 1773, "design": 1774, "ny": 1775, "far right": 1776, "tire": 1777, "younger": 1778, "biking": 1779, "at&t": 1780, "giants": 1781, "marshmallows": 1782, "caramel": 1783, "polo": 1784, "emirates": 1785, "salon": 1786, "focus": 1787, "on motorcycle": 1788, "magnets": 1789, "mat": 1790, "ivy": 1791, "cakes": 1792, "chrome": 1793, "bob": 1794, "asia": 1795, "graduation": 1796, "cauliflower": 1797, "in snow": 1798, "c": 1799, "rough": 1800, "vacation": 1801, "air": 1802, "windy": 1803, "victoria": 1804, "4:45": 1805, "trick": 1806, "coconut": 1807, "labrador": 1808, "on left": 1809, "yellow and green": 1810, "butterfly": 1811, "fake": 1812, "on napkin": 1813, "bricks": 1814, "wine glasses": 1815, "detroit": 1816, "man's": 1817, "parsley": 1818, "art": 1819, "subway": 1820, "wave": 1821, "placemat": 1822, "hydrant": 1823, "sofa": 1824, "pigeon": 1825, "riding elephant": 1826, "all": 1827, "branches": 1828, "plant": 1829, "to eat": 1830, "zucchini": 1831, "feta": 1832, "neon": 1833, "mouse pad": 1834, "cloud": 1835, "toilet paper": 1836, "pumpkin": 1837, "rowing": 1838, "toronto": 1839, "handicap": 1840, "seeds": 1841, "fly kite": 1842, "chicago": 1843, "marble": 1844, "frame": 1845, "150": 1846, "rocky": 1847, "give way": 1848, "sauce": 1849, "it's not": 1850, "control": 1851, "high chair": 1852, "playstation": 1853, "xbox": 1854, "not likely": 1855, "roman": 1856, "land": 1857, "1:35": 1858, "lifeguard": 1859, "on pizza": 1860, "size": 1861, "bull": 1862, "dandelions": 1863, "equestrian": 1864, "goose": 1865, "8 feet": 1866, "recessed": 1867, "statue": 1868, "index": 1869, "phillies": 1870, "strike": 1871, "mirrors": 1872, "pointing": 1873, "farmer": 1874, "collie": 1875, "motorbike": 1876, "lanes": 1877, "bikes": 1878, "biker": 1879, "arrows": 1880, "gas station": 1881, "logs": 1882, "smaller": 1883, "desert": 1884, "yield": 1885, "flags": 1886, "stool": 1887, "kitten": 1888, "doll": 1889, "daffodils": 1890, "letters": 1891, "dishwasher": 1892, "first base": 1893, "nuts": 1894, "2013": 1895, "persian": 1896, "swim trunks": 1897, "deep": 1898, "o": 1899, "doubles": 1900, "toothpicks": 1901, "in field": 1902, "wristband": 1903, "wheels": 1904, "baking": 1905, "4:15": 1906, "11:00": 1907, "ear": 1908, "2007": 1909, "51": 1910, "chevy": 1911, "using computer": 1912, "frog": 1913, "storm": 1914, "boogie board": 1915, "hungry": 1916, "by window": 1917, "ambulance": 1918, "pigtails": 1919, "audi": 1920, "microsoft": 1921, "on man": 1922, "cannot tell": 1923, "stained glass": 1924, "hugging": 1925, "laying down": 1926, "3:00": 1927, "taxi": 1928, "pedestrian": 1929, "landing": 1930, "numbers": 1931, "38": 1932, "stones": 1933, "on tree": 1934, "clocks": 1935, "new": 1936, "picnic": 1937, "fog": 1938, "buffalo": 1939, "under armour": 1940, "cocker spaniel": 1941, "orioles": 1942, "no sign": 1943, "telling time": 1944, "bags": 1945, "golden gate": 1946, "cover": 1947, "castle": 1948, "canoe": 1949, "selfie": 1950, "cream": 1951, "floating": 1952, "indoor": 1953, "antique": 1954, "aluminum": 1955, "silver and black": 1956, "cast iron": 1957, "peas": 1958, "sun hat": 1959, "on right": 1960, "swiss": 1961, "flour": 1962, "under sink": 1963, "fashion": 1964, "fedora": 1965, "shells": 1966, "1 hour": 1967, "puppy": 1968, "in stands": 1969, "not here": 1970, "motor": 1971, "thousands": 1972, "120": 1973, "sail": 1974, "butt": 1975, "mexican": 1976, "dead end": 1977, "paddle": 1978, "bathing suit": 1979, "shop": 1980, "onion rings": 1981, "boxing": 1982, "birthday cake": 1983, "chalk": 1984, "scenery": 1985, "style": 1986, "nissan": 1987, "sticker": 1988, "on rack": 1989, "1 4": 1990, "woman's": 1991, "surprised": 1992, "north face": 1993, "squash": 1994, "not sure": 1995, "email": 1996, "spotted": 1997, "seat": 1998, "himself": 1999, "circles": 2000, "san diego": 2001, "kia": 2002, "mattress": 2003, "obama": 2004, "lamb": 2005, "american flag": 2006, "climbing": 2007, "skull and crossbones": 2008, "roast beef": 2009, "visor": 2010, "herd": 2011, "double": 2012, "52": 2013, "high": 2014, "stagecoach": 2015, "cart": 2016, "feeding": 2017, "eaten": 2018, "cone": 2019, "11:15": 2020, "smoothie": 2021, "golf": 2022, "colorado": 2023, "electronics": 2024, "5:15": 2025, "bowling": 2026, "players": 2027, "ketchup and mustard": 2028, "styrofoam": 2029, "6 feet": 2030, "hawk": 2031, "cheddar": 2032, "12:28": 2033, "arabic": 2034, "12:25": 2035, "12:10": 2036, "shower curtain": 2037, "army": 2038, "salmon": 2039, "10:40": 2040, "hanging": 2041, "whole": 2042, "behind fence": 2043, "bars": 2044, "moss": 2045, "no dog": 2046, "traffic": 2047, "10:25": 2048, "r": 2049, "countryside": 2050, "machine": 2051, "directions": 2052, "cooked": 2053, "aa": 2054, "6:45": 2055, "4 way": 2056, "stripe": 2057, "brand": 2058, "baseball player": 2059, "bunk": 2060, "coleslaw": 2061, "fishing boat": 2062, "at table": 2063, "europe": 2064, "dead": 2065, "arch": 2066, "scrambled": 2067, "clothing": 2068, "closet": 2069, "egg": 2070, "suitcases": 2071, "indoors": 2072, "coffee pot": 2073, "tires": 2074, "lilies": 2075, "cafe": 2076, "9:35": 2077, "teal": 2078, "toothpaste": 2079, "in background": 2080, "tarmac": 2081, "painted": 2082, "sunset": 2083, "orange and yellow": 2084, "oar": 2085, "peaches": 2086, "zebra and giraffe": 2087, "ladybug": 2088, "20 ft": 2089, "sesame seeds": 2090, "hills": 2091, "2:30": 2092, "stucco": 2093, "tail": 2094, "couple": 2095, "kawasaki": 2096, "smooth": 2097, "powdered sugar": 2098, "pedestrian crossing": 2099, "french fries": 2100, "picnic table": 2101, "teeth": 2102, "ribbon": 2103, "saddle": 2104, "15 feet": 2105, "earbuds": 2106, "on train": 2107, "39": 2108, "curb": 2109, "tow": 2110, "shark": 2111, "white and orange": 2112, "6:25": 2113, "gravy": 2114, "fork and spoon": 2115, "pooping": 2116, "curtain": 2117, "lime": 2118, "skull": 2119, "crossing": 2120, "speed limit": 2121, "peacock": 2122, "boredom": 2123, "neck": 2124, "hit": 2125, "dragon": 2126, "tissues": 2127, "basil": 2128, "waving": 2129, "blue team": 2130, "rectangles": 2131, "helicopter": 2132, "mud": 2133, "us": 2134, "balcony": 2135, "red and gray": 2136, "firefighter": 2137, "sunflower": 2138, "wallpaper": 2139, "best buy": 2140, "11:20": 2141, "public market center": 2142, "seattle": 2143, "bookshelf": 2144, "looking": 2145, "1 inch": 2146, "harley": 2147, "urinal": 2148, "cartoon": 2149, "t shirt and jeans": 2150, "navy": 2151, "fedex": 2152, "rays": 2153, "deck": 2154, "coaster": 2155, "1:20": 2156, "50 feet": 2157, "4:20": 2158, "us open": 2159, "looking at camera": 2160, "600": 2161, "national express": 2162, "white house": 2163, "5:00": 2164, "jp morgan": 2165, "palm trees": 2166, "tub": 2167, "pens": 2168, "soldiers": 2169, "2 people": 2170, "animal": 2171, "speaker": 2172, "hamburger": 2173, "spaghetti": 2174, "green beans": 2175, "it isn't": 2176, "10:20": 2177, "buildings": 2178, "on shelf": 2179, "baseball uniform": 2180, "tiled": 2181, "orange and blue": 2182, "90": 2183, "north america": 2184, "arrow": 2185, "news": 2186, "tropicana": 2187, "formal": 2188, "in grass": 2189, "thumbs up": 2190, "clip": 2191, "gate": 2192, "tennis player": 2193, "lilac": 2194, "pastry": 2195, "nose": 2196, "pacifier": 2197, "11:35": 2198, "different teams": 2199, "cardinals": 2200, "exhaust": 2201, "hauling": 2202, "on tray": 2203, "bagel": 2204, "huge": 2205, "out of focus": 2206, "cook": 2207, "wheat": 2208, "photo": 2209, "ghost": 2210, "sedan": 2211, "qatar": 2212, "zig zag": 2213, "lanyard": 2214, "pink and white": 2215, "sesame": 2216, "space": 2217, "no clock": 2218, "warning": 2219, "snowy": 2220, "tater tots": 2221, "tropical": 2222, "grandfather": 2223, "mac": 2224, "magnet": 2225, "photoshop": 2226, "pajamas": 2227, "350": 2228, "casserole": 2229, "4:55": 2230, "pelican": 2231, "2009": 2232, "clydesdale": 2233, "tow truck": 2234, "belt": 2235, "west": 2236, "omelet": 2237, "heavy": 2238, "crown": 2239, "in corner": 2240, "hexagon": 2241, "mound": 2242, "iris": 2243, "g": 2244, "12:45": 2245, "2:15": 2246, "3:10": 2247, "drawing": 2248, "only": 2249, "little girl": 2250, "washing": 2251, "nokia": 2252, "windsor": 2253, "2 men": 2254, "parmesan cheese": 2255, "on woman": 2256, "freezer": 2257, "icing": 2258, "venice": 2259, "dairy": 2260, "several": 2261, "concentration": 2262, "3:15": 2263, "no smoking": 2264, "kayak": 2265, "frosting": 2266, "jetblue": 2267, "thoroughbred": 2268, "parakeet": 2269, "shoe": 2270, "skeleton": 2271, "britain": 2272, "ties": 2273, "in sink": 2274, "patio": 2275, "bank": 2276, "camouflage": 2277, "privacy": 2278, "bib": 2279, "blue and gray": 2280, "looking out window": 2281, "falling": 2282, "bucket": 2283, "cupcakes": 2284, "throw ball": 2285, "garden": 2286, "almonds": 2287, "ducati": 2288, "ireland": 2289, "plastic wrap": 2290, "starbucks": 2291, "all way": 2292, "bark": 2293, "home plate": 2294, "base": 2295, "dog food": 2296, "toys": 2297, "blue and orange": 2298, "1 in front": 2299, "foot": 2300, "dc": 2301, "california": 2302, "towing": 2303, "cheesecake": 2304, "bushes": 2305, "bow tie": 2306, "millions": 2307, "down street": 2308, "2011": 2309, "police officer": 2310, "windmill": 2311, "taking pictures": 2312, "street name": 2313, "cleaning": 2314, "on pole": 2315, "russia": 2316, "main street": 2317, "catch ball": 2318, "mario": 2319, "pirate": 2320, "track": 2321, "garage": 2322, "7:10": 2323, "they aren't": 2324, "mother and child": 2325, "tents": 2326, "fancy": 2327, "tattoos": 2328, "alcohol": 2329, "2:45": 2330, "wheelchair": 2331, "money": 2332, "top hat": 2333, "willow": 2334, "cd": 2335, "brushing hair": 2336, "pancake": 2337, "80": 2338, "listening to music": 2339, "green and red": 2340, "barrier": 2341, "vests": 2342, "hiking": 2343, "tank top": 2344, "lufthansa": 2345, "student": 2346, "menu": 2347, "forehand": 2348, "wii controllers": 2349, "acer": 2350, "wall st": 2351, "hundreds": 2352, "water ski": 2353, "furniture": 2354, "paisley": 2355, "pizza hut": 2356, "baseball game": 2357, "hill": 2358, "prom": 2359, "1 world": 2360, "tiara": 2361, "students": 2362, "information": 2363, "hazy": 2364, "nasa": 2365, "canon": 2366, "bird feeder": 2367, "crane": 2368, "dr pepper": 2369, "logitech": 2370, "2:10": 2371, "all of them": 2372, "utensils": 2373, "telephone": 2374, "converse": 2375, "bone": 2376, "jeep": 2377, "nursing": 2378, "krispy kreme": 2379, "cameraman": 2380, "pee": 2381, "ranch": 2382, "polka dots": 2383, "railroad crossing": 2384, "shirts": 2385, "feeder": 2386, "above toilet": 2387, "unclear": 2388, "below": 2389, "43": 2390, "spoons": 2391, "calendar": 2392, "vaio": 2393, "fox": 2394, "mint": 2395, "after": 2396, "spiderman": 2397, "lg": 2398, "concert": 2399, "on rock": 2400, "fluffy": 2401, "gray and black": 2402, "coats": 2403, "lady": 2404, "dodge": 2405, "easyjet": 2406, "pearl": 2407, "bunt": 2408, "flat screen": 2409, "10:30": 2410, "music": 2411, "polar bears": 2412, "riding horse": 2413, "lift": 2414, "angry": 2415, "cookies": 2416, "3:45": 2417, "buttons": 2418, "hot": 2419, "cute": 2420, "behind": 2421, "dole": 2422, "in motion": 2423, "26": 2424, "pans": 2425, "love": 2426, "winnie pooh": 2427, "pear": 2428, "copyright": 2429, "2 hours": 2430, "snowsuit": 2431, "kissing": 2432, "backhand": 2433, "to get to other side": 2434, "metro": 2435, "swans": 2436, "very fast": 2437, "can't see it": 2438, "nintendo": 2439, "direction": 2440, "waiting": 2441, "mohawk": 2442, "st patrick's day": 2443, "rail": 2444, "hoodie": 2445, "feet": 2446, "swirls": 2447, "muffins": 2448, "4:05": 2449, "106": 2450, "10:55": 2451, "coins": 2452, "mitt": 2453, "game controller": 2454, "room": 2455, "adults": 2456, "urinals": 2457, "cameras": 2458, "marker": 2459, "upright": 2460, "brass": 2461, "sled": 2462, "teacher": 2463, "conductor": 2464, "farmers market": 2465, "toiletries": 2466, "blue and black": 2467, "soccer field": 2468, "banana peel": 2469, "sprite": 2470, "doughnuts": 2471, "bank of america": 2472, "on his face": 2473, "heat": 2474, "emergency": 2475, "ski slope": 2476, "hard": 2477, "41": 2478, "6:00": 2479, "in his hand": 2480, "cluttered": 2481, "dog show": 2482, "on boat": 2483, "grizzly": 2484, "drums": 2485, "not": 2486, "in hand": 2487, "easy": 2488, "400": 2489, "under table": 2490, "d": 2491, "hitting ball": 2492, "photography": 2493, "intersection": 2494, "backwards": 2495, "crocs": 2496, "marina": 2497, "chips": 2498, "bible": 2499, "harry potter": 2500, "hawaii": 2501, "fanta": 2502, "half full": 2503, "carriage": 2504, "curious": 2505, "12:50": 2506, "black white": 2507, "geese": 2508, "pork": 2509, "mailbox": 2510, "l": 2511, "sidecar": 2512, "poop": 2513, "wings": 2514, "penguin": 2515, "to see": 2516, "pocket": 2517, "steps": 2518, "cubs": 2519, "junk": 2520, "deer": 2521, "ottoman": 2522, "salt": 2523, "condiments": 2524, "1:55": 2525, "post": 2526, "bulldog": 2527, "notebook": 2528, "no cat": 2529, "champagne": 2530, "jets": 2531, "knee pads": 2532, "throw frisbee": 2533, "drinks": 2534, "leopard": 2535, "taller": 2536, "cooler": 2537, "bundt": 2538, "monday": 2539, "grape": 2540, "wine tasting": 2541, "under": 2542, "baskets": 2543, "santa hat": 2544, "chest": 2545, "sewing": 2546, "on car": 2547, "sony ericsson": 2548, "peeing": 2549, "for photo": 2550, "tour": 2551, "few": 2552, "singapore": 2553, "fireman": 2554, "fire extinguisher": 2555, "wildebeest": 2556, "lemons": 2557, "peanuts": 2558, "babies": 2559, "wiimote": 2560, "guitar hero": 2561, "slide": 2562, "stopped": 2563, "library": 2564, "multi colored": 2565, "blue and pink": 2566, "choppy": 2567, "sailing": 2568, "brush": 2569, "grinding": 2570, "jelly": 2571, "dairy queen": 2572, "shaking hands": 2573, "ge": 2574, "tigers": 2575, "tokyo": 2576, "philadelphia": 2577, "ski boots": 2578, "buses": 2579, "11:45": 2580, "collage": 2581, "pink and blue": 2582, "jesus": 2583, "singles": 2584, "iron": 2585, "coffee table": 2586, "2 years": 2587, "don't walk": 2588, "classroom": 2589, "on water": 2590, "potato salad": 2591, "posts": 2592, "harbor": 2593, "residential": 2594, "joshua": 2595, "uk": 2596, "burgers": 2597, "deli": 2598, "kicking": 2599, "lace": 2600, "overalls": 2601, "vehicles": 2602, "ram": 2603, "dancing": 2604, "47": 2605, "shed": 2606, "lid": 2607, "he's not": 2608, "fans": 2609, "amtrak": 2610, "space shuttle": 2611, "ostrich": 2612, "bathtub": 2613, "kneeling": 2614, "2:50": 2615, "mall": 2616, "yellow and orange": 2617, "gazebo": 2618, "wax": 2619, "slow down": 2620, "lays": 2621, "hammer time": 2622, "octopus": 2623, "crib": 2624, "banana split": 2625, "broadway": 2626, "pottery": 2627, "wavy": 2628, "farmers": 2629, "holding phone": 2630, "on phone": 2631, "squirrel": 2632, "wax paper": 2633, "tusks": 2634, "dining": 2635, "packing": 2636, "kangaroo": 2637, "dawn": 2638, "defense": 2639, "powdered": 2640, "thomas": 2641, "budweiser": 2642, "back left": 2643, "stir fry": 2644, "beijing": 2645, "11:10": 2646, "tripod": 2647, "wide": 2648, "slope": 2649, "black and gray": 2650, "planter": 2651, "chili": 2652, "siblings": 2653, "kayaking": 2654, "captivity": 2655, "opaque": 2656, "rack": 2657, "panda": 2658, "doorway": 2659, "wheelie": 2660, "pelicans": 2661, "genetics": 2662, "not in service": 2663, "volvo": 2664, "dachshund": 2665, "v": 2666, "on laptop": 2667, "western": 2668, "gone": 2669, "birthday party": 2670, "parking garage": 2671, "tying tie": 2672, "blueberry": 2673, "scale": 2674, "notes": 2675, "train car": 2676, "man made": 2677, "stability": 2678, "lily": 2679, "lying down": 2680, "pacific": 2681, "high heels": 2682, "pare": 2683, "checkerboard": 2684, "partly cloudy": 2685, "cool": 2686, "n": 2687, "toilets": 2688, "tree branch": 2689, "copper": 2690, "cycling": 2691, "5:50": 2692, "870": 2693, "shopping": 2694, "7:05": 2695, "zipper": 2696, "holding umbrella": 2697, "batman": 2698, "lotion": 2699, "1:25": 2700, "black and brown": 2701, "playing video game": 2702, "girl on right": 2703, "legos": 2704, "drinking water": 2705, "burrito": 2706, "plow": 2707, "jet ski": 2708, "spiral": 2709, "ibm": 2710, "tools": 2711, "flashlight": 2712, "cherries": 2713, "maple leaf": 2714, "mountainous": 2715, "under tree": 2716, "vines": 2717, "sushi": 2718, "baker": 2719, "snake": 2720, "globe": 2721, "target": 2722, "john": 2723, "pomeranian": 2724, "tuxedo": 2725, "hockey": 2726, "sleeve": 2727, "leaning": 2728, "wireless": 2729, "11:05": 2730, "compaq": 2731, "do not enter": 2732, "radish": 2733, "1:05": 2734, "dim": 2735, "advertisement": 2736, "movement": 2737, "model": 2738, "hammock": 2739, "swing": 2740, "sheet": 2741, "google": 2742, "boardwalk": 2743, "right 1": 2744, "haircut": 2745, "ankle": 2746, "3:30": 2747, "exit": 2748, "csx": 2749, "tim hortons": 2750, "lego": 2751, "cucumbers": 2752, "angel": 2753, "12:20": 2754, "racquet": 2755, "behind woman": 2756, "potato": 2757, "egg salad": 2758, "controllers": 2759, "recliner": 2760, "upside down": 2761, "mosaic": 2762, "before": 2763, "antenna": 2764, "3:50": 2765, "10:15": 2766, "lion": 2767, "camo": 2768, "fighter": 2769, "silver and red": 2770, "dirt bike": 2771, "playing video games": 2772, "used": 2773, "crates": 2774, "horizontally": 2775, "plunger": 2776, "refrigerators": 2777, "radiator": 2778, "stork": 2779, "in basket": 2780, "cap": 2781, "living": 2782, "married": 2783, "briefcase": 2784, "bottom left": 2785, "30 mph": 2786, "ascending": 2787, "flip phone": 2788, "101": 2789, "11:50": 2790, "gun": 2791, "arizona": 2792, "foam": 2793, "serious": 2794, "y": 2795, "close up": 2796, "pancakes": 2797, "heineken": 2798, "paw": 2799, "cnn": 2800, "comforter": 2801, "sheets": 2802, "8:35": 2803, "driveway": 2804, "fair": 2805, "cleaner": 2806, "1 year": 2807, "delivery": 2808, "commuter": 2809, "apple and banana": 2810, "chase": 2811, "72": 2812, "safe": 2813, "trucks": 2814, "trunks": 2815, "spider": 2816, "64": 2817, "slacks": 2818, "meeting": 2819, "7:00": 2820, "skiers": 2821, "shaved": 2822, "carrot cake": 2823, "holding": 2824, "surfers": 2825, "giraffe and zebra": 2826, "7:45": 2827, "mississippi": 2828, "seaweed": 2829, "black and pink": 2830, "horse racing": 2831, "orchid": 2832, "rv": 2833, "tourist": 2834, "above door": 2835, "leaving": 2836, "pitch": 2837, "crest": 2838, "miami": 2839, "asics": 2840, "flood": 2841, "bus station": 2842, "take off": 2843, "amazon": 2844, "practice": 2845, "entering": 2846, "diesel": 2847, "pm": 2848, "wetsuits": 2849, "remodeling": 2850, "porch": 2851, "7:35": 2852, "tie dye": 2853, "baked": 2854, "life jacket": 2855, "cylinder": 2856, "grilled cheese": 2857, "meatballs": 2858, "paddling": 2859, "banana bread": 2860, "monster": 2861, "smiley face": 2862, "not high": 2863, "keys": 2864, "dreadlocks": 2865, "kitchenaid": 2866, "straight ahead": 2867, "badminton": 2868, "long sleeve": 2869, "sheepdog": 2870, "5:18": 2871, "end": 2872, "on shore": 2873, "scratching": 2874, "oriental": 2875, "5:05": 2876, "alligator": 2877, "city bus": 2878, "purple and white": 2879, "10:50": 2880, "each other": 2881, "weeds": 2882, "tinkerbell": 2883, "rottweiler": 2884, "apartments": 2885, "snowflakes": 2886, "stop light": 2887, "sweatshirt": 2888, "shore": 2889, "bidet": 2890, "switzerland": 2891, "stretching": 2892, "tv stand": 2893, "boundaries": 2894, "65": 2895, "bronze": 2896, "jar": 2897, "middle 1": 2898, "54": 2899, "skate": 2900, "easton": 2901, "turn right": 2902, "raspberries": 2903, "singing": 2904, "on bus": 2905, "carnations": 2906, "descending": 2907, "classic": 2908, "suspenders": 2909, "not long": 2910, "8:50": 2911, "father": 2912, "anniversary": 2913, "hsbc": 2914, "very long": 2915, "space needle": 2916, "skatepark": 2917, "fruit salad": 2918, "kenmore": 2919, "no water": 2920, "8:05": 2921, "db": 2922, "baby's breath": 2923, "shelter": 2924, "1980": 2925, "no left turn": 2926, "washington monument": 2927, "ham and cheese": 2928, "10 inches": 2929, "8:55": 2930, "savory": 2931, "6:35": 2932, "indians": 2933, "9:05": 2934, "fires": 2935, "pipes": 2936, "donkey": 2937, "cds": 2938, "mitsubishi": 2939, "tell time": 2940, "outfield": 2941, "christian": 2942, "puma": 2943, "parking meters": 2944, "cranes": 2945, "flip": 2946, "wine bottle": 2947, "stadium": 2948, "mouthwash": 2949, "heinz": 2950, "distance": 2951, "macaroni": 2952, "on plane": 2953, "triumph": 2954, "more": 2955, "4:50": 2956, "single engine": 2957, "disney": 2958, "on stove": 2959, "shih tzu": 2960, "fried": 2961, "to hit ball": 2962, "in her hand": 2963, "sunrise": 2964, "2nd": 2965, "elmo": 2966, "kite string": 2967, "suzuki": 2968, "traffic lights": 2969, "blt": 2970, "i": 2971, "hitting": 2972, "htc": 2973, "healthy": 2974, "current": 2975, "star alliance": 2976, "stomach": 2977, "watch tv": 2978, "tulip": 2979, "5:10": 2980, "right side": 2981, "4:40": 2982, "ginger": 2983, "on sign": 2984, "cushion": 2985, "5:30": 2986, "learning": 2987, "pencil": 2988, "maroon": 2989, "food processor": 2990, "5:40": 2991, "dog bed": 2992, "michigan": 2993, "close": 2994, "license plate": 2995, "crows": 2996, "right hand": 2997, "normal": 2998, "green and brown": 2999, "1.00": 3000, "000": 3001, "1:40": 3002, "wing": 3003, "american airlines": 3004, "kodak": 3005, "mural": 3006, "sniffing": 3007, "1:15": 3008, "behind bench": 3009, "cardinal": 3010, "no light": 3011, "warmth": 3012, "paved": 3013, "skyscrapers": 3014, "swinging bat": 3015, "watermark": 3016, "in cup": 3017, "pizza box": 3018, "dough": 3019, "hiding": 3020, "goal": 3021, "no plate": 3022, "shower head": 3023, "ripe": 3024, "1:10": 3025, "1 in back": 3026, "older": 3027, "nest": 3028, "multiple": 3029, "cinnamon": 3030, "bin": 3031, "new orleans": 3032, "colored": 3033, "enclosure": 3034, "bride": 3035, "on dresser": 3036, "star wars": 3037, "in back": 3038, "triangles": 3039, "over easy": 3040, "cilantro": 3041, "statues": 3042, "sticks": 3043, "formica": 3044, "roundabout": 3045, "bowls": 3046, "ahead": 3047, "years": 3048, "drain": 3049, "veggies": 3050, "no shirt": 3051, "taking photo": 3052, "tugboat": 3053, "broke": 3054, "59": 3055, "cadillac": 3056, "prince": 3057, "left side": 3058, "1 in middle": 3059, "10:45": 3060, "drying": 3061, "11:25": 3062, "silk": 3063, "conference room": 3064, "buoys": 3065, "pockets": 3066, "daffodil": 3067, "6:40": 3068, "walgreens": 3069, "4 ft": 3070, "6:05": 3071, "virgin atlantic": 3072, "12:40": 3073, "digital": 3074, "ups": 3075, "westjet": 3076, "bikers": 3077, "us air force": 3078, "limes": 3079, "comcast": 3080, "dip": 3081, "7:55": 3082, "man in middle": 3083, "bus driver": 3084, "soon": 3085, "futon": 3086, "selling": 3087, "braid": 3088, "mariners": 3089, "wisconsin": 3090, "99": 3091, "citizen": 3092, "broccoli and carrots": 3093, "grocery store": 3094, "us airways": 3095, "49": 3096, "bored": 3097, "red velvet": 3098, "hotel room": 3099, "qantas": 3100, "tam": 3101, "korean air": 3102, "10:35": 3103, "whirlpool": 3104, "coffee cup": 3105, "hilly": 3106, "9:12": 3107, "whipped cream": 3108, "video": 3109, "finger": 3110, "competition": 3111, "hollywood": 3112, "sas": 3113, "backward": 3114, "beads": 3115, "cosmo": 3116, "10:08": 3117, "jal": 3118, "6:30": 3119, "100 year party ct": 3120, "hispanic": 3121, "in cabbage town": 3122, "opponent": 3123, "woodpecker": 3124, "visilab": 3125, "mt airy": 3126, "crosstown": 3127, "freightliner": 3128} \ No newline at end of file diff --git a/utils/const.py b/utils/const.py new file mode 100644 index 0000000..f82e205 --- /dev/null +++ b/utils/const.py @@ -0,0 +1,9 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +constants +""" +IMG_DIM = 2048 +IMG_LABEL_DIM = 1601 +BUCKET_SIZE = 8192 diff --git a/utils/distributed.py b/utils/distributed.py new file mode 100644 index 0000000..fce69ea --- /dev/null +++ b/utils/distributed.py @@ -0,0 +1,209 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +distributed API using Horovod +Modified from OpenNMT's native pytorch distributed utils +(https://github.com/OpenNMT/OpenNMT-py) +""" +import math +import pickle + +import torch +from horovod import torch as hvd + + +def all_reduce_and_rescale_tensors(tensors, rescale_denom): + """All-reduce and rescale tensors at once (as a flattened tensor) + + Args: + tensors: list of Tensors to all-reduce + rescale_denom: denominator for rescaling summed Tensors + """ + # buffer size in bytes, determine equiv. # of elements based on data type + sz = sum(t.numel() for t in tensors) + buffer_t = tensors[0].new(sz).zero_() + + # copy tensors into buffer_t + offset = 0 + for t in tensors: + numel = t.numel() + buffer_t[offset:offset+numel].copy_(t.view(-1)) + offset += numel + + # all-reduce and rescale + hvd.allreduce_(buffer_t[:offset]) + buffer_t.div_(rescale_denom) + + # copy all-reduced buffer back into tensors + offset = 0 + for t in tensors: + numel = t.numel() + t.view(-1).copy_(buffer_t[offset:offset+numel]) + offset += numel + + +def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom, + buffer_size=10485760): + """All-reduce and rescale tensors in chunks of the specified size. + + Args: + tensors: list of Tensors to all-reduce + rescale_denom: denominator for rescaling summed Tensors + buffer_size: all-reduce chunk size in bytes + """ + # buffer size in bytes, determine equiv. # of elements based on data type + buffer_t = tensors[0].new( + math.ceil(buffer_size / tensors[0].element_size())).zero_() + buffer = [] + + def all_reduce_buffer(): + # copy tensors into buffer_t + offset = 0 + for t in buffer: + numel = t.numel() + buffer_t[offset:offset+numel].copy_(t.view(-1)) + offset += numel + + # all-reduce and rescale + hvd.allreduce_(buffer_t[:offset]) + buffer_t.div_(rescale_denom) + + # copy all-reduced buffer back into tensors + offset = 0 + for t in buffer: + numel = t.numel() + t.view(-1).copy_(buffer_t[offset:offset+numel]) + offset += numel + + filled = 0 + for t in tensors: + sz = t.numel() * t.element_size() + if sz > buffer_size: + # tensor is bigger than buffer, all-reduce and rescale directly + hvd.allreduce_(t) + t.div_(rescale_denom) + elif filled + sz > buffer_size: + # buffer is full, all-reduce and replace buffer with grad + all_reduce_buffer() + buffer = [t] + filled = sz + else: + # add tensor to buffer + buffer.append(t) + filled += sz + + if len(buffer) > 0: + all_reduce_buffer() + + +def broadcast_tensors(tensors, root_rank, buffer_size=10485760): + """broadcast tensors in chunks of the specified size. + + Args: + tensors: list of Tensors to broadcast + root_rank: rank to broadcast + buffer_size: broadcast chunk size in bytes + """ + # buffer size in bytes, determine equiv. # of elements based on data type + buffer_t = tensors[0].new( + math.ceil(buffer_size / tensors[0].element_size())).zero_() + buffer = [] + + def broadcast_buffer(): + # copy tensors into buffer_t + offset = 0 + for t in buffer: + numel = t.numel() + buffer_t[offset:offset+numel].copy_(t.view(-1)) + offset += numel + + # broadcast + hvd.broadcast_(buffer_t[:offset], root_rank) + + # copy all-reduced buffer back into tensors + offset = 0 + for t in buffer: + numel = t.numel() + t.view(-1).copy_(buffer_t[offset:offset+numel]) + offset += numel + + filled = 0 + for t in tensors: + sz = t.numel() * t.element_size() + if sz > buffer_size: + # tensor is bigger than buffer, broadcast directly + hvd.broadcast_(t, root_rank) + elif filled + sz > buffer_size: + # buffer is full, broadcast and replace buffer with tensor + broadcast_buffer() + buffer = [t] + filled = sz + else: + # add tensor to buffer + buffer.append(t) + filled += sz + + if len(buffer) > 0: + broadcast_buffer() + + +def _encode(enc, max_size, use_max_size=False): + enc_size = len(enc) + enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) + if use_max_size: + # this is used for broadcasting + buffer_ = torch.cuda.ByteTensor(max_size+enc_byte) + else: + buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) + remainder = enc_size + for i in range(enc_byte): + base = 256 ** (enc_byte-i-1) + buffer_[i] = remainder // base + remainder %= base + buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) + return buffer_, enc_byte + + +def _decode(buffer_, enc_byte): + size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() + for i in range(enc_byte)) + bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) + shift = size + enc_byte + return bytes_list, shift + + +_BUFFER_SIZE = 4096 + + +def all_gather_list(data): + """Gathers arbitrary data from all nodes into a list.""" + enc = pickle.dumps(data) + + enc_size = len(enc) + max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item() + in_buffer, enc_byte = _encode(enc, max_size) + + out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size]) + + results = [] + for _ in range(hvd.size()): + bytes_list, shift = _decode(out_buffer, enc_byte) + out_buffer = out_buffer[shift:] + result = pickle.loads(bytes_list) + results.append(result) + return results + + +def any_broadcast(data, root_rank): + """broadcast arbitrary data from root_rank to all nodes.""" + enc = pickle.dumps(data) + + max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item() + buffer_, enc_byte = _encode(enc, max_size, use_max_size=True) + + hvd.broadcast_(buffer_, root_rank) + + bytes_list, _ = _decode(buffer_, enc_byte) + result = pickle.loads(bytes_list) + return result diff --git a/utils/itm_eval.py b/utils/itm_eval.py new file mode 100644 index 0000000..0c76a8e --- /dev/null +++ b/utils/itm_eval.py @@ -0,0 +1,136 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Image Text Retrieval evaluation helper +""" +from time import time +import numpy as np +import torch +from horovod import torch as hvd +from tqdm import tqdm + +from .logger import LOGGER +from .misc import NoOp +from .distributed import all_gather_list + + +@torch.no_grad() +def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts, id2len=None): + # image retrieval + img2j = {i: j for j, i in enumerate(img_ids)} + _, rank_txt = score_matrix.topk(10, dim=1) + + if id2len is not None: + # print('| itm_eval | ids: ', len(ids)) + # print(ids[-10:]) + top5_rank_text_ids = rank_txt[:, :5] + dataset_text_lens = [] + for text_ids in top5_rank_text_ids.data.cpu().numpy(): + text_lens = [] + for text_id in text_ids: + # print('| txt_id: ', txt_id, type(txt_id)) + text_id = int(text_id) + text_lens.append(id2len[txt_ids[text_id]]) + dataset_text_lens.append(text_lens) + dataset_text_lens = np.array(dataset_text_lens) + print('| mean of top5 texts: ', dataset_text_lens.mean()) + + gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]] + for txt_id in txt_ids], + ).to(rank_txt.device + ).unsqueeze(1).expand_as(rank_txt) + rank = (rank_txt == gt_img_j).nonzero() + if rank.numel(): + ir_r1 = (rank < 1).sum().item() / len(txt_ids) + ir_r5 = (rank < 5).sum().item() / len(txt_ids) + ir_r10 = (rank < 10).sum().item() / len(txt_ids) + else: + ir_r1, ir_r5, ir_r10 = 0, 0, 0 + + # text retrieval + txt2i = {t: i for i, t in enumerate(txt_ids)} + _, rank_img = score_matrix.topk(10, dim=0) + tr_r1, tr_r5, tr_r10 = 0, 0, 0 + for j, img_id in enumerate(img_ids): + gt_is = [txt2i[t] for t in img2txts[img_id]] + ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is] + rank = min([10] + [r.item() for r in ranks if r.numel()]) + if rank < 1: + tr_r1 += 1 + if rank < 5: + tr_r5 += 1 + if rank < 10: + tr_r10 += 1 + tr_r1 /= len(img_ids) + tr_r5 /= len(img_ids) + tr_r10 /= len(img_ids) + + tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3 + ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3 + r_mean = (tr_mean + ir_mean) / 2 + + eval_log = {'txt_r1': tr_r1, + 'txt_r5': tr_r5, + 'txt_r10': tr_r10, + 'txt_r_mean': tr_mean, + 'img_r1': ir_r1, + 'img_r5': ir_r5, + 'img_r10': ir_r10, + 'img_r_mean': ir_mean, + 'r_mean': r_mean} + return eval_log + + +@torch.no_grad() +def evaluate(model, eval_loader): + st = time() + LOGGER.info("start running Image/Text Retrieval evaluation ...") + score_matrix = inference(model, eval_loader) + # print('| score_matrix: ', type(score_matrix)) + dset = eval_loader.dataset + all_score = hvd.allgather(score_matrix) + all_txt_ids = [i for ids in all_gather_list(dset.ids) + for i in ids] + all_img_ids = dset.all_img_ids + assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) + if hvd.rank() != 0: + return {} + + # NOTE: only use rank0 to compute final scores + eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, + dset.txt2img, dset.img2txts) + + tot_time = time()-st + LOGGER.info(f"evaluation finished in {int(tot_time)} seconds") + return eval_log + + +@torch.no_grad() +def inference(model, eval_loader): + model.eval() + if hvd.rank() == 0: + pbar = tqdm(total=len(eval_loader)) + else: + pbar = NoOp() + score_matrix = torch.zeros(len(eval_loader.dataset), + len(eval_loader.dataset.all_img_ids), + device=torch.device("cuda"), + dtype=torch.float16) + + for i, mini_batches in enumerate(eval_loader): + j = 0 + for batch in mini_batches: + model_outputs = model(batch, compute_loss=False) + if isinstance(model_outputs, dict): + scores = model_outputs['rank_scores'] + else: + scores = model_outputs + bs = scores.size(0) + score_matrix.data[i, j:j+bs] = scores.data.squeeze(1).half() + j += bs + assert j == score_matrix.size(1) + pbar.update(1) + model.train() + pbar.close() + return score_matrix diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..c0ddb4f --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,108 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +helper for logging +NOTE: loggers are global objects use with caution +""" +import logging +import math + +import tensorboardX + + +_LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' +_DATE_FMT = '%m/%d/%Y %H:%M:%S' +logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) +LOGGER = logging.getLogger('__main__') # this is the global logger + + +def add_log_to_file(log_path): + fh = logging.FileHandler(log_path) + formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) + fh.setFormatter(formatter) + LOGGER.addHandler(fh) + + +class TensorboardLogger(object): + def __init__(self): + self._logger = None + self._global_step = 0 + + def create(self, path): + self._logger = tensorboardX.SummaryWriter(path) + + def noop(self, *args, **kwargs): + return + + def step(self): + self._global_step += 1 + + @property + def global_step(self): + return self._global_step + + def log_scaler_dict(self, log_dict, prefix=''): + """ log a dictionary of scalar values""" + if self._logger is None: + return + if prefix: + prefix = f'{prefix}_' + for name, value in log_dict.items(): + if isinstance(value, dict): + self.log_scaler_dict(value, self._global_step, + prefix=f'{prefix}{name}') + else: + self._logger.add_scalar(f'{prefix}{name}', value, + self._global_step) + + def log_histogram_dict(self, log_dict, prefix=''): + """ log a dictionary of scalar values""" + if self._logger is None: + return + if prefix: + prefix = f'{prefix}_' + for name, value in log_dict.items(): + if isinstance(value, dict): + self.log_histogram_dict(value, self._global_step, + prefix=f'{prefix}{name}') + else: + self._logger.add_histogram(f'{prefix}{name}', value, + self._global_step) + + def __getattr__(self, name): + if self._logger is None: + return self.noop + return self._logger.__getattribute__(name) + + +TB_LOGGER = TensorboardLogger() + + +class RunningMeter(object): + """ running meteor of a scalar value + (useful for monitoring training loss) + """ + def __init__(self, name, val=None, smooth=0.99): + self._name = name + self._sm = smooth + self._val = val + + def __call__(self, value): + val = (value if self._val is None + else value*(1-self._sm) + self._val*self._sm) + if not math.isnan(val): + self._val = val + + def __str__(self): + return f'{self._name}: {self._val:.4f}' + + @property + def val(self): + if self._val is None: + return 0 + return self._val + + @property + def name(self): + return self._name diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..961322a --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,70 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Misc utilities +""" +import json +import random +import sys + +import torch +import numpy as np + +from utils.logger import LOGGER + + +class NoOp(object): + """ useful for distributed training No-Ops """ + def __getattr__(self, name): + return self.noop + + def noop(self, *args, **kwargs): + return + + +def parse_with_config(parser): + args = parser.parse_args() + if args.config is not None: + config_args = json.load(open(args.config)) + override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] + if arg.startswith('--')} + for k, v in config_args.items(): + if k not in override_keys: + setattr(args, k, v) + del args.config + return args + + +VE_ENT2IDX = { + 'contradiction': 0, + 'entailment': 1, + 'neutral': 2 +} + +VE_IDX2ENT = { + 0: 'contradiction', + 1: 'entailment', + 2: 'neutral' +} + + +class Struct(object): + def __init__(self, dict_): + self.__dict__.update(dict_) + + +def set_dropout(model, drop_p): + for name, module in model.named_modules(): + # we might want to tune dropout for smaller dataset + if isinstance(module, torch.nn.Dropout): + if module.p != drop_p: + module.p = drop_p + LOGGER.info(f'{name} set to {drop_p}') + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/utils/save.py b/utils/save.py new file mode 100644 index 0000000..8d4383b --- /dev/null +++ b/utils/save.py @@ -0,0 +1,77 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +saving utilities +""" +import json +import os +from os.path import abspath, dirname, exists, join +import subprocess + +import torch + +from utils.logger import LOGGER + + +def save_training_meta(args): + if args.rank > 0: + return + + if not exists(args.output_dir): + os.makedirs(join(args.output_dir, 'log')) + os.makedirs(join(args.output_dir, 'ckpt')) + if not exists(os.path.join(args.output_dir, 'log')): + os.makedirs(join(args.output_dir, 'log')) + if not exists(os.path.join(args.output_dir, 'ckpt')): + os.makedirs(join(args.output_dir, 'ckpt')) + + with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer: + json.dump(vars(args), writer, indent=4) + model_config = json.load(open(args.model_config)) + with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer: + json.dump(model_config, writer, indent=4) + # git info + try: + LOGGER.info("Waiting on git info....") + c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], + timeout=10, stdout=subprocess.PIPE) + git_branch_name = c.stdout.decode().strip() + LOGGER.info("Git branch: %s", git_branch_name) + c = subprocess.run(["git", "rev-parse", "HEAD"], + timeout=10, stdout=subprocess.PIPE) + git_sha = c.stdout.decode().strip() + LOGGER.info("Git SHA: %s", git_sha) + git_dir = abspath(dirname(__file__)) + git_status = subprocess.check_output( + ['git', 'status', '--short'], + cwd=git_dir, universal_newlines=True).strip() + with open(join(args.output_dir, 'log', 'git_info.json'), + 'w') as writer: + json.dump({'branch': git_branch_name, + 'is_dirty': bool(git_status), + 'status': git_status, + 'sha': git_sha}, + writer, indent=4) + except subprocess.TimeoutExpired as e: + LOGGER.exception(e) + LOGGER.warn("Git info not found. Moving right along...") + + +class ModelSaver(object): + def __init__(self, output_dir, prefix='model_step', suffix='pt'): + self.output_dir = output_dir + self.prefix = prefix + self.suffix = suffix + + def save(self, model, step, optimizer=None): + output_model_file = join(self.output_dir, + f"{self.prefix}_{step}.{self.suffix}") + state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in model.state_dict().items()} + torch.save(state_dict, output_model_file) + if optimizer is not None: + dump = {'step': step, 'optimizer': optimizer.state_dict()} + if hasattr(optimizer, '_amp_stash'): + pass # TODO fp16 optimizer + torch.save(dump, f'{self.output_dir}/train_state_{step}.pt')