diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..36a2815
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,38 @@
+FROM nvcr.io/nvidia/pytorch:19.05-py3
+
+# basic python packages
+RUN pip install pytorch-pretrained-bert==0.6.2 \
+ tensorboardX==1.7 ipdb==0.12 lz4==2.1.9 lmdb==0.97
+
+####### horovod for multi-GPU (distributed) training #######
+
+# update OpenMPI to avoid horovod bug
+RUN rm -r /usr/local/mpi &&\
+ wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.4.tar.gz &&\
+ gunzip -c openmpi-3.1.4.tar.gz | tar xf - &&\
+ cd openmpi-3.1.4 &&\
+ ./configure --prefix=/usr/local/mpi --enable-orterun-prefix-by-default \
+ --with-verbs --disable-getpwuid &&\
+ make -j$(nproc) all && make install &&\
+ ldconfig &&\
+ cd - && rm -r openmpi-3.1.4 && rm openmpi-3.1.4.tar.gz
+
+ENV OPENMPI_VERSION=3.1.4
+
+# horovod
+RUN HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \
+ pip install --no-cache-dir horovod==0.16.4 &&\
+ ldconfig
+
+# ssh
+RUN apt-get update &&\
+ apt-get install -y --no-install-recommends openssh-client openssh-server &&\
+ mkdir -p /var/run/sshd
+
+# Allow OpenSSH to talk to containers without asking for confirmation
+RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \
+ echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \
+ mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config
+
+
+WORKDIR /src
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..9b7b0f7
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 Microsoft Corporation
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..8e76355
--- /dev/null
+++ b/README.md
@@ -0,0 +1,146 @@
+# NSGDC
+
+Some codes in this repo are copied/modified from opensource implementations made available by
+[UNITER](https://github.com/ChenRocks/UNITER),
+[PyTorch](https://github.com/pytorch/pytorch),
+[HuggingFace](https://github.com/huggingface/transformers),
+[OpenNMT](https://github.com/OpenNMT/OpenNMT-py),
+and [Nvidia](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch).
+The image features are extracted using [BUTD](https://github.com/peteanderson80/bottom-up-attention).
+
+
+## Requirements
+This is following UNITER. We provide Docker image for easier reproduction. Please install the following:
+ - [nvidia driver](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#package-manager-installation) (418+),
+ - [Docker](https://docs.docker.com/install/linux/docker-ce/ubuntu/) (19.03+),
+ - [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-docker#quickstart).
+
+Our scripts require the user to have the [docker group membership](https://docs.docker.com/install/linux/linux-postinstall/)
+so that docker commands can be run without sudo.
+We only support Linux with NVIDIA GPUs. We test on Ubuntu 18.04 and V100 cards.
+We use mixed-precision training hence GPUs with Tensor Cores are recommended.
+
+## Image-Text Retrieval
+### Download Data
+```
+bash scripts/download_itm.sh $PATH_TO_STORAGE
+```
+
+### Launch the Docker Container
+```bash
+# docker image should be automatically pulled
+source launch_container.sh $PATH_TO_STORAGE/txt_db $PATH_TO_STORAGE/img_db \
+$PATH_TO_STORAGE/finetune $PATH_TO_STORAGE/pretrained
+```
+
+In case you would like to reproduce the whole preprocessing pipeline.
+
+The launch script respects $CUDA_VISIBLE_DEVICES environment variable.
+Note that the source code is mounted into the container under `/src` instead
+of built into the image so that user modification will be reflected without
+re-building the image. (Data folders are mounted into the container separately
+for flexibility on folder structures.)
+
+
+### Image-Text Retrieval (Flickr30k)
+```
+# Train wit the base setting
+bash run_cmds/tran_pnsgd_base_flickr.sh
+bash run_cmds/tran_pnsgd2_base_flickr.sh
+
+# Train wit the large setting
+bash run_cmds/tran_pnsgd_large_flickr.sh
+bash run_cmds/tran_pnsgd2_large_flickr.sh
+```
+
+### Image-Text Retrieval (COCO)
+```
+# Train wit the base setting
+bash run_cmds/tran_pnsgd_base_coco.sh
+bash run_cmds/tran_pnsgd2_base_coco.sh
+
+# Train wit the large setting
+bash run_cmds/tran_pnsgd_large_coco.sh
+bash run_cmds/tran_pnsgd2_large_coco.sh
+```
+
+### Run Inference
+```
+bash run_cmds/inf_nsgd.sh
+```
+
+## Results
+
+Our models achieve the following performance.
+
+### MS-COCO
+
+
+ Model |
+ Image-to-Text |
+ Text-to-Image |
+
+
+ R@1 |
+ R@5 |
+ R@110 |
+ R@1 |
+ R@5 |
+ R@10 |
+
+
+ NSGDC-Base |
+ 66.6 |
+ 88.6 |
+ 94.0 |
+ 51.6 |
+ 79.1 |
+ 87.5 |
+
+
+ NSGDC-Large |
+ 67.8 |
+ 89.6 |
+ 94.2 |
+ 53.3 |
+ 80.0 |
+ 88.0 |
+
+
+
+### Flickr30K
+
+
+
+
+ Model |
+ Image-to-Text |
+ Text-to-Image |
+
+
+ R@1 |
+ R@5 |
+ R@110 |
+ R@1 |
+ R@5 |
+ R@10 |
+
+
+ NSGDC-Base |
+ 87.9 |
+ 98.1 |
+ 99.3 |
+ 74.5 |
+ 93.3 |
+ 96.3 |
+
+
+ NSGDC-Large |
+ 90.6 |
+ 98.8 |
+ 99.1 |
+ 77.3 |
+ 94.3 |
+ 97.3 |
+
+
diff --git a/config/train-itm-pnsgd-base-coco.json b/config/train-itm-pnsgd-base-coco.json
new file mode 100644
index 0000000..9adfc0c
--- /dev/null
+++ b/config/train-itm-pnsgd-base-coco.json
@@ -0,0 +1,43 @@
+{
+ "compressed_db": false,
+ "checkpoint": "log/pretrained/uniter-base.pt",
+ "max_txt_len": 60,
+ "conf_th": 0.2,
+ "max_bb": 100,
+ "min_bb": 10,
+ "num_bb": 36,
+ "train_batch_size": 32,
+ "negative_size": 399,
+ "hard_neg_size": 31,
+ "inf_minibatch_size": 400,
+ "margin": 0.2,
+ "valid_steps": 500,
+ "num_train_steps": 5000,
+ "optim": "adamw",
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "dropout": 0.1,
+ "weight_decay": 0.01,
+ "grad_norm": 2.0,
+ "warmup_steps": 500,
+ "seed": 42,
+ "full_val": true,
+ "fp16": true,
+ "n_workers": 4,
+ "pin_mem": true,
+ "train_txt_dbs": [
+ "itm-data/txt_db3/itm_coco_train.db",
+ "itm-data/txt_db3/itm_coco_restval.db"
+ ],
+ "train_img_dbs": [
+ "itm-data/img_db/coco_train2014/",
+ "itm-data/img_db/coco_val2014/"
+ ],
+ "val_txt_db": "itm-data/txt_db3/itm_coco_val.db",
+ "val_img_db": "itm-data/img_db/coco_val2014",
+ "test_txt_db": "itm-data/txt_db3/itm_coco_test.db",
+ "test_img_db": "itm-data/img_db/coco_val2014",
+ "model_config": "config/uniter-base.json"
+}
diff --git a/config/train-itm-pnsgd-base-flickr.json b/config/train-itm-pnsgd-base-flickr.json
new file mode 100644
index 0000000..7c9b705
--- /dev/null
+++ b/config/train-itm-pnsgd-base-flickr.json
@@ -0,0 +1,41 @@
+{
+ "compressed_db": false,
+ "checkpoint": "log/pretrained/uniter-base.pt",
+ "max_txt_len": 60,
+ "conf_th": 0.2,
+ "max_bb": 100,
+ "min_bb": 10,
+ "num_bb": 36,
+ "train_batch_size": 32,
+ "negative_size": 399,
+ "hard_neg_size": 31,
+ "inf_minibatch_size": 400,
+ "margin": 0.2,
+ "valid_steps": 500,
+ "num_train_steps": 5000,
+ "optim": "adamw",
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "dropout": 0.1,
+ "weight_decay": 0.01,
+ "grad_norm": 2.0,
+ "warmup_steps": 500,
+ "seed": 42,
+ "full_val": true,
+ "fp16": true,
+ "n_workers": 4,
+ "pin_mem": true,
+ "train_txt_dbs": [
+ "itm-data/txt_db3/itm_flickr30k_train.db"
+ ],
+ "train_img_dbs": [
+ "itm-data/img_db/flickr30k/"
+ ],
+ "val_txt_db": "itm-data/txt_db3/itm_flickr30k_val.db",
+ "val_img_db": "itm-data/img_db/flickr30k/",
+ "test_txt_db": "itm-data/txt_db3/itm_flickr30k_test.db",
+ "test_img_db": "itm-data/img_db/flickr30k/",
+ "model_config": "config/uniter-base.json"
+}
diff --git a/config/train-itm-pnsgd-large-coco.json b/config/train-itm-pnsgd-large-coco.json
new file mode 100644
index 0000000..cab857f
--- /dev/null
+++ b/config/train-itm-pnsgd-large-coco.json
@@ -0,0 +1,43 @@
+{
+ "compressed_db": false,
+ "checkpoint": "log/pretrained/uniter-large.pt",
+ "max_txt_len": 60,
+ "conf_th": 0.2,
+ "max_bb": 100,
+ "min_bb": 10,
+ "num_bb": 36,
+ "train_batch_size": 32,
+ "negative_size": 399,
+ "hard_neg_size": 31,
+ "inf_minibatch_size": 400,
+ "margin": 0.2,
+ "valid_steps": 500,
+ "num_train_steps": 5000,
+ "optim": "adamw",
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "dropout": 0.1,
+ "weight_decay": 0.01,
+ "grad_norm": 2.0,
+ "warmup_steps": 500,
+ "seed": 42,
+ "full_val": true,
+ "fp16": true,
+ "n_workers": 4,
+ "pin_mem": true,
+ "train_txt_dbs": [
+ "itm-data/txt_db3/itm_coco_train.db",
+ "itm-data/txt_db3/itm_coco_restval.db"
+ ],
+ "train_img_dbs": [
+ "itm-data/img_db/coco_train2014/",
+ "itm-data/img_db/coco_val2014/"
+ ],
+ "val_txt_db": "itm-data/txt_db3/itm_coco_val.db",
+ "val_img_db": "itm-data/img_db/coco_val2014",
+ "test_txt_db": "itm-data/txt_db3/itm_coco_test.db",
+ "test_img_db": "itm-data/img_db/coco_val2014",
+ "model_config": "config/uniter-large.json"
+}
diff --git a/config/train-itm-pnsgd-large-flickr.json b/config/train-itm-pnsgd-large-flickr.json
new file mode 100644
index 0000000..5862f15
--- /dev/null
+++ b/config/train-itm-pnsgd-large-flickr.json
@@ -0,0 +1,41 @@
+{
+ "compressed_db": false,
+ "checkpoint": "log/pretrained/uniter-large.pt",
+ "max_txt_len": 60,
+ "conf_th": 0.2,
+ "max_bb": 100,
+ "min_bb": 10,
+ "num_bb": 36,
+ "train_batch_size": 16,
+ "negative_size": 399,
+ "hard_neg_size": 31,
+ "inf_minibatch_size": 400,
+ "margin": 0.2,
+ "valid_steps": 500,
+ "num_train_steps": 5000,
+ "optim": "adamw",
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "dropout": 0.1,
+ "weight_decay": 0.01,
+ "grad_norm": 2.0,
+ "warmup_steps": 500,
+ "seed": 42,
+ "full_val": true,
+ "fp16": true,
+ "n_workers": 4,
+ "pin_mem": true,
+ "train_txt_dbs": [
+ "itm-data/txt_db3/itm_flickr30k_train.db"
+ ],
+ "train_img_dbs": [
+ "itm-data/img_db/flickr30k/"
+ ],
+ "val_txt_db": "itm-data/txt_db3/itm_flickr30k_val.db",
+ "val_img_db": "itm-data/img_db/flickr30k/",
+ "test_txt_db": "itm-data/txt_db3/itm_flickr30k_test.db",
+ "test_img_db": "itm-data/img_db/flickr30k/",
+ "model_config": "config/uniter-large.json"
+}
diff --git a/config/train-itm-pnsgd2-base-coco.json b/config/train-itm-pnsgd2-base-coco.json
new file mode 100644
index 0000000..362a870
--- /dev/null
+++ b/config/train-itm-pnsgd2-base-coco.json
@@ -0,0 +1,41 @@
+{
+ "compressed_db": false,
+ "max_txt_len": 60,
+ "conf_th": 0.2,
+ "max_bb": 100,
+ "min_bb": 10,
+ "num_bb": 36,
+ "train_batch_size": 32,
+ "negative_size": 399,
+ "hard_neg_size": 31,
+ "inf_minibatch_size": 400,
+ "margin": 0.2,
+ "valid_steps": 150,
+ "num_train_steps": 1500,
+ "optim": "adamw",
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "dropout": 0.1,
+ "weight_decay": 0.01,
+ "grad_norm": 2.0,
+ "warmup_steps": 150,
+ "seed": 42,
+ "full_val": true,
+ "n_workers": 4,
+ "pin_mem": true,
+ "train_txt_dbs": [
+ "itm-data/txt_db3/itm_coco_train.db",
+ "itm-data/txt_db3/itm_coco_restval.db"
+ ],
+ "train_img_dbs": [
+ "itm-data/img_db/coco_train2014/",
+ "itm-data/img_db/coco_val2014/"
+ ],
+ "val_txt_db": "itm-data/txt_db3/itm_coco_val.db",
+ "val_img_db": "itm-data/img_db/coco_val2014",
+ "test_txt_db": "itm-data/txt_db3/itm_coco_test.db",
+ "test_img_db": "itm-data/img_db/coco_val2014",
+ "model_config": "config/uniter-base.json"
+}
diff --git a/config/train-itm-pnsgd2-base-flickr.json b/config/train-itm-pnsgd2-base-flickr.json
new file mode 100644
index 0000000..9b791c6
--- /dev/null
+++ b/config/train-itm-pnsgd2-base-flickr.json
@@ -0,0 +1,39 @@
+{
+ "compressed_db": false,
+ "max_txt_len": 60,
+ "conf_th": 0.2,
+ "max_bb": 100,
+ "min_bb": 10,
+ "num_bb": 36,
+ "train_batch_size": 32,
+ "negative_size": 399,
+ "hard_neg_size": 31,
+ "inf_minibatch_size": 400,
+ "margin": 0.2,
+ "valid_steps": 500,
+ "num_train_steps": 5000,
+ "optim": "adamw",
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "dropout": 0.1,
+ "weight_decay": 0.01,
+ "grad_norm": 2.0,
+ "warmup_steps": 500,
+ "seed": 42,
+ "full_val": true,
+ "n_workers": 4,
+ "pin_mem": true,
+ "train_txt_dbs": [
+ "itm-data/txt_db3/itm_flickr30k_train.db"
+ ],
+ "train_img_dbs": [
+ "itm-data/img_db/flickr30k/"
+ ],
+ "val_txt_db": "itm-data/txt_db3/itm_flickr30k_val.db",
+ "val_img_db": "itm-data/img_db/flickr30k/",
+ "test_txt_db": "itm-data/txt_db3/itm_flickr30k_test.db",
+ "test_img_db": "itm-data/img_db/flickr30k/",
+ "model_config": "config/uniter-base.json"
+}
diff --git a/config/train-itm-pnsgd2-large-coco.json b/config/train-itm-pnsgd2-large-coco.json
new file mode 100644
index 0000000..d24848e
--- /dev/null
+++ b/config/train-itm-pnsgd2-large-coco.json
@@ -0,0 +1,42 @@
+{
+ "compressed_db": false,
+ "max_txt_len": 60,
+ "conf_th": 0.2,
+ "max_bb": 100,
+ "min_bb": 10,
+ "num_bb": 36,
+ "train_batch_size": 32,
+ "negative_size": 399,
+ "hard_neg_size": 31,
+ "inf_minibatch_size": 400,
+ "margin": 0.2,
+ "valid_steps": 500,
+ "num_train_steps": 5000,
+ "optim": "adamw",
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "dropout": 0.1,
+ "weight_decay": 0.01,
+ "grad_norm": 2.0,
+ "warmup_steps": 500,
+ "seed": 42,
+ "full_val": true,
+ "fp16": true,
+ "n_workers": 4,
+ "pin_mem": true,
+ "train_txt_dbs": [
+ "itm-data/txt_db3/itm_coco_train.db",
+ "itm-data/txt_db3/itm_coco_restval.db"
+ ],
+ "train_img_dbs": [
+ "itm-data/img_db/coco_train2014/",
+ "itm-data/img_db/coco_val2014/"
+ ],
+ "val_txt_db": "itm-data/txt_db3/itm_coco_val.db",
+ "val_img_db": "itm-data/img_db/coco_val2014",
+ "test_txt_db": "itm-data/txt_db3/itm_coco_test.db",
+ "test_img_db": "itm-data/img_db/coco_val2014",
+ "model_config": "config/uniter-large.json"
+}
diff --git a/config/train-itm-pnsgd2-large-flickr.json b/config/train-itm-pnsgd2-large-flickr.json
new file mode 100644
index 0000000..94b9dfe
--- /dev/null
+++ b/config/train-itm-pnsgd2-large-flickr.json
@@ -0,0 +1,39 @@
+{
+ "compressed_db": false,
+ "max_txt_len": 60,
+ "conf_th": 0.2,
+ "max_bb": 100,
+ "min_bb": 10,
+ "num_bb": 36,
+ "train_batch_size": 32,
+ "negative_size": 399,
+ "hard_neg_size": 31,
+ "inf_minibatch_size": 400,
+ "margin": 0.2,
+ "valid_steps": 150,
+ "num_train_steps": 1500,
+ "optim": "adamw",
+ "betas": [
+ 0.9,
+ 0.98
+ ],
+ "dropout": 0.1,
+ "weight_decay": 0.01,
+ "grad_norm": 2.0,
+ "warmup_steps": 150,
+ "seed": 42,
+ "full_val": true,
+ "n_workers": 4,
+ "pin_mem": true,
+ "train_txt_dbs": [
+ "itm-data/txt_db3/itm_flickr30k_train.db"
+ ],
+ "train_img_dbs": [
+ "itm-data/img_db/flickr30k/"
+ ],
+ "val_txt_db": "itm-data/txt_db3/itm_flickr30k_val.db",
+ "val_img_db": "itm-data/img_db/flickr30k/",
+ "test_txt_db": "itm-data/txt_db3/itm_flickr30k_test.db",
+ "test_img_db": "itm-data/img_db/flickr30k/",
+ "model_config": "config/uniter-large.json"
+}
diff --git a/config/uniter-base.json b/config/uniter-base.json
new file mode 100644
index 0000000..691dacf
--- /dev/null
+++ b/config/uniter-base.json
@@ -0,0 +1,13 @@
+{
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "max_position_embeddings": 512,
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "type_vocab_size": 2,
+ "vocab_size": 28996
+}
diff --git a/config/uniter-large.json b/config/uniter-large.json
new file mode 100644
index 0000000..961e7ca
--- /dev/null
+++ b/config/uniter-large.json
@@ -0,0 +1,13 @@
+{
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 1024,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "max_position_embeddings": 512,
+ "num_attention_heads": 16,
+ "num_hidden_layers": 24,
+ "type_vocab_size": 2,
+ "vocab_size": 28996
+}
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..bf64845
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,17 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+from .data import (TxtTokLmdb, DetectFeatLmdb,
+ ImageLmdbGroup, ConcatDatasetWithLens)
+from .sampler import TokenBucketSampler
+from .loader import PrefetchLoader, MetaLoader
+from .itm import (TokenBucketSamplerForItm, ItmDataset,
+ itm_collate, itm_ot_collate,
+ ItmRankDataset, ItmValDataset, ItmEvalDataset, ItmAdvEvalDataset, ItmDCEvalDataset, ItmStaticDataAttackEvalDataset,
+ ItmRankDatasetHardNegFromImage,
+ ItmRankDatasetHardNegFromText,
+ itm_rank_collate, itm_val_collate, itm_eval_collate,
+ itm_rank_hn_collate)
+from .pnsgd import (PNSGDFromImage, PNSGDFromText, pnsgd_collate)
diff --git a/data/data.py b/data/data.py
new file mode 100644
index 0000000..fb02737
--- /dev/null
+++ b/data/data.py
@@ -0,0 +1,326 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Dataset interfaces
+"""
+from collections import defaultdict
+from contextlib import contextmanager
+import io
+import json
+from os.path import exists
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset, ConcatDataset
+import horovod.torch as hvd
+from tqdm import tqdm
+import lmdb
+from lz4.frame import compress, decompress
+
+import msgpack
+import msgpack_numpy
+msgpack_numpy.patch()
+
+
+def _fp16_to_fp32(feat_dict):
+ out = {k: arr.astype(np.float32)
+ if arr.dtype == np.float16 else arr
+ for k, arr in feat_dict.items()}
+ return out
+
+
+def compute_num_bb(confs, conf_th, min_bb, max_bb):
+ num_bb = max(min_bb, (confs > conf_th).sum())
+ num_bb = min(max_bb, num_bb)
+ return num_bb
+
+
+def _check_distributed():
+ try:
+ dist = hvd.size() != hvd.local_size()
+ except ValueError:
+ # not using horovod
+ dist = False
+ return dist
+
+
+class DetectFeatLmdb(object):
+ def __init__(self, img_dir, conf_th=0.2, max_bb=100, min_bb=10, num_bb=36,
+ compress=True):
+ self.img_dir = img_dir
+ print('| image_dir: {}'.format(img_dir))
+ print('| conf_th: ', conf_th)
+ import os
+ print('| files: ', os.listdir(img_dir))
+ if conf_th == -1:
+ db_name = f'feat_numbb{num_bb}'
+ self.name2nbb = defaultdict(lambda: num_bb)
+ else:
+ db_name = f'feat_th{conf_th}_max{max_bb}_min{min_bb}'
+ nbb = f'nbb_th{conf_th}_max{max_bb}_min{min_bb}.json'
+ if not exists(f'{img_dir}/{nbb}'):
+ # nbb is not pre-computed
+ self.name2nbb = None
+ else:
+ self.name2nbb = json.load(open(f'{img_dir}/{nbb}'))
+ self.compress = compress
+ if compress:
+ db_name += '_compressed'
+
+ if self.name2nbb is None:
+ if compress:
+ db_name = 'all_compressed'
+ else:
+ db_name = 'all'
+ # only read ahead on single node training
+ self.env = lmdb.open(f'{img_dir}/{db_name}',
+ readonly=True, create=False,
+ readahead=not _check_distributed())
+ self.txn = self.env.begin(buffers=True)
+ if self.name2nbb is None:
+ self.name2nbb = self._compute_nbb()
+
+ def _compute_nbb(self):
+ name2nbb = {}
+ fnames = json.loads(self.txn.get(key=b'__keys__').decode('utf-8'))
+ for fname in tqdm(fnames, desc='reading images'):
+ dump = self.txn.get(fname.encode('utf-8'))
+ if self.compress:
+ with io.BytesIO(dump) as reader:
+ img_dump = np.load(reader, allow_pickle=True)
+ confs = img_dump['conf']
+ else:
+ img_dump = msgpack.loads(dump, raw=False)
+ confs = img_dump['conf']
+ name2nbb[fname] = compute_num_bb(confs, self.conf_th,
+ self.min_bb, self.max_bb)
+
+ return name2nbb
+
+ def __del__(self):
+ self.env.close()
+
+ def get_dump(self, file_name):
+ # hack for MRC
+ dump = self.txn.get(file_name.encode('utf-8'))
+ nbb = self.name2nbb[file_name]
+ if self.compress:
+ with io.BytesIO(dump) as reader:
+ img_dump = np.load(reader, allow_pickle=True)
+ img_dump = _fp16_to_fp32(img_dump)
+ else:
+ img_dump = msgpack.loads(dump, raw=False)
+ img_dump = _fp16_to_fp32(img_dump)
+ img_dump = {k: arr[:nbb, ...] for k, arr in img_dump.items()}
+ return img_dump
+
+ def __getitem__(self, file_name):
+ dump = self.txn.get(file_name.encode('utf-8'))
+ nbb = self.name2nbb[file_name]
+ if self.compress:
+ with io.BytesIO(dump) as reader:
+ img_dump = np.load(reader, allow_pickle=True)
+ img_dump = {'features': img_dump['features'],
+ 'norm_bb': img_dump['norm_bb']}
+ else:
+ img_dump = msgpack.loads(dump, raw=False)
+ img_feat = torch.tensor(img_dump['features'][:nbb, :]).float()
+ img_bb = torch.tensor(img_dump['norm_bb'][:nbb, :]).float()
+ return img_feat, img_bb
+
+
+@contextmanager
+def open_lmdb(db_dir, readonly=False):
+ db = TxtLmdb(db_dir, readonly)
+ try:
+ yield db
+ finally:
+ del db
+
+
+class TxtLmdb(object):
+ def __init__(self, db_dir, readonly=True):
+ self.readonly = readonly
+ if readonly:
+ # training
+ print('| db_dir: ', db_dir)
+ self.env = lmdb.open(db_dir,
+ readonly=True, create=False,
+ readahead=not _check_distributed())
+ self.txn = self.env.begin(buffers=True)
+ self.write_cnt = None
+ # cur = self.txn.cursor() # 生成迭代器指针
+ # num = 0
+ # for key, value in cur:
+ # num += 1
+ #
+ # print('| number: ', num)
+ # raise Exception(" number: {}".format(num))
+
+ else:
+ # prepro
+ self.env = lmdb.open(db_dir, readonly=False, create=True,
+ map_size=4 * 1024**4)
+ self.txn = self.env.begin(write=True)
+ self.write_cnt = 0
+
+ def __del__(self):
+ if self.write_cnt:
+ self.txn.commit()
+ self.env.close()
+
+ def __getitem__(self, key):
+ # print('| key: ', key, type(key))
+ return msgpack.loads(decompress(self.txn.get(key.encode('utf-8'))),
+ raw=False)
+
+ def __setitem__(self, key, value):
+ # NOTE: not thread safe
+ if self.readonly:
+ raise ValueError('readonly text DB')
+ ret = self.txn.put(key.encode('utf-8'),
+ compress(msgpack.dumps(value, use_bin_type=True)))
+ self.write_cnt += 1
+ if self.write_cnt % 1000 == 0:
+ self.txn.commit()
+ self.txn = self.env.begin(write=True)
+ self.write_cnt = 0
+ return ret
+
+
+class TxtTokLmdb(object):
+ def __init__(self, db_dir, max_txt_len=60):
+ if max_txt_len == -1:
+ self.id2len = json.load(open(f'{db_dir}/id2len.json'))
+ else:
+ self.id2len = {
+ id_: len_
+ for id_, len_ in json.load(open(f'{db_dir}/id2len.json')
+ ).items()
+ if len_ <= max_txt_len
+ }
+ self.db_dir = db_dir
+ self.db = TxtLmdb(db_dir, readonly=True)
+ meta = json.load(open(f'{db_dir}/meta.json', 'r'))
+ self.cls_ = meta['CLS']
+ self.sep = meta['SEP']
+ self.mask = meta['MASK']
+ self.v_range = meta['v_range']
+
+ def __getitem__(self, id_):
+ txt_dump = self.db[id_]
+ return txt_dump
+
+ def combine_inputs(self, *inputs):
+ input_ids = [self.cls_]
+ for ids in inputs:
+ input_ids.extend(ids + [self.sep])
+ return torch.tensor(input_ids)
+
+ @property
+ def txt2img(self):
+ txt2img = json.load(open(f'{self.db_dir}/txt2img.json'))
+ return txt2img
+
+ @property
+ def img2txts(self):
+ img2txts = json.load(open(f'{self.db_dir}/img2txts.json'))
+ return img2txts
+
+
+def get_ids_and_lens(db):
+ assert isinstance(db, TxtTokLmdb)
+ lens = []
+ ids = []
+ for id_ in list(db.id2len.keys())[hvd.rank()::hvd.size()]:
+ lens.append(db.id2len[id_])
+ ids.append(id_)
+ return lens, ids
+
+
+class DetectFeatTxtTokDataset(Dataset):
+ def __init__(self, txt_db, img_db):
+ assert isinstance(txt_db, TxtTokLmdb)
+ assert isinstance(img_db, DetectFeatLmdb)
+ self.txt_db = txt_db
+ self.img_db = img_db
+ txt_lens, self.ids = get_ids_and_lens(txt_db)
+ # self.ids would be split by the GPUs
+ txt2img = txt_db.txt2img
+ self.lens = [tl + self.img_db.name2nbb[txt2img[id_]]
+ for tl, id_ in zip(txt_lens, self.ids)]
+
+ def __len__(self):
+ return len(self.ids)
+
+ def __getitem__(self, i):
+ id_ = self.ids[i]
+ example = self.txt_db[id_]
+ return example
+
+ def _get_img_feat(self, fname):
+ img_feat, bb = self.img_db[fname]
+ img_bb = torch.cat([bb, bb[:, 4:5]*bb[:, 5:]], dim=-1)
+ num_bb = img_feat.size(0)
+ return img_feat, img_bb, num_bb
+
+
+def pad_tensors(tensors, lens=None, pad=0):
+ """B x [T, ...]"""
+ if lens is None:
+ lens = [t.size(0) for t in tensors]
+ max_len = max(lens)
+ bs = len(tensors)
+ hid = tensors[0].size(-1)
+ dtype = tensors[0].dtype
+ output = torch.zeros(bs, max_len, hid, dtype=dtype)
+ if pad:
+ output.data.fill_(pad)
+ for i, (t, l) in enumerate(zip(tensors, lens)):
+ output.data[i, :l, ...] = t.data
+ return output
+
+
+def get_gather_index(txt_lens, num_bbs, batch_size, max_len, out_size):
+ assert len(txt_lens) == len(num_bbs) == batch_size
+ gather_index = torch.arange(0, out_size, dtype=torch.long,
+ ).unsqueeze(0).repeat(batch_size, 1)
+
+ for i, (tl, nbb) in enumerate(zip(txt_lens, num_bbs)):
+ gather_index.data[i, tl:tl+nbb] = torch.arange(max_len, max_len+nbb,
+ dtype=torch.long).data
+ return gather_index
+
+
+class ConcatDatasetWithLens(ConcatDataset):
+ """ A thin wrapper on pytorch concat dataset for lens batching """
+ def __init__(self, datasets):
+ super().__init__(datasets)
+ self.lens = [l for dset in datasets for l in dset.lens]
+
+ def __getattr__(self, name):
+ return self._run_method_on_all_dsets(name)
+
+ def _run_method_on_all_dsets(self, name):
+ def run_all(*args, **kwargs):
+ return [dset.__getattribute__(name)(*args, **kwargs)
+ for dset in self.datasets]
+ return run_all
+
+
+class ImageLmdbGroup(object):
+ def __init__(self, conf_th, max_bb, min_bb, num_bb, compress):
+ self.path2imgdb = {}
+ self.conf_th = conf_th
+ self.max_bb = max_bb
+ self.min_bb = min_bb
+ self.num_bb = num_bb
+ self.compress = compress
+
+ def __getitem__(self, path):
+ img_db = self.path2imgdb.get(path, None)
+ if img_db is None:
+ img_db = DetectFeatLmdb(path, self.conf_th, self.max_bb,
+ self.min_bb, self.num_bb, self.compress)
+ return img_db
diff --git a/data/itm.py b/data/itm.py
new file mode 100644
index 0000000..c572e1d
--- /dev/null
+++ b/data/itm.py
@@ -0,0 +1,722 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Itm dataset
+"""
+from collections import defaultdict
+import copy
+import random
+
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from toolz.sandbox import unzip
+from cytoolz import concat
+import numpy as np
+
+from .data import (DetectFeatTxtTokDataset, DetectFeatLmdb, TxtTokLmdb,
+ pad_tensors, get_gather_index, get_ids_and_lens)
+from .sampler import TokenBucketSampler
+from collections import Counter
+
+
+def random_word(tokens, vocab_range, mask):
+ """
+ Masking some random tokens for Language Model task with probabilities as in
+ the original BERT paper.
+ :param tokens: list of int, tokenized sentence.
+ :param vocab_range: for choosing a random word
+ :return: (list of int, list of int), masked tokens and related labels for
+ LM prediction
+ """
+ output_label = []
+
+ for i, token in enumerate(tokens):
+ prob = random.random()
+ # mask token with 15% probability
+ if prob < 0.15:
+ prob /= 0.15
+
+ # 80% randomly change token to mask token
+ if prob < 0.8:
+ tokens[i] = mask
+
+ # 10% randomly change token to random token
+ elif prob < 0.9:
+ tokens[i] = random.choice(list(range(*vocab_range)))
+
+ # -> rest 10% randomly keep current token
+
+ # append current token to output (we will predict these later)
+ output_label.append(token)
+ else:
+ # no masking token (will be ignored by loss function later)
+ output_label.append(-1)
+ if all(o == -1 for o in output_label):
+ # at least mask 1
+ output_label[0] = tokens[0]
+ tokens[0] = mask
+
+ return tokens, output_label
+
+
+def random_word2(tokens, vocab_range, mask, mlm_positions, position_to_prob):
+ """
+ Masking some random tokens for Language Model task with probabilities as in
+ the original BERT paper.
+ :param tokens: list of int, tokenized sentence.
+ :param vocab_range: for choosing a random word
+ :param mlm_positions: the positions sequence. if select, all tokens in the positions would be masked.
+ :param position_to_prob: the sampling probability of each token in the specific position
+ :return: (list of int, list of int), masked tokens and related labels for
+ LM prediction
+ """
+ random.shuffle(mlm_positions)
+ output_label = [-1] * len(tokens)
+ for positions in mlm_positions:
+ sample_prob = sum([position_to_prob.get(position, 0.15) for position in positions]) / max(len(positions), 1)
+ prob = random.random()
+ if prob < sample_prob:
+ prob /= sample_prob
+
+ for position in positions:
+ token = tokens[position]
+ if output_label[position] == -1:
+ prob2 = random.random()
+ if prob2 < 0.8:
+ tokens[position] = mask
+ elif prob2 < 0.9:
+ tokens[position] = random.choice(list(range(*vocab_range)))
+ output_label[position] = token
+
+ if all(o == -1 for o in output_label):
+ # at least mask 1
+ select_positions = mlm_positions[0]
+ for position in select_positions:
+ token = tokens[position]
+ prob2 = random.random()
+ if prob2 < 0.8:
+ tokens[position] = mask
+ elif prob2 < 0.9:
+ tokens[position] = random.choice(list(range(*vocab_range)))
+ output_label[position] = token
+
+ return tokens, output_label
+
+
+class TokenBucketSamplerForItm(TokenBucketSampler):
+ def __init__(self, dset, *args, **kwargs):
+ super().__init__(dset.lens, *args, **kwargs)
+ self.dset = dset
+
+ def __iter__(self):
+ it = super().__iter__()
+ self.dset.new_epoch()
+ self._lens = self.dset.lens
+ return it
+
+
+def _has_overlap(la, lb):
+ if len(la) < len(lb):
+ la, lb = lb, la
+ s = set(la)
+ return any(b in s for b in lb)
+
+
+def sample_negative(sample_pool, ground_truths, num_sample):
+ """ random and retry """
+ outputs = ground_truths[:1]
+ while _has_overlap(outputs, ground_truths):
+ outputs = random.sample(sample_pool, num_sample)
+ return outputs
+
+
+class ItmDataset(DetectFeatTxtTokDataset):
+ """ NOTE this Dataset handles distributed training itself
+ (for more efficient negative sampling) """
+ def __init__(self, txt_db, img_db, neg_sample_p=0.5):
+ assert isinstance(txt_db, TxtTokLmdb)
+ assert isinstance(img_db, DetectFeatLmdb)
+
+ self.txt_db = txt_db
+ self.img_db = img_db
+
+ self.txt_lens, self.ids = get_ids_and_lens(txt_db)
+ self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids))
+
+ self.neg_sample_p = neg_sample_p
+ self.new_epoch()
+
+ def new_epoch(self):
+ """ should be called every epoch for more randomness"""
+ self.labels = np.random.choice(
+ [0, 1], size=len(self.ids),
+ p=[self.neg_sample_p, 1-self.neg_sample_p])
+
+ self.lens = []
+ self.train_imgs = []
+ for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)):
+ img_fname = super().__getitem__(i)['img_fname']
+ if self.labels[i] == 0:
+ img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0]
+ self.train_imgs.append(img_fname)
+ self.lens.append(tl + self.img_db.name2nbb[img_fname])
+
+ def __getitem__(self, i):
+ example = super().__getitem__(i)
+ # labels and negative images should be sampled every epoch
+ ground_truth_label = self.labels[i]
+ img_fname = self.train_imgs[i]
+ img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname)
+
+ # text input
+ input_ids = example['input_ids']
+ input_ids = self.txt_db.combine_inputs(input_ids)
+
+ attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
+ target = torch.Tensor(1).long()
+ target.data.fill_(ground_truth_label)
+
+ return input_ids, img_feat, img_pos_feat, attn_masks, target
+
+
+def itm_collate(inputs):
+ (input_ids, img_feats, img_pos_feats, attn_masks, targets
+ ) = map(list, unzip(inputs))
+
+ txt_lens = [i.size(0) for i in input_ids]
+
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ num_bbs = [f.size(0) for f in img_feats]
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
+ targets = torch.cat(targets, dim=0)
+ bs, max_tl = input_ids.size()
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index,
+ 'targets': targets}
+ return batch
+
+
+def _compute_ot_scatter(txt_lens, max_txt_len, joint_len):
+ ot_scatter = torch.arange(0, joint_len, dtype=torch.long
+ ).unsqueeze(0).repeat(len(txt_lens), 1)
+ for i, tl in enumerate(txt_lens):
+ max_ind = max_txt_len + (joint_len-tl)
+ ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind,
+ dtype=torch.long).data
+ return ot_scatter
+
+
+def _compute_pad(lens, max_len):
+ pad = torch.zeros(len(lens), max_len, dtype=torch.uint8)
+ for i, l in enumerate(lens):
+ pad.data[i, l:].fill_(1)
+ return pad
+
+
+def itm_ot_collate(inputs):
+ (input_ids, img_feats, img_pos_feats, attn_masks, targets
+ ) = map(list, unzip(inputs))
+
+ txt_lens = [i.size(0) for i in input_ids]
+
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ num_bbs = [f.size(0) for f in img_feats]
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
+ targets = torch.cat(targets, dim=0)
+ bs, max_tl = input_ids.size()
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
+
+ # OT inputs
+ max_tl = max(txt_lens)
+ max_nbb = max(num_bbs)
+ ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1))
+ txt_pad = _compute_pad(txt_lens, max_tl)
+ img_pad = _compute_pad(num_bbs, max_nbb)
+ ot_inputs = {'ot_scatter': ot_scatter,
+ 'scatter_max': ot_scatter.max().item(),
+ 'txt_pad': txt_pad,
+ 'img_pad': img_pad}
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index,
+ 'targets': targets,
+ 'ot_inputs': ot_inputs}
+ return batch
+
+
+class ItmRankDataset(DetectFeatTxtTokDataset):
+ def __init__(self, txt_db, img_db, neg_sample_size=1):
+ assert neg_sample_size > 0, \
+ "ItmRankDataset need at least 1 negative sample"
+ super().__init__(txt_db, img_db)
+
+ txt2img = self.txt_db.txt2img
+ self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
+ # images partitioned by rank
+ self.img2txts = defaultdict(list)
+ for id_, img in self.txt2img.items():
+ self.img2txts[img].append(id_)
+ self.img_name_list = list(self.img2txts.keys())
+
+ assert neg_sample_size > 0
+ self.neg_sample_size = neg_sample_size
+
+ def __getitem__(self, i):
+ gt_txt_id = self.ids[i]
+ gt_img_fname = self.txt2img[gt_txt_id]
+
+ id_pairs = [(gt_txt_id, gt_img_fname)]
+ # sample negatives
+ neg_sample_img_ids = sample_negative(
+ self.img_name_list, [gt_img_fname], self.neg_sample_size)
+ neg_sample_txt_ids = sample_negative(
+ self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
+ id_pairs.extend([(gt_txt_id, neg_img_id)
+ for neg_img_id in neg_sample_img_ids] +
+ [(neg_txt_id, gt_img_fname)
+ for neg_txt_id in neg_sample_txt_ids])
+ inputs = self._collect_inputs(id_pairs)
+ assert len(inputs) == (1 + 2*self.neg_sample_size)
+ return inputs
+
+ def _collect_inputs(self, id_pairs):
+ # create input features
+ inputs = []
+ for txt_id, img_id in id_pairs:
+ example = self.txt_db[txt_id]
+ # text input
+ input_ids = example['input_ids']
+ input_ids = self.txt_db.combine_inputs(input_ids)
+ # img input
+ img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id)
+ # mask
+ attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
+
+ inputs.append((input_ids, img_feat, img_pos_feat, attn_masks))
+
+ return inputs
+
+
+def itm_rank_collate(inputs):
+ (input_ids, img_feats, img_pos_feats, attn_masks,
+ ) = map(list, unzip(concat(i for i in inputs)))
+
+ txt_lens = [i.size(0) for i in input_ids]
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ num_bbs = [f.size(0) for f in img_feats]
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
+ sample_size = len(inputs[0])
+ assert all(sample_size == len(i) for i in inputs)
+
+ bs, max_tl = input_ids.size()
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index,
+ 'sample_size': sample_size}
+ return batch
+
+
+class ItmRankDatasetHardNegFromText(DetectFeatTxtTokDataset):
+ def __init__(self, txt_db, img_db, neg_sample_size=1):
+ assert neg_sample_size > 0, "need at least 1 negative sample"
+ super().__init__(txt_db, img_db)
+
+ txt2img = self.txt_db.txt2img
+ self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
+ self.img2txts = self.txt_db.img2txts
+ self.img_name_list = list(self.img2txts.keys())
+ self.neg_sample_size = neg_sample_size
+
+ def __getitem__(self, i):
+ gt_txt_id = self.ids[i]
+ gt_img_fname = self.txt2img[gt_txt_id]
+
+ input_ids = self.txt_db[gt_txt_id]['input_ids']
+ input_ids = self.txt_db.combine_inputs(input_ids)
+ input_ids = input_ids.unsqueeze(0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ neg_img_ids = sample_negative(
+ self.img_name_list, [gt_img_fname], self.neg_sample_size)
+ img_ids = [gt_img_fname] + neg_img_ids
+ # process image features (gt always first)
+ img_feats, img_pos_feats, num_bbs = map(
+ list, unzip(map(self._get_img_feat, img_ids)))
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ tl = input_ids.size(1)
+ attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long()
+ for i, nbb in enumerate(num_bbs):
+ attn_masks.data[i, :tl+nbb].fill_(1)
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index([tl]*len(img_ids), num_bbs,
+ len(img_ids), tl, out_size)
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index}
+ return batch
+
+
+class ItmRankDatasetHardNegFromImage(DetectFeatTxtTokDataset):
+ def __init__(self, txt_db, img_db, neg_sample_size=1):
+ assert neg_sample_size > 0, "need at least 1 negative sample"
+ super().__init__(txt_db, img_db)
+
+ txt2img = self.txt_db.txt2img
+ self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
+ self.img2txts = self.txt_db.img2txts
+ self.txt_name_list = list(self.txt2img.keys())
+ self.neg_sample_size = neg_sample_size
+
+ def __getitem__(self, i):
+ gt_txt_id = self.ids[i]
+ gt_img_id = self.txt2img[gt_txt_id]
+ gt_txt_ids = self.img2txts[gt_img_id]
+
+ # process image features (gt always first)
+ img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id)
+ img_feat = img_feat.unsqueeze(0)
+ img_pos_feat = img_pos_feat.unsqueeze(0)
+
+ # sample negative
+ neg_txt_ids = sample_negative(
+ self.txt_name_list, gt_txt_ids, self.neg_sample_size)
+ txt_ids = [gt_txt_id] + neg_txt_ids
+
+ # process text inputs
+ all_inputs = []
+ txt_lens = []
+ for txt_id in txt_ids:
+ input_ids = self.txt_db.combine_inputs(
+ self.txt_db[txt_id]['input_ids'])
+ all_inputs.append(input_ids)
+ txt_lens.append(len(input_ids))
+ input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long()
+ for i, tl in enumerate(txt_lens):
+ attn_masks.data[i, :tl+nbb].fill_(1)
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids),
+ len(txt_ids), tl, out_size)
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index}
+ return batch
+
+
+def itm_rank_hn_collate(inputs):
+ assert len(inputs) == 1
+ return inputs[0]
+
+
+class ItmValDataset(DetectFeatTxtTokDataset):
+ """ For evaluating Image-Text-Retrieval task """
+ def __init__(self, db_dir, img_dir, mini_batch_size=400, mlm_sample_size=1):
+ super().__init__(db_dir, img_dir)
+
+ self.txt_lens = self.lens[:]
+ del self.lens
+ self.txt2img = self.txt_db.txt2img
+ self.img2txts = self.txt_db.img2txts
+ self.all_img_ids = list(self.img2txts.keys())
+ self.id2len = self.txt_db.id2len
+ self.i2len = {i: self.id2len.get(self.ids[i]) for i in range(len(self.ids))}
+
+ assert len(self.img2txts) >= mini_batch_size > 0
+ self.bs = mini_batch_size
+ self.mlm_sample_size = mlm_sample_size
+
+ def _get_batch_ids(self, i):
+ gt_txt_id = self.ids[i]
+ gt_img_id = self.txt2img[gt_txt_id]
+
+ # sample fixed negatives for each gt image
+ i = self.all_img_ids.index(gt_img_id)
+ neg_st = i+1
+ neg_end = neg_st+self.bs-1
+ if neg_end > len(self.all_img_ids):
+ # warp around
+ neg_end -= len(self.all_img_ids)
+ neg_img_ids = (self.all_img_ids[neg_st:]
+ + self.all_img_ids[:neg_end])
+ else:
+ neg_img_ids = self.all_img_ids[neg_st:neg_end]
+
+ assert len(neg_img_ids) == (self.bs - 1),\
+ "Did not sample enough neg samples"
+
+ return gt_img_id, neg_img_ids
+
+ def __getitem__(self, i):
+ """ this returns list of mini-batches """
+ gt_img_id, neg_img_ids = self._get_batch_ids(i)
+ # NOTE 1st one is gt img
+ batch = self.get_batch(i, [gt_img_id] + neg_img_ids)
+ return batch
+
+ def create_mlm_io(self, input_ids, nsample=1):
+ mlm_input_ids, mlm_txt_labels = [], []
+ for i in range(nsample):
+ # tokens, vocab_range, mask, mlm_positions, position_to_prob
+ r_input_ids, txt_labels = random_word(
+ copy.copy(input_ids), self.txt_db.v_range, self.txt_db.mask)
+ mlm_input_ids.append(torch.tensor([self.txt_db.cls_] + r_input_ids + [self.txt_db.sep]))
+ mlm_txt_labels.append(torch.tensor([-1] + txt_labels + [-1]))
+ mlm_input_ids = torch.stack(mlm_input_ids, dim=0)
+ mlm_txt_labels = torch.stack(mlm_txt_labels, dim=0)
+ return mlm_input_ids, mlm_txt_labels
+
+ def create_mlm_io2(self, input_ids, tree=None, nsample=1):
+ mlm_input_ids, mlm_txt_labels = [], []
+ sample_prob = 0.15
+
+ mlm_positions = [[i] for i in range(len(input_ids))]
+ for struct_type in ['relation', 'attribute', 'node']:
+ struct_nodes = tree.get(struct_type)
+ for struct_node in struct_nodes:
+ positions = struct_node.get('ids')
+ if positions is not None:
+ mlm_positions.append(positions)
+ # mlm_positions = list(set(mlm_positions))
+ position_counter = Counter()
+ for positions in mlm_positions:
+ position_counter.update(positions)
+ position_to_prob = {position: sample_prob / max(freq, 1.0) for position, freq in position_counter.items()}
+
+ # print("| mlm_positions: ", mlm_positions)
+ for i in range(nsample):
+ r_input_ids, txt_labels = random_word2(
+ copy.copy(input_ids), self.txt_db.v_range, self.txt_db.mask,
+ mlm_positions=mlm_positions, position_to_prob=position_to_prob)
+ mlm_input_ids.append(torch.tensor([self.txt_db.cls_] + r_input_ids + [self.txt_db.sep]))
+ mlm_txt_labels.append(torch.tensor([-1] + txt_labels + [-1]))
+ mlm_input_ids = torch.stack(mlm_input_ids, dim=0)
+ mlm_txt_labels = torch.stack(mlm_txt_labels, dim=0)
+ return mlm_input_ids, mlm_txt_labels
+
+ def get_batch(self, i, img_ids, forward_mlm=False):
+ batch = {}
+
+ example = super().__getitem__(i)
+
+ input_ids = example['input_ids']
+
+ if forward_mlm:
+ gt_txt_id = self.ids[i]
+ gt_img_fname = self.txt2img[gt_txt_id]
+ img_ids = [gt_img_fname]
+ # Process Masked Text to Generate Adversarial Samples
+ # mlm_input_ids, mlm_txt_labels = self.create_mlm_io(input_ids, nsample=self.mlm_sample_size)
+
+ tree = self.txt_db[gt_txt_id]['tree']
+ mlm_input_ids, mlm_txt_labels = self.create_mlm_io2(input_ids, tree=tree, nsample=self.mlm_sample_size)
+
+ mlm_position_ids = torch.arange(0, mlm_input_ids.size(1), dtype=torch.long).\
+ unsqueeze(0).expand(self.mlm_sample_size, -1)
+ img_feat, img_pos_feat, num_bbs = self._get_img_feat(gt_img_fname)
+ mlm_img_feat = img_feat.unsqueeze(dim=0).expand(self.mlm_sample_size, *list(img_feat.size()))
+ mlm_img_pos_feat = img_pos_feat.unsqueeze(dim=0).expand(self.mlm_sample_size, *list(img_pos_feat.size()))
+ tl = mlm_input_ids.size(1)
+ mlm_attn_masks = torch.zeros(self.mlm_sample_size, tl+num_bbs).long()
+ mlm_attn_masks.data[:, :tl+num_bbs].fill_(1)
+ mlm_gather_index = get_gather_index(
+ [tl]*self.mlm_sample_size, [num_bbs]*self.mlm_sample_size, self.mlm_sample_size, tl, tl+num_bbs)
+
+ batch['mlm_input_ids'] = mlm_input_ids
+ batch['mlm_position_ids'] = mlm_position_ids
+ batch['mlm_img_feat'] = mlm_img_feat
+ batch['mlm_img_pos_feat'] = mlm_img_pos_feat
+ batch['mlm_attn_masks'] = mlm_attn_masks
+ batch['mlm_gather_index'] = mlm_gather_index
+ batch['mlm_txt_labels'] = mlm_txt_labels
+
+ input_ids = self.txt_db.combine_inputs(input_ids)
+ input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone()
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ # process image features (gt always first)
+ img_feats, img_pos_feats, num_bbs = map(
+ list, unzip(map(self._get_img_feat, img_ids)))
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ tl = input_ids.size(1)
+ attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long()
+ for i, nbb in enumerate(num_bbs):
+ attn_masks.data[i, :tl+nbb].fill_(1)
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index([tl]*len(img_ids), num_bbs,
+ len(img_ids), tl, out_size)
+
+ batch['input_ids'] = input_ids
+ batch['position_ids'] = position_ids
+ batch['img_feat'] = img_feat
+ batch['img_pos_feat'] = img_pos_feat
+ batch['attn_masks'] = attn_masks
+ batch['gather_index'] = gather_index
+ return batch
+
+
+def itm_val_collate(inputs):
+ assert len(inputs) == 1, "input batch size > 1"
+ return inputs[0]
+
+
+class ItmEvalDataset(ItmValDataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids),
+ key=lambda i: self.img_db.name2nbb[i])
+
+ def __getitem__(self, i):
+ mini_batches = []
+ for st in range(0, len(self.all_img_ids), self.bs):
+ mini_batches.append(
+ self.get_batch(i, self.all_img_ids[st:st+self.bs]))
+ return mini_batches
+
+
+class ItmAdvEvalDataset(ItmValDataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids),
+ key=lambda i: self.img_db.name2nbb[i])
+ self.all_img_ids = np.array(self.all_img_ids)
+
+ def __getitem__(self, i):
+ return self.get_batch(i, [], forward_mlm=True)
+
+
+class ItmDCEvalDataset(ItmValDataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids),
+ key=lambda i: self.img_db.name2nbb[i])
+
+ def __getitem__(self, i):
+ return self.get_batch(i, [], forward_mlm=True)
+
+ def get_by_img_index(self, i, image_indices):
+ mini_batch = self.get_batch(i, self.all_img_ids[image_indices])
+ mini_batch_img_txt_input_ids = []
+ for image_index in image_indices:
+ img_id = self.all_img_ids[image_index]
+ txt_ids = self.img2txts[img_id]
+ img_txt_input_ids = []
+ for txt_id in txt_ids:
+ input_ids = self.txt_db[txt_id]['input_ids']
+ img_txt_input_ids.extend(input_ids)
+ mini_batch_img_txt_input_ids.append(torch.tensor(list(set(mini_batch_img_txt_input_ids))))
+ mini_batch['img_txt_input_ids'] = mini_batch_img_txt_input_ids
+ mini_batches = [mini_batch]
+ return mini_batches
+
+
+class ItmStaticDataAttackEvalDataset(DetectFeatTxtTokDataset):
+ def __init__(self, db_dir, img_dir, mini_batch_size=400, mlm_sample_size=1):
+ super().__init__(db_dir, img_dir)
+
+ self.txt_lens = self.lens[:]
+ del self.lens
+ self.txt2img = self.txt_db.txt2img
+ self.img2txts = self.txt_db.img2txts
+ self.all_img_ids = list(self.img2txts.keys())
+ self.id2len = self.txt_db.id2len
+ self.i2len = {i: self.id2len.get(self.ids[i]) for i in range(len(self.ids))}
+
+ assert len(self.img2txts) >= mini_batch_size > 0
+ self.bs = mini_batch_size
+ self.mlm_sample_size = mlm_sample_size
+
+ def __getitem__(self, i):
+ return self.get_batch(i, [], forward_mlm=True)
+
+ def get_batch(self, i, img_ids, forward_mlm=False):
+ batch = {}
+ example = super().__getitem__(i)
+ input_ids = example['input_ids']
+ all_attack_text_ids = example['attack_data'][0]
+ all_attack_text_ids = all_attack_text_ids[:400]
+
+ all_inputs = [self.txt_db.combine_inputs(token_ids) for token_ids in [input_ids] + all_attack_text_ids]
+ txt_lens = [len(token_ids) for token_ids in all_inputs]
+ input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0)
+ # print('| inputs_ids: ', len(all_attack_text_ids), input_ids.size())
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
+
+ gt_txt_id = self.ids[i]
+ gt_img_id = self.txt2img[gt_txt_id]
+
+ # process image features (gt always first)
+ img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id)
+ img_feat = img_feat.unsqueeze(0)
+ img_pos_feat = img_pos_feat.unsqueeze(0)
+ # print("| img_feat: ", img_feat.size(), img_pos_feat.size())
+
+ attn_masks = torch.zeros(len(txt_lens), max(txt_lens) + nbb).long()
+ for i, tl in enumerate(txt_lens):
+ attn_masks.data[i, :tl+nbb].fill_(1)
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, [nbb]*len(txt_lens), len(txt_lens), tl, out_size)
+
+ batch['input_ids'] = input_ids
+ batch['position_ids'] = position_ids
+ batch['img_feat'] = img_feat
+ batch['img_pos_feat'] = img_pos_feat
+ batch['attn_masks'] = attn_masks
+ batch['gather_index'] = gather_index
+ return batch
+
+
+itm_eval_collate = itm_val_collate
diff --git a/data/loader.py b/data/loader.py
new file mode 100644
index 0000000..e177c07
--- /dev/null
+++ b/data/loader.py
@@ -0,0 +1,150 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+A prefetch loader to speedup data loading
+Modified from Nvidia Deep Learning Examples
+(https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch).
+"""
+import random
+
+import torch
+from torch.utils.data import DataLoader
+
+from utils.distributed import any_broadcast
+
+
+class MetaLoader(object):
+ """ wraps multiple data loaders """
+ def __init__(self, loaders, accum_steps=1, distributed=False):
+ assert isinstance(loaders, dict)
+ self.name2loader = {}
+ self.name2iter = {}
+ self.sampling_pools = []
+ for n, l in loaders.items():
+ if isinstance(l, tuple):
+ l, r = l
+ elif isinstance(l, DataLoader):
+ r = 1
+ else:
+ raise ValueError()
+ self.name2loader[n] = l
+ self.name2iter[n] = iter(l)
+ self.sampling_pools.extend([n]*r)
+
+ self.accum_steps = accum_steps
+ self.distributed = distributed
+ self.step = 0
+
+ def __iter__(self):
+ """ this iterator will run indefinitely """
+ task = self.sampling_pools[0]
+ while True:
+ if self.step % self.accum_steps == 0:
+ task = random.choice(self.sampling_pools)
+ if self.distributed:
+ # make sure all process is training same task
+ task = any_broadcast(task, 0)
+ self.step += 1
+ iter_ = self.name2iter[task]
+ try:
+ batch = next(iter_)
+ except StopIteration:
+ iter_ = iter(self.name2loader[task])
+ batch = next(iter_)
+ self.name2iter[task] = iter_
+
+ yield task, batch
+
+
+def move_to_cuda(batch):
+ if isinstance(batch, torch.Tensor):
+ try:
+ return batch.cuda(non_blocking=True)
+ except RuntimeError:
+ # print('| Error |', batch.size(), batch.dtype)
+ return batch.contiguous().cuda(non_blocking=True)
+ elif isinstance(batch, list):
+ new_batch = [move_to_cuda(t) for t in batch]
+ elif isinstance(batch, tuple):
+ new_batch = tuple(move_to_cuda(t) for t in batch)
+ elif isinstance(batch, dict):
+ new_batch = {}
+ for n, t in batch.items():
+ # print(n, t.size(), t[0])
+ new_batch[n] = move_to_cuda(t)
+ # new_batch = {n: move_to_cuda(t) for n, t in batch.items()}
+ else:
+ return batch
+ return new_batch
+
+
+def record_cuda_stream(batch):
+ if isinstance(batch, torch.Tensor):
+ batch.record_stream(torch.cuda.current_stream())
+ elif isinstance(batch, list) or isinstance(batch, tuple):
+ for t in batch:
+ record_cuda_stream(t)
+ elif isinstance(batch, dict):
+ for t in batch.values():
+ record_cuda_stream(t)
+ else:
+ pass
+
+
+class PrefetchLoader(object):
+ """
+ overlap compute and cuda data transfer
+ (copied and then modified from nvidia apex)
+ """
+ def __init__(self, loader):
+ self.loader = loader
+ self.stream = torch.cuda.Stream()
+
+ def __iter__(self):
+ loader_it = iter(self.loader)
+ self.preload(loader_it)
+ batch = self.next(loader_it)
+ while batch is not None:
+ yield batch
+ batch = self.next(loader_it)
+
+ def __len__(self):
+ return len(self.loader)
+
+ def preload(self, it):
+ try:
+ self.batch = next(it)
+ except StopIteration:
+ self.batch = None
+ return
+ # if record_stream() doesn't work, another option is to make sure
+ # device inputs are created on the main stream.
+ # self.next_input_gpu = torch.empty_like(self.next_input,
+ # device='cuda')
+ # self.next_target_gpu = torch.empty_like(self.next_target,
+ # device='cuda')
+ # Need to make sure the memory allocated for next_* is not still in use
+ # by the main stream at the time we start copying to next_*:
+ # self.stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self.stream):
+ self.batch = move_to_cuda(self.batch)
+ # more code for the alternative if record_stream() doesn't work:
+ # copy_ will record the use of the pinned source tensor in this
+ # side stream.
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
+ # self.next_input = self.next_input_gpu
+ # self.next_target = self.next_target_gpu
+
+ def next(self, it):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is not None:
+ record_cuda_stream(batch)
+ self.preload(it)
+ return batch
+
+ def __getattr__(self, name):
+ method = self.loader.__getattribute__(name)
+ return method
diff --git a/data/pnsgd.py b/data/pnsgd.py
new file mode 100644
index 0000000..c35ff26
--- /dev/null
+++ b/data/pnsgd.py
@@ -0,0 +1,616 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Itm dataset
+"""
+from collections import defaultdict
+import copy
+import random
+from collections import Counter
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from toolz.sandbox import unzip
+from cytoolz import concat
+import numpy as np
+
+from .data import (DetectFeatTxtTokDataset, DetectFeatLmdb, TxtTokLmdb,
+ pad_tensors, get_gather_index, get_ids_and_lens)
+from .sampler import TokenBucketSampler
+
+
+def random_word(tokens, vocab_range, mask):
+ """
+ Masking some random tokens for Language Model task with probabilities as in
+ the original BERT paper.
+ :param tokens: list of int, tokenized sentence.
+ :param vocab_range: for choosing a random word
+ :return: (list of int, list of int), masked tokens and related labels for
+ LM prediction
+ """
+ output_label = []
+
+ for i, token in enumerate(tokens):
+ prob = random.random()
+ # mask token with 15% probability
+ if prob < 0.15:
+ prob /= 0.15
+
+ # 80% randomly change token to mask token
+ if prob < 0.8:
+ tokens[i] = mask
+
+ # 10% randomly change token to random token
+ elif prob < 0.9:
+ tokens[i] = random.choice(list(range(*vocab_range)))
+
+ # -> rest 10% randomly keep current token
+
+ # append current token to output (we will predict these later)
+ output_label.append(token)
+ else:
+ # no masking token (will be ignored by loss function later)
+ output_label.append(-1)
+ if all(o == -1 for o in output_label):
+ # at least mask 1
+ output_label[0] = tokens[0]
+ tokens[0] = mask
+
+ return tokens, output_label
+
+
+def random_word2(tokens, vocab_range, mask, mlm_positions, position_to_prob):
+ """
+ Masking some random tokens for Language Model task with probabilities as in
+ the original BERT paper.
+ :param tokens: list of int, tokenized sentence.
+ :param vocab_range: for choosing a random word
+ :param mlm_positions: the positions sequence. if select, all tokens in the positions would be masked.
+ :param position_to_prob: the sampling probability of each token in the specific position
+ :return: (list of int, list of int), masked tokens and related labels for
+ LM prediction
+ """
+ random.shuffle(mlm_positions)
+ output_label = [-1] * len(tokens)
+ for positions in mlm_positions:
+ sample_prob = sum([position_to_prob.get(position, 0.15) for position in positions]) / max(len(positions), 1)
+ prob = random.random()
+ if prob < sample_prob:
+ prob /= sample_prob
+
+ for position in positions:
+ token = tokens[position]
+ if output_label[position] == -1:
+ prob2 = random.random()
+ if prob2 < 0.8:
+ tokens[position] = mask
+ elif prob2 < 0.9:
+ tokens[position] = random.choice(list(range(*vocab_range)))
+ output_label[position] = token
+
+ if all(o == -1 for o in output_label):
+ # at least mask 1
+ select_positions = mlm_positions[0]
+ for position in select_positions:
+ token = tokens[position]
+ prob2 = random.random()
+ if prob2 < 0.8:
+ tokens[position] = mask
+ elif prob2 < 0.9:
+ tokens[position] = random.choice(list(range(*vocab_range)))
+ output_label[position] = token
+
+ return tokens, output_label
+
+
+class TokenBucketSamplerForItm(TokenBucketSampler):
+ def __init__(self, dset, *args, **kwargs):
+ super().__init__(dset.lens, *args, **kwargs)
+ self.dset = dset
+
+ def __iter__(self):
+ it = super().__iter__()
+ self.dset.new_epoch()
+ self._lens = self.dset.lens
+ return it
+
+
+def _has_overlap(la, lb):
+ if len(la) < len(lb):
+ la, lb = lb, la
+ s = set(la)
+ return any(b in s for b in lb)
+
+
+def sample_negative(sample_pool, ground_truths, num_sample):
+ """ random and retry """
+ outputs = ground_truths[:1]
+ while _has_overlap(outputs, ground_truths):
+ outputs = random.sample(sample_pool, num_sample)
+ return outputs
+
+
+class ItmDataset(DetectFeatTxtTokDataset):
+ """ NOTE this Dataset handles distributed training itself
+ (for more efficient negative sampling) """
+ def __init__(self, txt_db, img_db, neg_sample_p=0.5):
+ assert isinstance(txt_db, TxtTokLmdb)
+ assert isinstance(img_db, DetectFeatLmdb)
+
+ self.txt_db = txt_db
+ self.img_db = img_db
+
+ self.txt_lens, self.ids = get_ids_and_lens(txt_db)
+ self.all_imgs = list(set(txt_db[id_]['img_fname'] for id_ in self.ids))
+
+ self.neg_sample_p = neg_sample_p
+ self.new_epoch()
+
+ def new_epoch(self):
+ """ should be called every epoch for more randomness"""
+ self.labels = np.random.choice(
+ [0, 1], size=len(self.ids),
+ p=[self.neg_sample_p, 1-self.neg_sample_p])
+
+ self.lens = []
+ self.train_imgs = []
+ for i, (id_, tl) in enumerate(zip(self.ids, self.txt_lens)):
+ img_fname = super().__getitem__(i)['img_fname']
+ if self.labels[i] == 0:
+ img_fname = sample_negative(self.all_imgs, [img_fname], 1)[0]
+ self.train_imgs.append(img_fname)
+ self.lens.append(tl + self.img_db.name2nbb[img_fname])
+
+ def __getitem__(self, i):
+ example = super().__getitem__(i)
+ # labels and negative images should be sampled every epoch
+ ground_truth_label = self.labels[i]
+ img_fname = self.train_imgs[i]
+ img_feat, img_pos_feat, num_bb = self._get_img_feat(img_fname)
+
+ # text input
+ input_ids = example['input_ids']
+ input_ids = self.txt_db.combine_inputs(input_ids)
+
+ attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
+ target = torch.Tensor(1).long()
+ target.data.fill_(ground_truth_label)
+
+ return input_ids, img_feat, img_pos_feat, attn_masks, target
+
+
+def itm_collate(inputs):
+ (input_ids, img_feats, img_pos_feats, attn_masks, targets
+ ) = map(list, unzip(inputs))
+
+ txt_lens = [i.size(0) for i in input_ids]
+
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ num_bbs = [f.size(0) for f in img_feats]
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
+ targets = torch.cat(targets, dim=0)
+ bs, max_tl = input_ids.size()
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index,
+ 'targets': targets}
+ return batch
+
+
+def _compute_ot_scatter(txt_lens, max_txt_len, joint_len):
+ ot_scatter = torch.arange(0, joint_len, dtype=torch.long
+ ).unsqueeze(0).repeat(len(txt_lens), 1)
+ for i, tl in enumerate(txt_lens):
+ max_ind = max_txt_len + (joint_len-tl)
+ ot_scatter.data[i, tl:] = torch.arange(max_txt_len, max_ind,
+ dtype=torch.long).data
+ return ot_scatter
+
+
+def _compute_pad(lens, max_len):
+ pad = torch.zeros(len(lens), max_len, dtype=torch.uint8)
+ for i, l in enumerate(lens):
+ pad.data[i, l:].fill_(1)
+ return pad
+
+
+def itm_ot_collate(inputs):
+ (input_ids, img_feats, img_pos_feats, attn_masks, targets
+ ) = map(list, unzip(inputs))
+
+ txt_lens = [i.size(0) for i in input_ids]
+
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ num_bbs = [f.size(0) for f in img_feats]
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
+ targets = torch.cat(targets, dim=0)
+ bs, max_tl = input_ids.size()
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
+
+ # OT inputs
+ max_tl = max(txt_lens)
+ max_nbb = max(num_bbs)
+ ot_scatter = _compute_ot_scatter(txt_lens, max_tl, attn_masks.size(1))
+ txt_pad = _compute_pad(txt_lens, max_tl)
+ img_pad = _compute_pad(num_bbs, max_nbb)
+ ot_inputs = {'ot_scatter': ot_scatter,
+ 'scatter_max': ot_scatter.max().item(),
+ 'txt_pad': txt_pad,
+ 'img_pad': img_pad}
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index,
+ 'targets': targets,
+ 'ot_inputs': ot_inputs}
+ return batch
+
+
+class ItmRankDataset(DetectFeatTxtTokDataset):
+ def __init__(self, txt_db, img_db, neg_sample_size=1):
+ assert neg_sample_size > 0, \
+ "ItmRankDataset need at least 1 negative sample"
+ super().__init__(txt_db, img_db)
+
+ txt2img = self.txt_db.txt2img
+ self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
+ # images partitioned by rank
+ self.img2txts = defaultdict(list)
+ for id_, img in self.txt2img.items():
+ self.img2txts[img].append(id_)
+ self.img_name_list = list(self.img2txts.keys())
+
+ assert neg_sample_size > 0
+ self.neg_sample_size = neg_sample_size
+
+ def __getitem__(self, i):
+ gt_txt_id = self.ids[i]
+ gt_img_fname = self.txt2img[gt_txt_id]
+
+ id_pairs = [(gt_txt_id, gt_img_fname)]
+ # sample negatives
+ neg_sample_img_ids = sample_negative(
+ self.img_name_list, [gt_img_fname], self.neg_sample_size)
+ neg_sample_txt_ids = sample_negative(
+ self.ids, self.img2txts[gt_img_fname], self.neg_sample_size)
+ id_pairs.extend([(gt_txt_id, neg_img_id)
+ for neg_img_id in neg_sample_img_ids] +
+ [(neg_txt_id, gt_img_fname)
+ for neg_txt_id in neg_sample_txt_ids])
+ inputs = self._collect_inputs(id_pairs)
+ assert len(inputs) == (1 + 2*self.neg_sample_size)
+ return inputs
+
+ def _collect_inputs(self, id_pairs):
+ # create input features
+ inputs = []
+ for txt_id, img_id in id_pairs:
+ example = self.txt_db[txt_id]
+ # text input
+ input_ids = example['input_ids']
+ input_ids = self.txt_db.combine_inputs(input_ids)
+ # img input
+ img_feat, img_pos_feat, num_bb = self._get_img_feat(img_id)
+ # mask
+ attn_masks = torch.ones(len(input_ids) + num_bb, dtype=torch.long)
+
+ inputs.append((input_ids, img_feat, img_pos_feat, attn_masks))
+
+ return inputs
+
+
+def itm_rank_collate(inputs):
+ (input_ids, img_feats, img_pos_feats, attn_masks,
+ ) = map(list, unzip(concat(i for i in inputs)))
+
+ txt_lens = [i.size(0) for i in input_ids]
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ num_bbs = [f.size(0) for f in img_feats]
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ attn_masks = pad_sequence(attn_masks, batch_first=True, padding_value=0)
+ sample_size = len(inputs[0])
+ assert all(sample_size == len(i) for i in inputs)
+
+ bs, max_tl = input_ids.size()
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, num_bbs, bs, max_tl, out_size)
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index,
+ 'sample_size': sample_size}
+ return batch
+
+
+class PNSGDFromText(DetectFeatTxtTokDataset):
+ def __init__(self, txt_db, img_db, neg_sample_size=1, mlm_sample_size=1):
+ assert neg_sample_size > 0, "need at least 1 negative sample"
+ super().__init__(txt_db, img_db)
+
+ txt2img = self.txt_db.txt2img
+ self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
+ self.img2txts = self.txt_db.img2txts
+ self.img_name_list = list(self.img2txts.keys())
+ self.neg_sample_size = neg_sample_size
+ self.mlm_sample_size = mlm_sample_size
+
+ def create_mlm_io(self, input_ids, tree=None, nsample=1):
+ mlm_input_ids, mlm_txt_labels = [], []
+ sample_prob = 0.15
+
+ mlm_positions = []
+ for struct_type in ['relation', 'attribute', 'node']:
+ struct_nodes = tree.get(struct_type)
+ for struct_node in struct_nodes:
+ positions = struct_node.get('ids')
+ if positions is not None:
+ mlm_positions.append(positions)
+ if len(mlm_positions) < 1:
+ mlm_positions = [[i] for i in range(len(input_ids))]
+
+ # mlm_positions = list(set(mlm_positions))
+ position_counter = Counter()
+ for positions in mlm_positions:
+ position_counter.update(positions)
+ position_to_prob = {position: sample_prob / max(freq, 1.0) for position, freq in position_counter.items()}
+
+ # print("| mlm_positions: ", mlm_positions)
+ for i in range(nsample):
+ # tokens, vocab_range, mask, mlm_positions, position_to_prob
+ r_input_ids, txt_labels = random_word2(
+ copy.copy(input_ids), self.txt_db.v_range, self.txt_db.mask,
+ mlm_positions=mlm_positions, position_to_prob=position_to_prob)
+ mlm_input_ids.append(torch.tensor([self.txt_db.cls_] + r_input_ids + [self.txt_db.sep]))
+ mlm_txt_labels.append(torch.tensor([-1] + txt_labels + [-1]))
+ mlm_input_ids = torch.stack(mlm_input_ids, dim=0)
+ mlm_txt_labels = torch.stack(mlm_txt_labels, dim=0)
+ return mlm_input_ids, mlm_txt_labels
+
+ def __getitem__(self, i):
+ gt_txt_id = self.ids[i]
+ gt_img_fname = self.txt2img[gt_txt_id]
+
+ txt_ids = self.txt_db.img2txts[gt_img_fname]
+ tot_input_ids = []
+ for txt_id in txt_ids:
+ tot_input_ids.extend(self.txt_db[txt_id]['input_ids'])
+ tot_input_ids = torch.tensor(tot_input_ids)
+
+ input_ids = self.txt_db[gt_txt_id]['input_ids']
+ tree = self.txt_db[gt_txt_id]['tree']
+
+ mlm_input_ids, mlm_txt_labels = self.create_mlm_io(input_ids, tree=tree, nsample=self.mlm_sample_size)
+ mlm_position_ids = torch.arange(0, mlm_input_ids.size(1), dtype=torch.long).\
+ unsqueeze(0).expand(self.mlm_sample_size, -1)
+ img_feat, img_pos_feat, num_bbs = self._get_img_feat(gt_img_fname)
+ mlm_img_feat = img_feat.unsqueeze(dim=0).expand(self.mlm_sample_size, *list(img_feat.size()))
+ mlm_img_pos_feat = img_pos_feat.unsqueeze(dim=0).expand(self.mlm_sample_size, *list(img_pos_feat.size()))
+ tl = mlm_input_ids.size(1)
+ mlm_attn_masks = torch.zeros(self.mlm_sample_size, tl+num_bbs).long()
+ mlm_attn_masks.data[:, :tl+num_bbs].fill_(1)
+ mlm_gather_index = get_gather_index(
+ [tl]*self.mlm_sample_size, [num_bbs]*self.mlm_sample_size, self.mlm_sample_size, tl, tl+num_bbs)
+
+ # Process Text for Image and Hard Text Matching
+ input_ids = self.txt_db.combine_inputs(input_ids)
+ input_ids = input_ids.unsqueeze(0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
+
+ neg_img_ids = sample_negative(
+ self.img_name_list, [gt_img_fname], self.neg_sample_size)
+ img_ids = [gt_img_fname] + neg_img_ids
+
+ # Process image features (gt always first)
+ img_feats, img_pos_feats, num_bbs = map(
+ list, unzip(map(self._get_img_feat, img_ids)))
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ tl = input_ids.size(1)
+ attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long()
+ for i, nbb in enumerate(num_bbs):
+ attn_masks.data[i, :tl+nbb].fill_(1)
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index([tl]*len(img_ids), num_bbs,
+ len(img_ids), tl, out_size)
+
+ batch = {
+ 'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index,
+ 'mlm_input_ids': mlm_input_ids,
+ 'mlm_position_ids': mlm_position_ids,
+ 'mlm_img_feat': mlm_img_feat,
+ 'mlm_img_pos_feat': mlm_img_pos_feat,
+ 'mlm_attn_masks': mlm_attn_masks,
+ 'mlm_gather_index': mlm_gather_index,
+ 'mlm_txt_labels': mlm_txt_labels,
+ 'input_dict': tot_input_ids
+ }
+ return batch
+
+
+class PNSGDFromImage(DetectFeatTxtTokDataset):
+ def __init__(self, txt_db, img_db, neg_sample_size=1):
+ assert neg_sample_size > 0, "need at least 1 negative sample"
+ super().__init__(txt_db, img_db)
+
+ txt2img = self.txt_db.txt2img
+ self.txt2img = {id_: txt2img[id_] for id_ in self.ids}
+ self.img2txts = self.txt_db.img2txts
+ self.txt_name_list = list(self.txt2img.keys())
+ self.neg_sample_size = neg_sample_size
+
+ def __getitem__(self, i):
+ gt_txt_id = self.ids[i]
+ gt_img_id = self.txt2img[gt_txt_id]
+ gt_txt_ids = self.img2txts[gt_img_id]
+
+ # process image features (gt always first)
+ img_feat, img_pos_feat, nbb = self._get_img_feat(gt_img_id)
+ img_feat = img_feat.unsqueeze(0)
+ img_pos_feat = img_pos_feat.unsqueeze(0)
+
+ # sample negative
+ neg_txt_ids = sample_negative(
+ self.txt_name_list, gt_txt_ids, self.neg_sample_size)
+ txt_ids = [gt_txt_id] + neg_txt_ids
+
+ # process text inputs
+ all_inputs = []
+ txt_lens = []
+ for txt_id in txt_ids:
+ input_ids = self.txt_db.combine_inputs(
+ self.txt_db[txt_id]['input_ids'])
+ all_inputs.append(input_ids)
+ txt_lens.append(len(input_ids))
+ input_ids = pad_sequence(all_inputs, batch_first=True, padding_value=0)
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ attn_masks = torch.zeros(len(txt_ids), max(txt_lens) + nbb).long()
+ for i, tl in enumerate(txt_lens):
+ attn_masks.data[i, :tl+nbb].fill_(1)
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index(txt_lens, [nbb]*len(txt_ids),
+ len(txt_ids), tl, out_size)
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index}
+ return batch
+
+
+def pnsgd_collate(inputs):
+ assert len(inputs) == 1
+ return inputs[0]
+
+
+class ItmValDataset(DetectFeatTxtTokDataset):
+ """ For evaluating Image-Text-Retrieval task """
+ def __init__(self, db_dir, img_dir, mini_batch_size=400):
+ super().__init__(db_dir, img_dir)
+ del self.lens
+ self.txt2img = self.txt_db.txt2img
+ self.img2txts = self.txt_db.img2txts
+ self.all_img_ids = list(self.img2txts.keys())
+
+ assert len(self.img2txts) >= mini_batch_size > 0
+ self.bs = mini_batch_size
+
+ def _get_batch_ids(self, i):
+ gt_txt_id = self.ids[i]
+ gt_img_id = self.txt2img[gt_txt_id]
+
+ # sample fixed negatives for each gt image
+ i = self.all_img_ids.index(gt_img_id)
+ neg_st = i+1
+ neg_end = neg_st+self.bs-1
+ if neg_end > len(self.all_img_ids):
+ # warp around
+ neg_end -= len(self.all_img_ids)
+ neg_img_ids = (self.all_img_ids[neg_st:]
+ + self.all_img_ids[:neg_end])
+ else:
+ neg_img_ids = self.all_img_ids[neg_st:neg_end]
+
+ assert len(neg_img_ids) == (self.bs - 1),\
+ "Did not sample enough neg samples"
+
+ return gt_img_id, neg_img_ids
+
+ def __getitem__(self, i):
+ """ this returns list of mini-batches """
+ gt_img_id, neg_img_ids = self._get_batch_ids(i)
+ # NOTE 1st one is gt img
+ batch = self.get_batch(i, [gt_img_id] + neg_img_ids)
+ return batch
+
+ def get_batch(self, i, img_ids):
+ example = super().__getitem__(i)
+
+ input_ids = example['input_ids']
+ input_ids = self.txt_db.combine_inputs(input_ids)
+ input_ids = input_ids.unsqueeze(0).expand(len(img_ids), -1).clone()
+ position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long
+ ).unsqueeze(0)
+
+ # process image features (gt always first)
+ img_feats, img_pos_feats, num_bbs = map(
+ list, unzip(map(self._get_img_feat, img_ids)))
+ img_feat = pad_tensors(img_feats, num_bbs)
+ img_pos_feat = pad_tensors(img_pos_feats, num_bbs)
+
+ tl = input_ids.size(1)
+ attn_masks = torch.zeros(len(img_ids), max(num_bbs) + tl).long()
+ for i, nbb in enumerate(num_bbs):
+ attn_masks.data[i, :tl+nbb].fill_(1)
+ out_size = attn_masks.size(1)
+ gather_index = get_gather_index([tl]*len(img_ids), num_bbs,
+ len(img_ids), tl, out_size)
+
+ batch = {'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'img_feat': img_feat,
+ 'img_pos_feat': img_pos_feat,
+ 'attn_masks': attn_masks,
+ 'gather_index': gather_index}
+ return batch
+
+
+def itm_val_collate(inputs):
+ assert len(inputs) == 1, "input batch size > 1"
+ return inputs[0]
+
+
+class ItmEvalDataset(ItmValDataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.all_img_ids = sorted(copy.deepcopy(self.all_img_ids),
+ key=lambda i: self.img_db.name2nbb[i])
+
+ def __getitem__(self, i):
+ mini_batches = []
+ for st in range(0, len(self.all_img_ids), self.bs):
+ mini_batches.append(
+ self.get_batch(i, self.all_img_ids[st:st+self.bs]))
+ return mini_batches
+
+
+itm_eval_collate = itm_val_collate
diff --git a/data/sampler.py b/data/sampler.py
new file mode 100644
index 0000000..22863ce
--- /dev/null
+++ b/data/sampler.py
@@ -0,0 +1,58 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+sampler for length bucketing (batch by tokens)
+"""
+import random
+
+from torch.utils.data import Sampler
+from cytoolz import partition_all
+
+
+class TokenBucketSampler(Sampler):
+ def __init__(self, lens, bucket_size, batch_size,
+ droplast=False, size_multiple=8):
+ self._lens = lens
+ self._max_tok = batch_size
+ self._bucket_size = bucket_size
+ self._droplast = droplast
+ self._size_mul = size_multiple
+
+ def _create_ids(self):
+ return list(range(len(self._lens)))
+
+ def _sort_fn(self, i):
+ return self._lens[i]
+
+ def __iter__(self):
+ ids = self._create_ids()
+ random.shuffle(ids)
+ buckets = [sorted(ids[i:i+self._bucket_size],
+ key=self._sort_fn, reverse=True)
+ for i in range(0, len(ids), self._bucket_size)]
+ # fill batches until max_token (include padding)
+ batches = []
+ for bucket in buckets:
+ max_len = 0
+ batch_indices = []
+ for indices in partition_all(self._size_mul, bucket):
+ max_len = max(max_len, max(self._lens[i] for i in indices))
+ if (max_len * (len(batch_indices) + self._size_mul)
+ > self._max_tok):
+ if not batch_indices:
+ raise ValueError(
+ "max_tokens too small / max_seq_len too long")
+ assert len(batch_indices) % self._size_mul == 0
+ batches.append(batch_indices)
+ batch_indices = list(indices)
+ else:
+ batch_indices.extend(indices)
+ if not self._droplast and batch_indices:
+ batches.append(batch_indices)
+ random.shuffle(batches)
+ return iter(batches)
+
+ def __len__(self):
+ raise ValueError("NOT supported. "
+ "This has some randomness across epochs")
diff --git a/inf_nsgd.py b/inf_nsgd.py
new file mode 100644
index 0000000..03ad27a
--- /dev/null
+++ b/inf_nsgd.py
@@ -0,0 +1,177 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+run inference for Image Text Retrieval
+"""
+import argparse
+import json
+import os
+from os.path import exists
+import pickle
+from time import time
+
+import torch
+from torch.utils.data import DataLoader
+
+from apex import amp
+from horovod import torch as hvd
+
+from data import (PrefetchLoader,
+ DetectFeatLmdb, TxtTokLmdb, ItmEvalDataset, itm_eval_collate)
+from model.nsgd import UniterForNSGD
+
+from utils.logger import LOGGER
+from utils.distributed import all_gather_list
+from utils.misc import Struct
+from utils.const import IMG_DIM
+from utils.itm_eval import inference, itm_eval
+
+
+def main(opts):
+ hvd.init()
+ n_gpu = hvd.size()
+ device = torch.device("cuda", hvd.local_rank())
+ torch.cuda.set_device(hvd.local_rank())
+ rank = hvd.rank()
+ LOGGER.info("device: {} n_gpu: {}, rank: {}, "
+ "16-bits training: {}".format(
+ device, n_gpu, hvd.rank(), opts.fp16))
+
+ if opts.train_config is not None:
+ train_opts = Struct(json.load(open(opts.train_config)))
+ opts.conf_th = train_opts.conf_th
+ opts.max_bb = train_opts.max_bb
+ opts.min_bb = train_opts.min_bb
+ opts.num_bb = train_opts.num_bb
+
+ # load DBs and image dirs
+ eval_img_db = DetectFeatLmdb(opts.img_db,
+ opts.conf_th, opts.max_bb,
+ opts.min_bb, opts.num_bb,
+ opts.compressed_db)
+ eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
+ eval_dataset = ItmEvalDataset(eval_txt_db, eval_img_db, opts.batch_size)
+
+ # Prepare model
+ checkpoint = torch.load(opts.checkpoint)
+ model = UniterForNSGD.from_pretrained(
+ opts.model_config, checkpoint, img_dim=IMG_DIM)
+ if 'rank_output' not in checkpoint:
+ model.init_output() # zero shot setting
+
+ model.to(device)
+ model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')
+
+ print('| eval_dataset id2len: ', len(eval_dataset.id2len))
+ print('| eval_dataset_id2len example key: ', list(eval_dataset.id2len.keys())[:5],
+ max([int(k) for k in list(eval_dataset.id2len.keys())]))
+ print('| eval_dataset_id2len example value: ', list(eval_dataset.id2len.values())[:5])
+ # for k in range(10):
+ # print('| example:', k, eval_dataset.i2len[k])
+ print('| i2len: ', len(eval_dataset.id2len), min(list(eval_dataset.id2len.keys())), max(eval_dataset.id2len.keys()),
+ min(list(eval_dataset.id2len.values())), max(eval_dataset.id2len.values()))
+
+ print('| mean of all_txt_lens:', sum(list(eval_dataset.id2len.values())) / float(len(list(eval_dataset.id2len.values()))))
+
+
+ eval_dataloader = DataLoader(eval_dataset, batch_size=1,
+ num_workers=opts.n_workers,
+ pin_memory=opts.pin_mem,
+ collate_fn=itm_eval_collate)
+ eval_dataloader = PrefetchLoader(eval_dataloader)
+
+ eval_log, results = evaluate(model, eval_dataloader)
+ if hvd.rank() == 0:
+ if not exists(opts.output_dir) and rank == 0:
+ os.makedirs(opts.output_dir)
+ with open(f'{opts.output_dir}/config.json', 'w') as f:
+ json.dump(vars(opts), f)
+ with open(f'{opts.output_dir}/results.bin', 'wb') as f:
+ pickle.dump(results, f)
+ with open(f'{opts.output_dir}/scores.json', 'w') as f:
+ json.dump(eval_log, f)
+ LOGGER.info(f'evaluation finished')
+ LOGGER.info(
+ f"======================== Results =========================\n"
+ f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
+ f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
+ f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
+ f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
+ f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
+ f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
+ LOGGER.info("========================================================")
+
+
+@torch.no_grad()
+def evaluate(model, eval_loader):
+ model.eval()
+ st = time()
+ LOGGER.info("start running Image/Text Retrieval evaluation ...")
+ score_matrix = inference(model, eval_loader)
+ dset = eval_loader.dataset
+ all_score = hvd.allgather(score_matrix)
+ all_txt_ids = [i for ids in all_gather_list(dset.ids)
+ for i in ids]
+ # all_txt_lens = [l for lens in all_gather_list(dset.txt_lens) for l in lens]
+ print('| mean of all_txt_lens:', sum(list(dset.id2len.values())) / float(len(list(dset.id2len.values()))))
+ all_img_ids = dset.all_img_ids
+ assert all_score.size() == (len(all_txt_ids), len(all_img_ids))
+ if hvd.rank() != 0:
+ return {}, tuple()
+ # NOTE: only use rank0 to compute final scores
+ eval_log = itm_eval(all_score, all_txt_ids, all_img_ids,
+ dset.txt2img, dset.img2txts, dset.id2len)
+
+ results = (all_score, all_txt_ids, all_img_ids)
+ tot_time = time()-st
+ LOGGER.info(f"evaluation finished in {int(tot_time)} seconds, ")
+ return eval_log, results
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+ parser.add_argument("--txt_db", default=None, type=str,
+ help="The input train corpus. (LMDB)")
+ parser.add_argument("--img_db", default=None, type=str,
+ help="The input train images.")
+ parser.add_argument("--checkpoint", default=None, type=str,
+ help="model checkpoint binary")
+ parser.add_argument("--model_config", default=None, type=str,
+ help="model config json")
+ parser.add_argument(
+ "--output_dir", default=None, type=str,
+ help="The output directory where the inference results will be "
+ "written.")
+
+ # optional parameters
+ parser.add_argument("--train_config", default=None, type=str,
+ help="hps.json from training (for prepro hps)")
+ parser.add_argument('--compressed_db', action='store_true',
+ help='use compressed LMDB')
+ parser.add_argument('--conf_th', type=float, default=0.2,
+ help='threshold for dynamic bounding boxes '
+ '(-1 for fixed)')
+ parser.add_argument('--max_bb', type=int, default=100,
+ help='max number of bounding boxes')
+ parser.add_argument('--min_bb', type=int, default=10,
+ help='min number of bounding boxes')
+ parser.add_argument('--num_bb', type=int, default=36,
+ help='static number of bounding boxes')
+ parser.add_argument("--batch_size", default=400, type=int,
+ help="number of tokens in a batch")
+
+ # device parameters
+ parser.add_argument('--fp16', action='store_true',
+ help="Whether to use 16-bit float precision instead "
+ "of 32-bit")
+ parser.add_argument('--n_workers', type=int, default=4,
+ help="number of data workers")
+ parser.add_argument('--pin_mem', action='store_true',
+ help="pin memory")
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/launch_container.sh b/launch_container.sh
new file mode 100644
index 0000000..45ceefb
--- /dev/null
+++ b/launch_container.sh
@@ -0,0 +1,21 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+TXT_DB=$1
+IMG_DIR=$2
+OUTPUT=$3
+PRETRAIN_DIR=$4
+
+if [ -z $CUDA_VISIBLE_DEVICES ]; then
+ CUDA_VISIBLE_DEVICES='all'
+fi
+
+
+docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm -it \
+ --mount src=$(pwd),dst=/src,type=bind \
+ --mount src=$OUTPUT,dst=/storage,type=bind \
+ --mount src=$PRETRAIN_DIR,dst=/pretrain,type=bind,readonly \
+ --mount src=$TXT_DB,dst=/txt,type=bind,readonly \
+ --mount src=$IMG_DIR,dst=/img,type=bind,readonly \
+ -e NVIDIA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \
+ -w /src chenrocks/uniter
\ No newline at end of file
diff --git a/model/attention.py b/model/attention.py
new file mode 100644
index 0000000..7d7e2b0
--- /dev/null
+++ b/model/attention.py
@@ -0,0 +1,402 @@
+"""
+copy multi-head attention code from pytorch
+(https://github.com/pytorch/pytorch),
+"""
+import warnings
+
+import torch
+from torch.nn import Module, Parameter, Linear
+from torch.nn.init import xavier_normal_, xavier_uniform_, constant_
+from torch.nn.functional import linear, softmax, dropout
+
+
+def multi_head_attention_forward(query, # type: Tensor
+ key, # type: Tensor
+ value, # type: Tensor
+ embed_dim_to_check, # type: int
+ num_heads, # type: int
+ in_proj_weight, # type: Tensor
+ in_proj_bias, # type: Tensor
+ bias_k, # type: Optional[Tensor]
+ bias_v, # type: Optional[Tensor]
+ add_zero_attn, # type: bool
+ dropout_p, # type: float
+ out_proj_weight, # type: Tensor
+ out_proj_bias, # type: Tensor
+ training=True, # type: bool
+ key_padding_mask=None, # type: Optional[Tensor]
+ need_weights=True, # type: bool
+ attn_mask=None, # type: Optional[Tensor]
+ use_separate_proj_weight=False, # type: bool
+ q_proj_weight=None, # type: Optional[Tensor]
+ k_proj_weight=None, # type: Optional[Tensor]
+ v_proj_weight=None, # type: Optional[Tensor]
+ static_k=None, # type: Optional[Tensor]
+ static_v=None # type: Optional[Tensor]
+ ):
+ # type: (...) -> Tuple[Tensor, Optional[Tensor]]
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ embed_dim_to_check: total dimension of the model.
+ num_heads: parallel attention heads.
+ in_proj_weight, in_proj_bias: input projection weight and bias.
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ dropout_p: probability of an element to be zeroed.
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
+ training: apply dropout if is ``True``.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ attn_mask: mask that prevents attention to certain positions. This is an additive mask
+ (i.e. the values will be added to the attention layer).
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
+ and value in differnt forms. If false, in_proj_weight will be used, which is
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
+ static_k, static_v: static key and value used for attention operators.
+
+
+ Shape:
+ Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
+ - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+
+ Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+
+ qkv_same = torch.equal(query, key) and torch.equal(key, value)
+ kv_same = torch.equal(key, value)
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == embed_dim_to_check
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ assert key.size() == value.size()
+
+ head_dim = embed_dim // num_heads
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
+ scaling = float(head_dim) ** -0.5
+
+ if use_separate_proj_weight is not True:
+ if qkv_same:
+ # self-attention
+ q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
+
+ elif kv_same:
+ # encoder-decoder attention
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = linear(query, _w, _b)
+
+ if key is None:
+ assert value is None
+ k = None
+ v = None
+ else:
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ k, v = linear(key, _w, _b).chunk(2, dim=-1)
+
+ else:
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = linear(query, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim
+ _end = embed_dim * 2
+ _w = in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ k = linear(key, _w, _b)
+
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
+ _b = in_proj_bias
+ _start = embed_dim * 2
+ _end = None
+ _w = in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ v = linear(value, _w, _b)
+ else:
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
+ len1, len2 = q_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == query.size(-1)
+
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
+ len1, len2 = k_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == key.size(-1)
+
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
+ len1, len2 = v_proj_weight_non_opt.size()
+ assert len1 == embed_dim and len2 == value.size(-1)
+
+ if in_proj_bias is not None:
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
+ else:
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
+ q = q * scaling
+
+ if bias_k is not None and bias_v is not None:
+ if static_k is None and static_v is None:
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask,
+ torch.zeros((attn_mask.size(0), 1),
+ dtype=attn_mask.dtype,
+ device=attn_mask.device)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
+ dtype=key_padding_mask.dtype,
+ device=key_padding_mask.device)], dim=1)
+ else:
+ assert static_k is None, "bias cannot be added to static key."
+ assert static_v is None, "bias cannot be added to static value."
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+
+ if static_k is not None:
+ assert static_k.size(0) == bsz * num_heads
+ assert static_k.size(2) == head_dim
+ k = static_k
+
+ if static_v is not None:
+ assert static_v.size(0) == bsz * num_heads
+ assert static_v.size(2) == head_dim
+ v = static_v
+
+ src_len = k.size(1)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
+ dtype=attn_mask.dtype,
+ device=attn_mask.device)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
+ dtype=key_padding_mask.dtype,
+ device=key_padding_mask.device)], dim=1)
+
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ attn_output_weights += attn_mask
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float('-inf'),
+ )
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
+
+ attn_output_weights = softmax(
+ attn_output_weights, dim=-1)
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
+
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
+ else:
+ return attn_output, None
+
+
+class MultiheadAttention(Module):
+ r"""Allows the model to jointly attend to information
+ from different representation subspaces.
+ See reference: Attention Is All You Need
+
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
+
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+ bias: add bias as module parameter. Default: True.
+ add_bias_kv: add bias to the key and value sequences at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ kdim: total number of features in key. Default: None.
+ vdim: total number of features in key. Default: None.
+
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
+ query, key, and value have the same number of features.
+
+ Examples::
+
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+ """
+
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
+
+ if self._qkv_same_embed_dim is False:
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+ self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ if self._qkv_same_embed_dim:
+ xavier_uniform_(self.in_proj_weight)
+ else:
+ xavier_uniform_(self.q_proj_weight)
+ xavier_uniform_(self.k_proj_weight)
+ xavier_uniform_(self.v_proj_weight)
+
+ if self.in_proj_bias is not None:
+ constant_(self.in_proj_bias, 0.)
+ constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ xavier_normal_(self.bias_v)
+
+ def forward(self, query, key, value, key_padding_mask=None,
+ need_weights=True, attn_mask=None):
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ attn_mask: mask that prevents attention to certain positions. This is an additive mask
+ (i.e. the values will be added to the attention layer).
+
+ Shape:
+ - Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
+ - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+
+ - Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
+ return multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+ else:
+ if not hasattr(self, '_qkv_same_embed_dim'):
+ warnings.warn('A new version of MultiheadAttention module has been implemented. \
+ Please re-train your model with the new module',
+ UserWarning)
+
+ return multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask)
diff --git a/model/itm.py b/model/itm.py
new file mode 100644
index 0000000..a7aefd0
--- /dev/null
+++ b/model/itm.py
@@ -0,0 +1,140 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+UNITER for ITM model
+"""
+from collections import defaultdict
+
+import torch
+from torch import nn
+from .model import UniterPreTrainedModel, UniterModel
+
+
+class UniterForImageTextRetrieval(UniterPreTrainedModel):
+ """ Finetune UNITER for image text retrieval
+ """
+ def __init__(self, config, img_dim, margin=0.2):
+ super().__init__(config)
+ self.uniter = UniterModel(config, img_dim)
+ self.itm_output = nn.Linear(config.hidden_size, 2)
+ self.rank_output = nn.Linear(config.hidden_size, 1)
+ self.margin = margin
+ self.apply(self.init_weights)
+
+ def init_output(self):
+ """ need to be called after from pretrained """
+ self.rank_output.weight.data = self.itm_output.weight.data[1:, :]
+ self.rank_output.bias.data = self.itm_output.bias.data[1:]
+
+ def forward(self, batch, compute_loss=True):
+ batch = defaultdict(lambda: None, batch)
+ input_ids = batch['input_ids']
+ position_ids = batch['position_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ attention_mask = batch['attn_masks']
+ gather_index = batch['gather_index']
+ sequence_output = self.uniter(input_ids, position_ids,
+ img_feat, img_pos_feat,
+ attention_mask, gather_index,
+ output_all_encoded_layers=False)
+ pooled_output = self.uniter.pooler(sequence_output)
+ rank_scores = self.rank_output(pooled_output)
+
+ if compute_loss:
+ # triplet loss
+ rank_scores_sigmoid = torch.sigmoid(rank_scores)
+ sample_size = batch['sample_size']
+ scores = rank_scores_sigmoid.contiguous().view(-1, sample_size)
+ pos = scores[:, :1]
+ neg = scores[:, 1:]
+ rank_loss = torch.clamp(self.margin + neg - pos, 0)
+ # print('| success ratio: ', neg.lt(pos).float().sum().div(neg.numel()))
+ return rank_loss
+ else:
+ return rank_scores
+
+
+class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval):
+ """ Finetune UNITER for image text retrieval
+ """
+ def __init__(self, config, img_dim, margin=0.2, hard_size=16):
+ super().__init__(config, img_dim, margin)
+ self.hard_size = hard_size
+
+ def forward(self, batch, sample_from='t', compute_loss=True):
+ # expect same input_ids for all pairs
+ batch_size = batch['attn_masks'].size(0)
+ input_ids = batch['input_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ if sample_from == 't':
+ if input_ids.size(0) == 1:
+ batch['input_ids'] = input_ids.expand(batch_size, -1)
+ elif sample_from == 'i':
+ if img_feat.size(0) == 1:
+ batch['img_feat'] = img_feat.expand(batch_size, -1, -1)
+ if img_pos_feat.size(0) == 1:
+ batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1)
+ else:
+ raise ValueError()
+
+ if self.training and compute_loss:
+ with torch.no_grad():
+ self.eval()
+ scores = super().forward(batch, compute_loss=False)
+ hard_batch = self._get_hard_batch(batch, scores, sample_from)
+ self.train()
+ return super().forward(hard_batch, compute_loss=True)
+ else:
+ return super().forward(batch, compute_loss)
+
+ def _get_hard_batch(self, batch, scores, sample_from='t'):
+ batch = defaultdict(lambda: None, batch)
+ input_ids = batch['input_ids']
+ position_ids = batch['position_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ attention_mask = batch['attn_masks']
+ gather_index = batch['gather_index']
+ hard_batch = {'sample_size': self.hard_size + 1}
+
+ # NOTE first example is positive
+ hard_indices = scores.squeeze(-1)[1:].topk(
+ self.hard_size, sorted=False)[1] + 1
+ indices = torch.cat([torch.zeros(1, dtype=torch.long,
+ device=hard_indices.device),
+ hard_indices])
+
+ attention_mask = attention_mask.index_select(0, indices)
+ gather_index = gather_index.index_select(0, indices)
+ if position_ids.size(0) != 1:
+ position_ids = position_ids[:self.hard_size+1]
+
+ if sample_from == 't':
+ # cut to minimum padding
+ max_len = attention_mask.sum(dim=1).max().item()
+ max_i = max_len - input_ids.size(1)
+ attention_mask = attention_mask[:, :max_len]
+ gather_index = gather_index[:, :max_len]
+ img_feat = img_feat.index_select(0, indices)[:, :max_i, :]
+ img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :]
+ # expect same input_ids for all pairs
+ input_ids = input_ids[:self.hard_size+1]
+ elif sample_from == 'i':
+ input_ids = input_ids.index_select(0, indices)
+ # expect same image features for all pairs
+ img_feat = img_feat[:self.hard_size+1]
+ img_pos_feat = img_pos_feat[:self.hard_size+1]
+ else:
+ raise ValueError()
+
+ hard_batch['input_ids'] = input_ids
+ hard_batch['position_ids'] = position_ids
+ hard_batch['img_feat'] = img_feat
+ hard_batch['img_pos_feat'] = img_pos_feat
+ hard_batch['attn_masks'] = attention_mask
+ hard_batch['gather_index'] = gather_index
+
+ return hard_batch
diff --git a/model/layer.py b/model/layer.py
new file mode 100644
index 0000000..ed0203f
--- /dev/null
+++ b/model/layer.py
@@ -0,0 +1,233 @@
+"""
+BERT layers from the huggingface implementation
+(https://github.com/huggingface/transformers)
+"""
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import math
+
+import torch
+from torch import nn
+from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
+
+
+logger = logging.getLogger(__name__)
+
+
+def gelu(x):
+ """Implementation of the gelu activation function.
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
+ Also see https://arxiv.org/abs/1606.08415
+ """
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+def swish(x):
+ return x * torch.sigmoid(x)
+
+
+ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
+
+
+class GELU(nn.Module):
+ def forward(self, input_):
+ output = gelu(input_)
+ return output
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config):
+ super(BertSelfAttention, self).__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states, attention_mask):
+ mixed_query_layer = self.query(hidden_states)
+ mixed_key_layer = self.key(hidden_states)
+ mixed_value_layer = self.value(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+ key_layer = self.transpose_for_scores(mixed_key_layer)
+ value_layer = self.transpose_for_scores(mixed_value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ return context_layer
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super(BertSelfOutput, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config):
+ super(BertAttention, self).__init__()
+ self.self = BertSelfAttention(config)
+ self.output = BertSelfOutput(config)
+
+ def forward(self, input_tensor, attention_mask):
+ self_output = self.self(input_tensor, attention_mask)
+ attention_output = self.output(self_output, input_tensor)
+ return attention_output
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super(BertIntermediate, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super(BertOutput, self).__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config):
+ super(BertLayer, self).__init__()
+ self.attention = BertAttention(config)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(self, hidden_states, attention_mask):
+ attention_output = self.attention(hidden_states, attention_mask)
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super(BertPooler, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super(BertPredictionHeadTransform, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config, bert_model_embedding_weights):
+ super(BertLMPredictionHead, self).__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
+ bert_model_embedding_weights.size(0),
+ bias=False)
+ self.decoder.weight = bert_model_embedding_weights
+ self.bias = nn.Parameter(
+ torch.zeros(bert_model_embedding_weights.size(0)))
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states) + self.bias
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config, bert_model_embedding_weights):
+ super(BertOnlyMLMHead, self).__init__()
+ self.predictions = BertLMPredictionHead(config,
+ bert_model_embedding_weights)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
diff --git a/model/model.py b/model/model.py
new file mode 100644
index 0000000..7b2e6a9
--- /dev/null
+++ b/model/model.py
@@ -0,0 +1,367 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Pytorch modules
+some classes are modified from HuggingFace
+(https://github.com/huggingface/transformers)
+"""
+import copy
+import json
+import logging
+from io import open
+
+import torch
+from torch import nn
+from apex.normalization.fused_layer_norm import FusedLayerNorm
+
+from .layer import BertLayer, BertPooler
+
+
+logger = logging.getLogger(__name__)
+
+
+class UniterConfig(object):
+ """Configuration class to store the configuration of a `UniterModel`.
+ """
+ def __init__(self,
+ vocab_size_or_config_json_file,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02):
+ """Constructs UniterConfig.
+ Args:
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in
+ `UniterModel`.
+ hidden_size: Size of the encoder layers and the pooler layer.
+ num_hidden_layers: Number of hidden layers in the Transformer
+ encoder.
+ num_attention_heads: Number of attention heads for each attention
+ layer in the Transformer encoder.
+ intermediate_size: The size of the "intermediate" (i.e.
+ feed-forward) layer in the Transformer encoder.
+ hidden_act: The non-linear activation function (function or string)
+ in the encoder and pooler. If string, "gelu", "relu" and
+ "swish" are supported.
+ hidden_dropout_prob: The dropout probabilitiy for all fully
+ connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob: The dropout ratio for the attention
+ probabilities.
+ max_position_embeddings: The maximum sequence length that this
+ model might ever be used with. Typically set this to something
+ large just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed
+ into `UniterModel`.
+ initializer_range: The sttdev of the truncated_normal_initializer
+ for initializing all weight matrices.
+ """
+ if isinstance(vocab_size_or_config_json_file, str):
+ with open(vocab_size_or_config_json_file,
+ "r", encoding='utf-8') as reader:
+ json_config = json.loads(reader.read())
+ for key, value in json_config.items():
+ self.__dict__[key] = value
+ elif isinstance(vocab_size_or_config_json_file, int):
+ self.vocab_size = vocab_size_or_config_json_file
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ else:
+ raise ValueError("First argument must be either a vocabulary size "
+ "(int) or the path to a pretrained model config "
+ "file (str)")
+
+ @classmethod
+ def from_dict(cls, json_object):
+ """Constructs a `UniterConfig` from a
+ Python dictionary of parameters."""
+ config = UniterConfig(vocab_size_or_config_json_file=-1)
+ for key, value in json_object.items():
+ config.__dict__[key] = value
+ return config
+
+ @classmethod
+ def from_json_file(cls, json_file):
+ """Constructs a `UniterConfig` from a json file of parameters."""
+ with open(json_file, "r", encoding='utf-8') as reader:
+ text = reader.read()
+ return cls.from_dict(json.loads(text))
+
+ def __repr__(self):
+ return str(self.to_json_string())
+
+ def to_dict(self):
+ """Serializes this instance to a Python dictionary."""
+ output = copy.deepcopy(self.__dict__)
+ return output
+
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
+
+
+class UniterPreTrainedModel(nn.Module):
+ """ An abstract class to handle weights initialization and
+ a simple interface for dowloading and loading pretrained models.
+ """
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__()
+ if not isinstance(config, UniterConfig):
+ raise ValueError(
+ "Parameter config in `{}(config)` should be an instance of "
+ "class `UniterConfig`. To create a model from a Google "
+ "pretrained model use "
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
+ self.__class__.__name__, self.__class__.__name__
+ ))
+ self.config = config
+
+ def init_weights(self, module):
+ """ Initialize the weights.
+ """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses
+ # truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0,
+ std=self.config.initializer_range)
+ elif isinstance(module, FusedLayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+ @classmethod
+ def from_pretrained(cls, config_file, state_dict, *inputs, **kwargs):
+ """
+ Instantiate a UniterPreTrainedModel from a pre-trained model file or a
+ pytorch state dict.
+ Params:
+ config_file: config json file
+ state_dict: an state dictionnary
+ *inputs, **kwargs: additional input for the specific Uniter class
+ """
+ # Load config
+ config = UniterConfig.from_json_file(config_file)
+ logger.info("Model config {}".format(config))
+ # Instantiate model.
+ model = cls(config, *inputs, **kwargs)
+ # Load from a PyTorch state_dict
+ old_keys = []
+ new_keys = []
+ for key in state_dict.keys():
+ new_key = None
+ if 'gamma' in key:
+ new_key = key.replace('gamma', 'weight')
+ if 'beta' in key:
+ new_key = key.replace('beta', 'bias')
+ if new_key:
+ old_keys.append(key)
+ new_keys.append(new_key)
+ for old_key, new_key in zip(old_keys, new_keys):
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ def load(module, prefix=''):
+ local_metadata = ({} if metadata is None
+ else metadata.get(prefix[:-1], {}))
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys,
+ unexpected_keys, error_msgs)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+ start_prefix = ''
+ if not hasattr(model, 'bert') and any(s.startswith('bert.')
+ for s in state_dict.keys()):
+ start_prefix = 'bert.'
+ load(model, prefix=start_prefix)
+ if len(missing_keys) > 0:
+ logger.info("Weights of {} not initialized from "
+ "pretrained model: {}".format(
+ model.__class__.__name__, missing_keys))
+ if len(unexpected_keys) > 0:
+ logger.info("Weights from pretrained model not used in "
+ "{}: {}".format(
+ model.__class__.__name__, unexpected_keys))
+ if len(error_msgs) > 0:
+ raise RuntimeError('Error(s) in loading state_dict for '
+ '{}:\n\t{}'.format(
+ model.__class__.__name__,
+ "\n\t".join(error_msgs)))
+ return model
+
+
+class UniterTextEmbeddings(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size,
+ config.hidden_size, padding_idx=0)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings,
+ config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
+ config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model
+ # variable name and be able to load any TensorFlow checkpoint file
+ self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, input_ids, position_ids, token_type_ids=None):
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ words_embeddings = self.word_embeddings(input_ids)
+ position_embeddings = self.position_embeddings(position_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = (words_embeddings
+ + position_embeddings
+ + token_type_embeddings)
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class UniterImageEmbeddings(nn.Module):
+ def __init__(self, config, img_dim):
+ super().__init__()
+ self.img_linear = nn.Linear(img_dim, config.hidden_size)
+ self.img_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12)
+ self.pos_layer_norm = FusedLayerNorm(config.hidden_size, eps=1e-12)
+ self.pos_linear = nn.Linear(7, config.hidden_size)
+ self.mask_embedding = nn.Embedding(2, img_dim, padding_idx=0)
+
+ # tf naming convention for layer norm
+ self.LayerNorm = FusedLayerNorm(config.hidden_size, eps=1e-12)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, img_feat, img_pos_feat, type_embeddings, img_masks=None):
+ if img_masks is not None:
+ self.mask_embedding.weight.data[0, :].fill_(0)
+ mask = self.mask_embedding(img_masks.long())
+ img_feat = img_feat + mask
+
+ transformed_im = self.img_layer_norm(self.img_linear(img_feat))
+ transformed_pos = self.pos_layer_norm(self.pos_linear(img_pos_feat))
+ embeddings = transformed_im + transformed_pos + type_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class UniterEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ layer = BertLayer(config)
+ self.layer = nn.ModuleList([copy.deepcopy(layer)
+ for _ in range(config.num_hidden_layers)])
+
+ def forward(self, input_, attention_mask,
+ output_all_encoded_layers=True):
+ all_encoder_layers = []
+ hidden_states = input_
+ for layer_module in self.layer:
+ hidden_states = layer_module(hidden_states, attention_mask)
+ if output_all_encoded_layers:
+ all_encoder_layers.append(hidden_states)
+ if not output_all_encoded_layers:
+ all_encoder_layers.append(hidden_states)
+ return all_encoder_layers
+
+
+class UniterModel(UniterPreTrainedModel):
+ """ Modification for Joint Vision-Language Encoding
+ """
+ def __init__(self, config, img_dim):
+ super().__init__(config)
+ self.embeddings = UniterTextEmbeddings(config)
+ self.img_embeddings = UniterImageEmbeddings(config, img_dim)
+ self.encoder = UniterEncoder(config)
+ self.pooler = BertPooler(config)
+ self.apply(self.init_weights)
+
+ def _compute_txt_embeddings(self, input_ids, position_ids,
+ txt_type_ids=None):
+ output = self.embeddings(input_ids, position_ids, txt_type_ids)
+ return output
+
+ def _compute_img_embeddings(self, img_feat, img_pos_feat, img_masks=None,
+ img_type_ids=None):
+ if img_type_ids is None:
+ img_type_ids = torch.ones_like(img_feat[:, :, 0].long())
+ img_type_embeddings = self.embeddings.token_type_embeddings(
+ img_type_ids)
+ output = self.img_embeddings(img_feat, img_pos_feat,
+ img_type_embeddings, img_masks)
+ return output
+
+ def _compute_img_txt_embeddings(self, input_ids, position_ids,
+ img_feat, img_pos_feat,
+ gather_index, img_masks=None,
+ txt_type_ids=None, img_type_ids=None):
+ txt_emb = self._compute_txt_embeddings(
+ input_ids, position_ids, txt_type_ids)
+ img_emb = self._compute_img_embeddings(
+ img_feat, img_pos_feat, img_masks, img_type_ids)
+ # align back to most compact input
+ gather_index = gather_index.unsqueeze(-1).expand(
+ -1, -1, self.config.hidden_size)
+ txt_img_emb = torch.cat([txt_emb, img_emb], dim=1)
+ embedding_output = torch.gather(txt_img_emb, dim=1, index=gather_index)
+ return embedding_output
+
+ def forward(self, input_ids, position_ids,
+ img_feat, img_pos_feat,
+ attention_mask, gather_index=None, img_masks=None,
+ output_all_encoded_layers=True,
+ txt_type_ids=None, img_type_ids=None):
+ # compute self-attention mask
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ extended_attention_mask = extended_attention_mask.to(
+ dtype=next(self.parameters()).dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+
+ # embedding layer
+ if input_ids is None:
+ # image only
+ embedding_output = self._compute_img_embeddings(
+ img_feat, img_pos_feat, img_masks, img_type_ids)
+ elif img_feat is None:
+ # text only
+ embedding_output = self._compute_txt_embeddings(
+ input_ids, position_ids, txt_type_ids)
+ else:
+ embedding_output = self._compute_img_txt_embeddings(
+ input_ids, position_ids,
+ img_feat, img_pos_feat,
+ gather_index, img_masks, txt_type_ids, img_type_ids)
+
+ encoded_layers = self.encoder(
+ embedding_output, extended_attention_mask,
+ output_all_encoded_layers=output_all_encoded_layers)
+ if not output_all_encoded_layers:
+ encoded_layers = encoded_layers[-1]
+ return encoded_layers
diff --git a/model/nsgd.py b/model/nsgd.py
new file mode 100644
index 0000000..a9bc915
--- /dev/null
+++ b/model/nsgd.py
@@ -0,0 +1,358 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+UNITER for ITM model
+"""
+from collections import defaultdict
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from .model import UniterPreTrainedModel, UniterModel
+from .layer import GELU, BertOnlyMLMHead
+
+
+def repeat_interleave(x, n_repeat=1, dim=0):
+ repeat_list = [1] * (dim + 1) + [n_repeat] + [1] * (x.dim() - dim - 1)
+ x_size = list(x.size())
+ x_size[dim] = x_size[dim] * n_repeat
+ x = x.unsqueeze(dim+1).repeat(*repeat_list).view(*x_size)
+ return x
+
+
+class UniterForNSGD(UniterPreTrainedModel):
+ """ Finetune UNITER for image text retrieval
+ """
+ def __init__(self, config, img_dim, margin=0.2, hard_size=16, nsgd_sample_size=16, nsgd_sample_temperature=1.0):
+ super().__init__(config)
+ self.uniter = UniterModel(config, img_dim)
+ self.cls = BertOnlyMLMHead(
+ config, self.uniter.embeddings.word_embeddings.weight)
+ self.itm_output = nn.Linear(config.hidden_size, 2)
+ self.rank_output = nn.Linear(config.hidden_size, 1)
+ self.margin = margin
+ self.apply(self.init_weights)
+ self.hard_size = hard_size
+ self.nsgd_sample_size = nsgd_sample_size
+ self.nsgd_sample_temperature = nsgd_sample_temperature
+
+ def init_output(self):
+ """ need to be called after from pretrained """
+ self.rank_output.weight.data = self.itm_output.weight.data[1:, :]
+ self.rank_output.bias.data = self.itm_output.bias.data[1:]
+
+ def forward_uniter(
+ self,
+ sample_size=None,
+ input_ids=None, position_ids=None,
+ img_feat=None, img_pos_feat=None,
+ attn_masks=None, gather_index=None,
+ compute_loss=True, sigmoid_norm=False
+ ):
+ model_outputs = {}
+ sequence_output = self.uniter(
+ input_ids, position_ids,
+ img_feat, img_pos_feat,
+ attn_masks, gather_index,
+ output_all_encoded_layers=False
+ )
+ pooled_output = self.uniter.pooler(sequence_output)
+ rank_scores = self.rank_output(pooled_output)
+ model_outputs['rank_scores'] = rank_scores
+ if compute_loss:
+ # triplet loss
+ rank_scores_sigmoid = torch.sigmoid(rank_scores)
+ # sample_size = batch['sample_size']
+ scores = rank_scores_sigmoid.contiguous().view(-1, sample_size)
+ pos = scores[:, :1]
+ neg = scores[:, 1:]
+ rank_loss = torch.clamp(self.margin + neg - pos, 0)
+ rank_corrects = neg.sub(pos).le(0).float()
+ model_outputs['rank_loss'] = rank_loss
+ model_outputs['rank_corrects'] = rank_corrects
+ if sigmoid_norm:
+ rank_scores_sigmoid = torch.sigmoid(rank_scores)
+ model_outputs['rank_scores_sigmoid'] = rank_scores_sigmoid
+ # sample_size = batch['sample_size']
+ return model_outputs
+
+ def _compute_masked_hidden(self, hidden, mask):
+ """ get only the masked region (don't compute unnecessary hiddens) """
+ mask = mask.unsqueeze(-1).expand_as(hidden)
+ hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1))
+ return hidden_masked
+
+ def forward_mlm(
+ self,
+ input_ids=None, position_ids=None,
+ img_feat=None, img_pos_feat=None,
+ attn_masks=None, gather_index=None,
+ txt_labels=None, compute_loss=True, sampling=True
+ ):
+ model_outputs = {}
+ sequence_output = self.uniter(
+ input_ids, position_ids,
+ img_feat, img_pos_feat,
+ attn_masks, gather_index,
+ output_all_encoded_layers=False)
+ # get only the text part
+ sequence_output = sequence_output[:, :input_ids.size(1), :]
+
+ if compute_loss:
+ # only compute masked tokens for better efficiency
+ masked_output = self._compute_masked_hidden(sequence_output, txt_labels != -1)
+ prediction_scores = self.cls(masked_output)
+ masked_lm_loss = F.cross_entropy(prediction_scores, txt_labels[txt_labels != -1], reduction='none')
+ model_outputs['masked_lm_loss'] = masked_lm_loss
+ model_outputs['mlm_corrects'] = prediction_scores.max(-1)[1].eq(txt_labels[txt_labels != -1]).float()
+ if sampling:
+ bsz, caption_len = input_ids.size(0), input_ids.size(1)
+ prediction_scores = self.cls(sequence_output)
+ sample_caption_tokens = torch.multinomial(
+ prediction_scores.div(self.nsgd_sample_temperature).softmax(-1).view(-1, prediction_scores.size(-1)),
+ num_samples=self.nsgd_sample_size,
+ replacement=True,
+ ).view(bsz, caption_len, self.nsgd_sample_size).permute(0, 2, 1)
+ mask_indicator = txt_labels.ne(-1).long().unsqueeze(1)
+ synthetic_input_ids = input_ids.unsqueeze(1).mul(1-mask_indicator).\
+ add(sample_caption_tokens.mul(mask_indicator)).reshape(-1, caption_len)
+ model_outputs['fill_input_ids'] = synthetic_input_ids
+ return model_outputs
+
+ def forward(self, batch, sample_from='t', compute_loss=True, compute_mlm=False):
+ # expect same input_ids for all pairs
+ model_outputs = {}
+ if not sample_from.startswith('g'):
+ batch_size = batch['attn_masks'].size(0)
+ input_ids = batch['input_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ if sample_from == 't':
+ if input_ids.size(0) == 1:
+ batch['input_ids'] = input_ids.expand(batch_size, -1)
+ elif sample_from == 'i':
+ if img_feat.size(0) == 1:
+ batch['img_feat'] = img_feat.expand(batch_size, -1, -1)
+ if img_pos_feat.size(0) == 1:
+ batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1)
+ else:
+ raise ValueError()
+
+ if self.training and compute_loss:
+ with torch.no_grad():
+ self.eval()
+ # print(f'| is_training: {self.training} | compute_loss: {compute_loss} |')
+ if torch.isnan(batch['input_ids']).sum().item() > 0 or \
+ torch.isnan(batch['position_ids']).sum().item() > 0 or \
+ torch.isnan(batch['img_feat']).sum().item() > 0 or \
+ torch.isnan(batch['img_pos_feat']).sum().item() > 0 or \
+ torch.isnan(batch['attn_masks']).sum().item() > 0 or \
+ torch.isnan(batch['gather_index']).sum().item() > 0:
+ print(' | nan appear!')
+ if torch.isinf(batch['input_ids']).sum().item() > 0 or \
+ torch.isinf(batch['position_ids']).sum().item() > 0 or \
+ torch.isinf(batch['img_feat']).sum().item() > 0 or \
+ torch.isinf(batch['img_pos_feat']).sum().item() > 0 or \
+ torch.isinf(batch['attn_masks']).sum().item() > 0 or \
+ torch.isinf(batch['gather_index']).sum().item() > 0:
+ print(' | inf appear!')
+
+ forward_uniter_outputs = self.forward_uniter(
+ input_ids=batch['input_ids'], position_ids=batch['position_ids'],
+ img_feat=batch['img_feat'], img_pos_feat=batch['img_pos_feat'],
+ attn_masks=batch['attn_masks'], gather_index=batch['gather_index'],
+ compute_loss=False,
+ )
+ hard_batch = self._get_hard_batch(batch, forward_uniter_outputs['rank_scores'], sample_from)
+ self.train()
+ forward_uniter_outputs = self.forward_uniter(
+ sample_size=hard_batch['sample_size'],
+ input_ids=hard_batch['input_ids'], position_ids=hard_batch['position_ids'],
+ img_feat=hard_batch['img_feat'], img_pos_feat=hard_batch['img_pos_feat'],
+ attn_masks=hard_batch['attn_masks'], gather_index=hard_batch['gather_index'],
+ compute_loss=True)
+ model_outputs['rank_loss'] = forward_uniter_outputs['rank_loss']
+ model_outputs['rank_corrects'] = forward_uniter_outputs['rank_corrects']
+ else:
+ forward_uniter_outputs = self.forward_uniter(
+ input_ids=batch['input_ids'], position_ids=batch['position_ids'],
+ img_feat=batch['img_feat'], img_pos_feat=batch['img_pos_feat'],
+ attn_masks=batch['attn_masks'], gather_index=batch['gather_index'],
+ compute_loss=compute_loss, sigmoid_norm=True)
+ model_outputs.update(forward_uniter_outputs)
+ return model_outputs
+ else:
+ if compute_mlm:
+ self.train()
+ mlm_outputs = self.forward_mlm(
+ input_ids=batch['mlm_input_ids'], position_ids=batch['mlm_position_ids'],
+ img_feat=batch['mlm_img_feat'], img_pos_feat=batch['mlm_img_pos_feat'],
+ attn_masks=batch['mlm_attn_masks'], gather_index=batch['mlm_gather_index'],
+ txt_labels=batch['mlm_txt_labels'], compute_loss=True, sampling=True
+ )
+ model_outputs['masked_lm_loss'] = mlm_outputs['masked_lm_loss']
+ # model_outputs['effect_nsgd_number'] = mlm_outputs['effect_nsgd_number']
+ model_outputs['mlm_corrects'] = mlm_outputs['mlm_corrects']
+ else:
+ with torch.no_grad():
+ self.eval()
+ # print('| mlm_inference | mlm_input_ids: ', batch['mlm_input_ids'].size())
+ mlm_outputs = self.forward_mlm(
+ input_ids=batch['mlm_input_ids'], position_ids=batch['mlm_position_ids'],
+ img_feat=batch['mlm_img_feat'], img_pos_feat=batch['mlm_img_pos_feat'],
+ attn_masks=batch['mlm_attn_masks'], gather_index=batch['mlm_gather_index'],
+ txt_labels=batch['mlm_txt_labels'], compute_loss=False, sampling=True
+ )
+ with torch.no_grad():
+ # select_indices = mlm_outputs['select_indices']
+ nsgd_batch = {}
+ nsgd_batch['txt_ids'] = batch['input_ids']
+ nsgd_batch['mlm_sample_size'] = len(batch['mlm_position_ids'])
+ nsgd_batch['nsgd_sample_size'] = self.nsgd_sample_size
+ nsgd_batch['input_ids'] = torch.cat(
+ [batch['input_ids'], mlm_outputs['fill_input_ids']], dim=0)
+ nsgd_batch['position_ids'] = torch.cat(
+ [batch['mlm_position_ids'][:1],
+ repeat_interleave(batch['mlm_position_ids'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['img_feat'] = torch.cat(
+ [batch['mlm_img_feat'][:1],
+ repeat_interleave(batch['mlm_img_feat'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['img_pos_feat'] = torch.cat(
+ [batch['mlm_img_pos_feat'][:1],
+ repeat_interleave(batch['mlm_img_pos_feat'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['attn_masks'] = torch.cat(
+ [batch['mlm_attn_masks'][:1],
+ repeat_interleave(batch['mlm_attn_masks'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['gather_index'] = torch.cat(
+ [batch['mlm_attn_masks'][:1],
+ repeat_interleave(batch['mlm_gather_index'], self.nsgd_sample_size, dim=0)], dim=0)
+ self.eval()
+ forward_uniter_outputs = self.forward_uniter(
+ input_ids=nsgd_batch['input_ids'], position_ids=nsgd_batch['position_ids'],
+ img_feat=nsgd_batch['img_feat'], img_pos_feat=nsgd_batch['img_pos_feat'],
+ attn_masks=nsgd_batch['attn_masks'], gather_index=nsgd_batch['gather_index'],
+ compute_loss=False, sigmoid_norm=True
+ )
+ nsgd_batch = self._get_nsgd_batch(
+ nsgd_batch, scores=forward_uniter_outputs['rank_scores'],
+ clean=compute_loss)
+ self.train()
+ assert batch['input_ids'].ne(nsgd_batch['input_ids'][0]).long().sum().item() == 0
+ model_outputs['effect_nsgd_number'] = nsgd_batch['effect_num']
+ model_outputs['rank_adv_scores'] = forward_uniter_outputs['rank_scores_sigmoid']
+ if compute_loss:
+ forward_uniter_outputs = self.forward_uniter(
+ sample_size=nsgd_batch['sample_size'],
+ input_ids=nsgd_batch['input_ids'], position_ids=nsgd_batch['position_ids'],
+ img_feat=nsgd_batch['img_feat'], img_pos_feat=nsgd_batch['img_pos_feat'],
+ attn_masks=nsgd_batch['attn_masks'], gather_index=nsgd_batch['gather_index'],
+ compute_loss=True
+ )
+ model_outputs.update(forward_uniter_outputs)
+ else:
+ model_outputs['nsgd_adv_batch'] = nsgd_batch
+ return model_outputs
+
+ def _get_nsgd_batch(self, batch, scores, clean=True):
+ batch = defaultdict(lambda: None, batch)
+ txt_ids = batch['txt_ids']
+ mlm_sample_size, nsgd_sample_size = batch['mlm_sample_size'], batch['nsgd_sample_size']
+
+ input_ids = batch['input_ids']
+ position_ids = batch['position_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ attention_mask = batch['attn_masks']
+ gather_index = batch['gather_index']
+ hard_batch = {'sample_size': self.hard_size + 1}
+
+ # NOTE first example is positive
+ if clean:
+ # print('| clean:', clean)
+ c1_penalty_scores = repeat_interleave(txt_ids, mlm_sample_size * nsgd_sample_size + 1, dim=0).\
+ ne(input_ids).float().sum(-1).le(0).float().mul(3)
+ c2_penalty_scores = input_ids.unsqueeze(1).ne(input_ids.unsqueeze(0)).\
+ float().sum(-1).le(0).float().\
+ triu(diagonal=1).sum(-1).gt(0).float().mul(2)
+ effect_num = c1_penalty_scores.add(c2_penalty_scores).eq(0).long().sum()
+ hard_batch['effect_num'] = effect_num
+ hard_scores = scores.squeeze(-1).sub(c1_penalty_scores.add(c2_penalty_scores).type_as(scores))
+ else:
+ hard_batch['effect_num'] = len(scores)
+ hard_scores = scores.squeeze(-1)
+ # print('| hard_scores: ', hard_scores.size(), self.hard_size)
+ hard_indices = hard_scores[1:].topk(self.hard_size, sorted=True)[1] + 1
+ # hard_indices = hard_indices[256:]
+ hard_size = len(hard_indices)
+ indices = torch.cat([torch.zeros(1, dtype=torch.long, device=hard_indices.device),
+ hard_indices])
+ # indices = hard_indices
+
+ attention_mask = attention_mask.index_select(0, indices)
+ gather_index = gather_index.index_select(0, indices)
+ if position_ids.size(0) != 1:
+ position_ids = position_ids[:hard_size+1]
+
+ input_ids = input_ids.index_select(0, indices)
+ # expect same image features for all pairs
+ img_feat = img_feat[:hard_size+1]
+ img_pos_feat = img_pos_feat[:hard_size+1]
+
+ hard_batch['input_ids'] = input_ids
+ hard_batch['position_ids'] = position_ids
+ hard_batch['img_feat'] = img_feat
+ hard_batch['img_pos_feat'] = img_pos_feat
+ hard_batch['attn_masks'] = attention_mask
+ hard_batch['gather_index'] = gather_index
+
+ return hard_batch
+
+ def _get_hard_batch(self, batch, scores, sample_from='t'):
+ batch = defaultdict(lambda: None, batch)
+ input_ids = batch['input_ids']
+ position_ids = batch['position_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ attention_mask = batch['attn_masks']
+ gather_index = batch['gather_index']
+ hard_batch = {'sample_size': self.hard_size + 1}
+
+ # NOTE first example is positive
+ hard_indices = scores.squeeze(-1)[1:].topk(
+ self.hard_size, sorted=False)[1] + 1
+ indices = torch.cat([torch.zeros(1, dtype=torch.long,
+ device=hard_indices.device),
+ hard_indices])
+
+ attention_mask = attention_mask.index_select(0, indices)
+ gather_index = gather_index.index_select(0, indices)
+ if position_ids.size(0) != 1:
+ position_ids = position_ids[:self.hard_size+1]
+
+ if sample_from == 't':
+ # cut to minimum padding
+ max_len = attention_mask.sum(dim=1).max().item()
+ max_i = max_len - input_ids.size(1)
+ attention_mask = attention_mask[:, :max_len]
+ gather_index = gather_index[:, :max_len]
+ img_feat = img_feat.index_select(0, indices)[:, :max_i, :]
+ img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :]
+ # expect same input_ids for all pairs
+ input_ids = input_ids[:self.hard_size+1]
+ elif sample_from == 'i':
+ input_ids = input_ids.index_select(0, indices)
+ # expect same image features for all pairs
+ img_feat = img_feat[:self.hard_size+1]
+ img_pos_feat = img_pos_feat[:self.hard_size+1]
+ else:
+ raise ValueError()
+
+ hard_batch['input_ids'] = input_ids
+ hard_batch['position_ids'] = position_ids
+ hard_batch['img_feat'] = img_feat
+ hard_batch['img_pos_feat'] = img_pos_feat
+ hard_batch['attn_masks'] = attention_mask
+ hard_batch['gather_index'] = gather_index
+
+ return hard_batch
diff --git a/model/nsgd2.py b/model/nsgd2.py
new file mode 100644
index 0000000..334a7e2
--- /dev/null
+++ b/model/nsgd2.py
@@ -0,0 +1,456 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+UNITER for ITM model
+"""
+from collections import defaultdict
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from .model import UniterPreTrainedModel, UniterModel
+from .layer import GELU, BertOnlyMLMHead
+
+
+def repeat_interleave(x, n_repeat=1, dim=0):
+ repeat_list = [1] * (dim + 1) + [n_repeat] + [1] * (x.dim() - dim - 1)
+ x_size = list(x.size())
+ x_size[dim] = x_size[dim] * n_repeat
+ x = x.unsqueeze(dim+1).repeat(*repeat_list).view(*x_size)
+ return x
+
+
+class UniterForNSGD2(UniterPreTrainedModel):
+ """ Finetune UNITER for image text retrieval
+ """
+ def __init__(self, config, img_dim, margin=0.2, hard_size=16, nsgd_sample_size=16, nsgd_sample_temperature=1.0,
+ disc_weights=(1.0, 0.11)):
+ super().__init__(config)
+ self.uniter = UniterModel(config, img_dim)
+ self.cls = BertOnlyMLMHead(
+ config, self.uniter.embeddings.word_embeddings.weight)
+ self.itm_output = nn.Linear(config.hidden_size, 2)
+ self.rank_output = nn.Linear(config.hidden_size, 1)
+ self.disc_output = nn.Linear(config.hidden_size, 2)
+ self.correction_output = BertOnlyMLMHead(
+ config, self.uniter.embeddings.word_embeddings.weight)
+
+ self.margin = margin
+ self.apply(self.init_weights)
+ self.hard_size = hard_size
+ self.nsgd_sample_size = nsgd_sample_size
+ self.nsgd_sample_temperature = nsgd_sample_temperature
+ self.disc_weights = torch.tensor(disc_weights)
+ print('| disc_weights: ', self.disc_weights)
+ print('-'*100)
+
+ def init_output(self):
+ """ need to be called after from pretrained """
+ self.rank_output.weight.data = self.itm_output.weight.data[1:, :]
+ self.rank_output.bias.data = self.itm_output.bias.data[1:]
+ self.disc_output.weight.data.copy_(torch.tensor(self.itm_output.weight.data.cpu().numpy()))
+ self.disc_output.bias.data.copy_(torch.tensor(self.itm_output.bias.data.cpu().numpy()))
+ for name, param in self.correction_output.named_parameters():
+ cls_param = self.cls.state_dict().get(name)
+ param.data.copy_(torch.tensor(cls_param.data.cpu().numpy()))
+
+ def forward_uniter(
+ self,
+ sample_size=None,
+ input_ids=None, position_ids=None,
+ img_feat=None, img_pos_feat=None,
+ attn_masks=None, gather_index=None,
+ compute_loss=True, sigmoid_norm=False
+ ):
+ model_outputs = {}
+ sequence_output = self.uniter(
+ input_ids, position_ids,
+ img_feat, img_pos_feat,
+ attn_masks, gather_index,
+ output_all_encoded_layers=False
+ )
+ pooled_output = self.uniter.pooler(sequence_output)
+ rank_scores = self.rank_output(pooled_output)
+ model_outputs['rank_scores'] = rank_scores
+ if compute_loss:
+ # triplet loss
+ rank_scores_sigmoid = torch.sigmoid(rank_scores)
+ # sample_size = batch['sample_size']
+ scores = rank_scores_sigmoid.contiguous().view(-1, sample_size)
+ pos = scores[:, :1]
+ neg = scores[:, 1:]
+ rank_loss = torch.clamp(self.margin + neg - pos, 0)
+ rank_corrects = neg.sub(pos).le(0).float()
+ model_outputs['rank_loss'] = rank_loss
+ model_outputs['rank_corrects'] = rank_corrects
+ if sigmoid_norm:
+ rank_scores_sigmoid = torch.sigmoid(rank_scores)
+ model_outputs['rank_scores_sigmoid'] = rank_scores_sigmoid
+ return model_outputs
+
+ def _compute_masked_hidden(self, hidden, mask):
+ """ get only the masked region (don't compute unnecessary hiddens) """
+ mask = mask.unsqueeze(-1).expand_as(hidden)
+ hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1))
+ return hidden_masked
+
+ def forward_mlm(
+ self,
+ input_ids=None, position_ids=None,
+ img_feat=None, img_pos_feat=None,
+ attn_masks=None, gather_index=None,
+ txt_labels=None, compute_loss=True, sampling=True
+ ):
+ model_outputs = {}
+ if gather_index.size(0) == 1 and gather_index.size(0) < input_ids.size(0):
+ rep_in_other_dimension_size = [1] * (len(gather_index.size()) - 1)
+ gather_index = gather_index.repeat(input_ids.size(0), *rep_in_other_dimension_size)
+
+ sequence_output = self.uniter(
+ input_ids, position_ids,
+ img_feat, img_pos_feat,
+ attn_masks, gather_index,
+ output_all_encoded_layers=False)
+ # get only the text part
+ sequence_output = sequence_output[:, :input_ids.size(1), :]
+
+ if compute_loss:
+ # only compute masked tokens for better efficiency
+ masked_output = self._compute_masked_hidden(sequence_output, txt_labels != -1)
+ # print('| forward_mlm: ', masked_output.min(), masked_output.max())
+ prediction_scores = self.cls(masked_output)
+ masked_lm_loss = F.cross_entropy(prediction_scores, txt_labels[txt_labels != -1], reduction='none')
+ model_outputs['masked_lm_loss'] = masked_lm_loss
+ model_outputs['mlm_corrects'] = prediction_scores.max(-1)[1].eq(txt_labels[txt_labels != -1]).float()
+ if sampling:
+ bsz, caption_len = input_ids.size(0), input_ids.size(1)
+ prediction_scores = self.cls(sequence_output)
+ sample_caption_tokens = torch.multinomial(
+ prediction_scores.div(self.nsgd_sample_temperature).softmax(-1).view(-1, prediction_scores.size(-1)),
+ num_samples=self.nsgd_sample_size,
+ replacement=True,
+ ).view(bsz, caption_len, self.nsgd_sample_size).permute(0, 2, 1)
+ mask_indicator = txt_labels.ne(-1).long().unsqueeze(1)
+ synthetic_input_ids = input_ids.unsqueeze(1).mul(1-mask_indicator).\
+ add(sample_caption_tokens.mul(mask_indicator)).reshape(-1, caption_len)
+ model_outputs['fill_input_ids'] = synthetic_input_ids
+ return model_outputs
+
+ def forward_gd(
+ self,
+ syn_input_ids=None, syn_position_ids=None,
+ img_feat=None, img_pos_feat=None,
+ attn_masks=None, gather_index=None,
+ txt_labels=None,
+ gt_input_ids=None, compute_loss=True, compute_all_correction_output=False,
+ ):
+ model_outputs = {}
+ if gather_index.size(0) == 1 and gather_index.size(0) < syn_input_ids.size(0):
+ rep_in_other_dimension_size = [1] * (len(gather_index.size()) - 1)
+ gather_index = gather_index.repeat(syn_input_ids.size(0), *rep_in_other_dimension_size)
+
+ sequence_output = self.uniter(
+ syn_input_ids, syn_position_ids,
+ img_feat, img_pos_feat,
+ attn_masks, gather_index,
+ output_all_encoded_layers=False)
+ # get only the text part
+ sequence_output = sequence_output[:, :syn_input_ids.size(1), :]
+
+ # disc loss
+ disc_p = self.disc_output(sequence_output).softmax(-1)
+ model_outputs['disc_p'] = disc_p
+ if compute_loss:
+ disc_labels = syn_input_ids.eq(gt_input_ids).long()
+ # print('| disc_labels: ', disc_labels[:3])
+ disc_loss = F.nll_loss(
+ disc_p.clamp(1e-10).log().reshape(-1, disc_p.size(-1)),
+ disc_labels.view(-1),
+ reduction='none',
+ weight=self.disc_weights.type_as(disc_p))
+ model_outputs['disc_loss'] = disc_loss
+ model_outputs['disc_pos_corrects'] = disc_p.max(-1)[1].eq(disc_labels).float().\
+ mul(disc_labels.eq(1).float()).sum()
+ model_outputs['disc_pos_samples'] = disc_pos_ntokens = disc_labels.eq(1).sum()
+ model_outputs['disc_neg_corrects'] = disc_p.max(-1)[1].eq(disc_labels).float().\
+ mul(disc_labels.eq(0).float()).sum()
+ model_outputs['disc_neg_samples'] = disc_neg_ntokens = disc_labels.eq(0).sum()
+ model_outputs['disc_corrects'] = disc_p.max(-1)[1].eq(disc_labels).float().sum()
+ model_outputs['disc_samples'] = syn_input_ids.numel()
+
+ if compute_loss:
+ masked_output = self._compute_masked_hidden(sequence_output, txt_labels != -1)
+ logits = self.correction_output(masked_output) # .div(3.0).clamp(-8.0, 8.0)
+ # print('| logits: ', logits.min(), logits.max())
+ overall_p = correction_p = logits.softmax(-1)
+ masked_lm_loss = F.nll_loss(
+ overall_p.clamp(1e-10).log(),
+ txt_labels[txt_labels != -1],
+ reduction='none')
+
+ model_outputs['correction_loss'] = masked_lm_loss
+ model_outputs['correction_corrects'] = correction_corrects = \
+ overall_p.max(-1)[1].eq(txt_labels[txt_labels != -1]).float()
+
+ if compute_all_correction_output:
+ logits = self.correction_output(sequence_output) # .div(3.0).clamp(-8.0, 8.0)
+ model_outputs['correction_texts'] = logits.max(-1)[1]
+ model_outputs['all_correction_corrects'] = all_correction_corrects = \
+ logits.max(-1)[1][..., 1:-1].eq(gt_input_ids[..., 1:-1]).float()
+
+ # print('| correction_corrects: {} | all_correction_corrects: {}'.format(
+ # correction_corrects.size(), all_correction_corrects.size()))
+ # print(correction_corrects.sum().item(), correction_corrects.numel(),
+ # all_correction_corrects.sum().item(), all_correction_corrects.numel())
+ return model_outputs
+
+ def forward(self, batch, sample_from='t', compute_loss=True, compute_mlm=False, compute_all_correction_output=False):
+ # expect same input_ids for all pairs
+ model_outputs = {}
+ if not sample_from.startswith('g'):
+ batch_size = batch['attn_masks'].size(0)
+ input_ids = batch['input_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ if sample_from == 't':
+ if input_ids.size(0) == 1:
+ batch['input_ids'] = input_ids.expand(batch_size, -1)
+ elif sample_from == 'i':
+ if img_feat.size(0) == 1:
+ batch['img_feat'] = img_feat.expand(batch_size, -1, -1)
+ if img_pos_feat.size(0) == 1:
+ batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1)
+ else:
+ raise ValueError()
+
+ if self.training and compute_loss:
+ with torch.no_grad():
+ self.eval()
+ forward_uniter_outputs = self.forward_uniter(
+ input_ids=batch['input_ids'], position_ids=batch['position_ids'],
+ img_feat=batch['img_feat'], img_pos_feat=batch['img_pos_feat'],
+ attn_masks=batch['attn_masks'], gather_index=batch['gather_index'],
+ compute_loss=False,
+ )
+ hard_batch = self._get_hard_batch(batch, forward_uniter_outputs['rank_scores'], sample_from)
+ self.train()
+ forward_uniter_outputs = self.forward_uniter(
+ sample_size=hard_batch['sample_size'],
+ input_ids=hard_batch['input_ids'], position_ids=hard_batch['position_ids'],
+ img_feat=hard_batch['img_feat'], img_pos_feat=hard_batch['img_pos_feat'],
+ attn_masks=hard_batch['attn_masks'], gather_index=hard_batch['gather_index'],
+ compute_loss=True)
+ model_outputs['rank_loss'] = forward_uniter_outputs['rank_loss']
+ model_outputs['rank_corrects'] = forward_uniter_outputs['rank_corrects']
+ else:
+ forward_uniter_outputs = self.forward_uniter(
+ input_ids=batch['input_ids'], position_ids=batch['position_ids'],
+ img_feat=batch['img_feat'], img_pos_feat=batch['img_pos_feat'],
+ attn_masks=batch['attn_masks'], gather_index=batch['gather_index'],
+ compute_loss=compute_loss)
+ model_outputs.update(forward_uniter_outputs)
+ return model_outputs
+ elif sample_from == 'gsynt':
+ batch_size = batch['syn_attn_masks'].size(0)
+ if batch['gt_input_ids'].size(0) == 1:
+ batch['gt_input_ids'] = batch['gt_input_ids'].expand(batch_size, -1)
+ forward_uniter_outputs = self.forward_gd(
+ syn_input_ids=batch['syn_input_ids'], syn_position_ids=batch['syn_position_ids'],
+ img_feat=batch['syn_img_feat'], img_pos_feat=batch['syn_img_pos_feat'],
+ attn_masks=batch['syn_attn_masks'], gather_index=batch['syn_gather_index'],
+ txt_labels=batch['syn_txt_labels'],
+ gt_input_ids=batch['gt_input_ids'], compute_all_correction_output=compute_all_correction_output)
+ model_outputs.update(forward_uniter_outputs)
+ return model_outputs
+ else:
+ if compute_mlm:
+ self.train()
+ mlm_outputs = self.forward_mlm(
+ input_ids=batch['mlm_input_ids'], position_ids=batch['mlm_position_ids'],
+ img_feat=batch['mlm_img_feat'], img_pos_feat=batch['mlm_img_pos_feat'],
+ attn_masks=batch['mlm_attn_masks'], gather_index=batch['mlm_gather_index'],
+ txt_labels=batch['mlm_txt_labels'], compute_loss=True, sampling=True
+ )
+ model_outputs['masked_lm_loss'] = mlm_outputs['masked_lm_loss']
+ # model_outputs['effect_nsgd_number'] = mlm_outputs['effect_nsgd_number']
+ model_outputs['mlm_corrects'] = mlm_outputs['mlm_corrects']
+ else:
+ with torch.no_grad():
+ self.eval()
+ mlm_outputs = self.forward_mlm(
+ input_ids=batch['mlm_input_ids'], position_ids=batch['mlm_position_ids'],
+ img_feat=batch['mlm_img_feat'], img_pos_feat=batch['mlm_img_pos_feat'],
+ attn_masks=batch['mlm_attn_masks'], gather_index=batch['gather_index'],
+ txt_labels=batch['mlm_txt_labels'], compute_loss=False, sampling=True
+ )
+ with torch.no_grad():
+ # select_indices = mlm_outputs['select_indices']
+ nsgd_batch = {}
+ nsgd_batch['txt_ids'] = batch['input_ids']
+ nsgd_batch['mlm_sample_size'] = len(batch['mlm_position_ids'])
+ nsgd_batch['nsgd_sample_size'] = self.nsgd_sample_size
+ nsgd_batch['input_dict'] = batch['input_dict']
+ nsgd_batch['input_ids'] = torch.cat(
+ [batch['input_ids'], mlm_outputs['fill_input_ids']], dim=0)
+ nsgd_batch['position_ids'] = torch.cat(
+ [batch['mlm_position_ids'][:1],
+ repeat_interleave(batch['mlm_position_ids'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['img_feat'] = torch.cat(
+ [batch['mlm_img_feat'][:1],
+ repeat_interleave(batch['mlm_img_feat'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['img_pos_feat'] = torch.cat(
+ [batch['mlm_img_pos_feat'][:1],
+ repeat_interleave(batch['mlm_img_pos_feat'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['attn_masks'] = torch.cat(
+ [batch['mlm_attn_masks'][:1],
+ repeat_interleave(batch['mlm_attn_masks'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['gather_index'] = torch.cat(
+ [batch['mlm_attn_masks'][:1],
+ repeat_interleave(batch['mlm_gather_index'], self.nsgd_sample_size, dim=0)], dim=0)
+ nsgd_batch['txt_labels'] = torch.cat(
+ [batch['mlm_txt_labels'][:1],
+ repeat_interleave(batch['mlm_txt_labels'], self.nsgd_sample_size, dim=0)], dim=0)
+ self.eval()
+ forward_uniter_outputs = self.forward_uniter(
+ input_ids=nsgd_batch['input_ids'], position_ids=nsgd_batch['position_ids'],
+ img_feat=nsgd_batch['img_feat'], img_pos_feat=nsgd_batch['img_pos_feat'],
+ attn_masks=nsgd_batch['attn_masks'], gather_index=nsgd_batch['gather_index'],
+ compute_loss=False, sigmoid_norm=True
+ )
+ nsgd_batch = self._get_nsgd_batch(
+ nsgd_batch, scores=forward_uniter_outputs['rank_scores_sigmoid'])
+ assert batch['input_ids'].ne(nsgd_batch['input_ids'][0]).long().sum().item() == 0
+ # print('| gt_input_ids: {} | syn_input_ids[:2]: {} | txt_labels[:2]: {}'.
+ # format(batch['input_ids'], nsgd_batch['input_ids'][:3], nsgd_batch['txt_labels'][:3]))
+ model_outputs['syn_input_ids'] = nsgd_batch['input_ids'][1:]
+ model_outputs['syn_position_ids'] = nsgd_batch['position_ids'][1:]
+ model_outputs['syn_img_feat'] = nsgd_batch['img_feat'][1:]
+ model_outputs['syn_img_pos_feat'] = nsgd_batch['img_pos_feat'][1:]
+ model_outputs['syn_attn_masks'] = nsgd_batch['attn_masks'][1:]
+ model_outputs['syn_gather_index'] = nsgd_batch['gather_index'][1:]
+ model_outputs['syn_txt_labels'] = nsgd_batch['txt_labels'][1:]
+ model_outputs['gt_input_ids'] = batch['input_ids'][0].unsqueeze(0)
+ model_outputs['effect_nsgd_number'] = nsgd_batch['effect_num']
+ model_outputs['rank_adv_scores'] = nsgd_batch['rank_scores_sigmoid']
+
+ if compute_loss:
+ self.train()
+ forward_uniter_outputs = self.forward_uniter(
+ sample_size=nsgd_batch['sample_size'],
+ input_ids=nsgd_batch['input_ids'], position_ids=nsgd_batch['position_ids'],
+ img_feat=nsgd_batch['img_feat'], img_pos_feat=nsgd_batch['img_pos_feat'],
+ attn_masks=nsgd_batch['attn_masks'], gather_index=nsgd_batch['gather_index'],
+ compute_loss=True
+ )
+ model_outputs.update(forward_uniter_outputs)
+ else:
+ model_outputs['nsgd_adv_batch'] = nsgd_batch
+ return model_outputs
+
+ def _get_nsgd_batch(self, batch, scores, clean=True):
+ batch = defaultdict(lambda: None, batch)
+ # txt_ids = batch['txt_ids']
+ input_dict = batch['input_dict']
+
+ mlm_sample_size, nsgd_sample_size = batch['mlm_sample_size'], batch['nsgd_sample_size']
+
+ input_ids = batch['input_ids']
+ position_ids = batch['position_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ attention_mask = batch['attn_masks']
+ gather_index = batch['gather_index']
+ txt_labels = batch['txt_labels']
+ hard_batch = {'sample_size': self.hard_size + 1}
+
+ # NOTE first example is positive
+ if clean:
+ # c1_penalty_scores = repeat_interleave(txt_ids, mlm_sample_size * nsgd_sample_size + 1, dim=0).\
+ # ne(input_ids).float().sum(-1).le(0).float().mul(3)
+ c1_penalty_scores = repeat_interleave(input_dict, mlm_sample_size * nsgd_sample_size + 1, dim=0).\
+ ne(input_ids).float().sum(-1).le(0).float().mul(3)
+
+ c2_penalty_scores = input_ids.unsqueeze(1).ne(input_ids.unsqueeze(0)).\
+ float().sum(-1).le(0).float().\
+ triu(diagonal=1).sum(-1).gt(0).float().mul(2)
+ effect_num = c1_penalty_scores.add(c2_penalty_scores).eq(0).long().sum()
+ hard_batch['effect_num'] = effect_num
+ hard_scores = scores.squeeze(-1).sub(c1_penalty_scores.add(c2_penalty_scores).type_as(scores))
+ else:
+ hard_batch['effect_num'] = len(scores)
+ hard_scores = scores.squeeze(-1)
+ hard_indices = hard_scores[1:].topk(self.hard_size, sorted=True)[1] + 1
+ hard_size = len(hard_indices)
+ indices = torch.cat([torch.zeros(1, dtype=torch.long, device=hard_indices.device),
+ hard_indices])
+
+ attention_mask = attention_mask.index_select(0, indices)
+ gather_index = gather_index.index_select(0, indices)
+ txt_labels = txt_labels.index_select(0, indices)
+ if position_ids.size(0) != 1:
+ position_ids = position_ids[:hard_size+1]
+
+ input_ids = input_ids.index_select(0, indices)
+ # expect same image features for all pairs
+ img_feat = img_feat[:hard_size+1]
+ img_pos_feat = img_pos_feat[:hard_size+1]
+
+ hard_batch['input_ids'] = input_ids
+ hard_batch['position_ids'] = position_ids
+ hard_batch['img_feat'] = img_feat
+ hard_batch['img_pos_feat'] = img_pos_feat
+ hard_batch['attn_masks'] = attention_mask
+ hard_batch['gather_index'] = gather_index
+ hard_batch['txt_labels'] = txt_labels
+ hard_batch['rank_scores_sigmoid'] = torch.cat([scores[0], hard_scores.index_select(0, hard_indices)], dim=0)
+ return hard_batch
+
+ def _get_hard_batch(self, batch, scores, sample_from='t'):
+ batch = defaultdict(lambda: None, batch)
+ input_ids = batch['input_ids']
+ position_ids = batch['position_ids']
+ img_feat = batch['img_feat']
+ img_pos_feat = batch['img_pos_feat']
+ attention_mask = batch['attn_masks']
+ gather_index = batch['gather_index']
+ hard_batch = {'sample_size': self.hard_size + 1}
+
+ # NOTE first example is positive
+ hard_indices = scores.squeeze(-1)[1:].topk(
+ self.hard_size, sorted=False)[1] + 1
+ indices = torch.cat([torch.zeros(1, dtype=torch.long,
+ device=hard_indices.device),
+ hard_indices])
+
+ attention_mask = attention_mask.index_select(0, indices)
+ gather_index = gather_index.index_select(0, indices)
+ if position_ids.size(0) != 1:
+ position_ids = position_ids[:self.hard_size+1]
+
+ if sample_from == 't':
+ # cut to minimum padding
+ max_len = attention_mask.sum(dim=1).max().item()
+ max_i = max_len - input_ids.size(1)
+ attention_mask = attention_mask[:, :max_len]
+ gather_index = gather_index[:, :max_len]
+ img_feat = img_feat.index_select(0, indices)[:, :max_i, :]
+ img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :]
+ # expect same input_ids for all pairs
+ input_ids = input_ids[:self.hard_size+1]
+ elif sample_from == 'i':
+ input_ids = input_ids.index_select(0, indices)
+ # expect same image features for all pairs
+ img_feat = img_feat[:self.hard_size+1]
+ img_pos_feat = img_pos_feat[:self.hard_size+1]
+ else:
+ raise ValueError()
+
+ hard_batch['input_ids'] = input_ids
+ hard_batch['position_ids'] = position_ids
+ hard_batch['img_feat'] = img_feat
+ hard_batch['img_pos_feat'] = img_pos_feat
+ hard_batch['attn_masks'] = attention_mask
+ hard_batch['gather_index'] = gather_index
+
+ return hard_batch
diff --git a/optim/__init__.py b/optim/__init__.py
new file mode 100644
index 0000000..3c21fa9
--- /dev/null
+++ b/optim/__init__.py
@@ -0,0 +1,7 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+"""
+from .sched import noam_schedule, warmup_linear, vqa_schedule, get_lr_sched
+from .adamw import AdamW
diff --git a/optim/adamw.py b/optim/adamw.py
new file mode 100644
index 0000000..f8e346c
--- /dev/null
+++ b/optim/adamw.py
@@ -0,0 +1,103 @@
+"""
+AdamW optimizer (weight decay fix)
+copied from hugginface (https://github.com/huggingface/transformers).
+"""
+import math
+
+import torch
+from torch.optim import Optimizer
+
+
+class AdamW(Optimizer):
+ """ Implements Adam algorithm with weight decay fix.
+ Parameters:
+ lr (float): learning rate. Default 1e-3.
+ betas (tuple of 2 floats): Adams beta parameters (b1, b2).
+ Default: (0.9, 0.999)
+ eps (float): Adams epsilon. Default: 1e-6
+ weight_decay (float): Weight decay. Default: 0.0
+ correct_bias (bool): can be set to False to avoid correcting bias
+ in Adam (e.g. like in Bert TF repository). Default True.
+ """
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
+ weight_decay=0.0, correct_bias=True):
+ if lr < 0.0:
+ raise ValueError(
+ "Invalid learning rate: {} - should be >= 0.0".format(lr))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter: {} - "
+ "should be in [0.0, 1.0[".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter: {} - "
+ "should be in [0.0, 1.0[".format(betas[1]))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {} - "
+ "should be >= 0.0".format(eps))
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
+ correct_bias=correct_bias)
+ super(AdamW, self).__init__(params, defaults)
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError(
+ 'Adam does not support sparse '
+ 'gradients, please consider SparseAdam instead')
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ state['step'] += 1
+
+ # Decay the first and second moment running average coefficient
+ # In-place operations to update the averages at the same time
+ exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
+ exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+
+ step_size = group['lr']
+ if group['correct_bias']: # No bias correction for Bert
+ bias_correction1 = 1.0 - beta1 ** state['step']
+ bias_correction2 = 1.0 - beta2 ** state['step']
+ step_size = (step_size * math.sqrt(bias_correction2)
+ / bias_correction1)
+
+ p.data.addcdiv_(-step_size, exp_avg, denom)
+
+ # Just adding the square of the weights to the loss function is
+ # *not* the correct way of using L2 regularization/weight decay
+ # with Adam, since that will interact with the m and v
+ # parameters in strange ways.
+ #
+ # Instead we want to decay the weights in a manner that doesn't
+ # interact with the m/v parameters. This is equivalent to
+ # adding the square of the weights to the loss with plain
+ # (non-momentum) SGD.
+ # Add weight decay at the end (fixed version)
+ if group['weight_decay'] > 0.0:
+ p.data.add_(-group['lr'] * group['weight_decay'], p.data)
+
+ return loss
diff --git a/optim/misc.py b/optim/misc.py
new file mode 100644
index 0000000..1368ab4
--- /dev/null
+++ b/optim/misc.py
@@ -0,0 +1,35 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Misc lr helper
+"""
+from torch.optim import Adam, Adamax
+
+from .adamw import AdamW
+
+
+def build_optimizer(model, opts):
+ param_optimizer = list(model.named_parameters())
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
+ optimizer_grouped_parameters = [
+ {'params': [p for n, p in param_optimizer
+ if not any(nd in n for nd in no_decay)],
+ 'weight_decay': opts.weight_decay},
+ {'params': [p for n, p in param_optimizer
+ if any(nd in n for nd in no_decay)],
+ 'weight_decay': 0.0}
+ ]
+
+ # currently Adam only
+ if opts.optim == 'adam':
+ OptimCls = Adam
+ elif opts.optim == 'adamax':
+ OptimCls = Adamax
+ elif opts.optim == 'adamw':
+ OptimCls = AdamW
+ else:
+ raise ValueError('invalid optimizer')
+ optimizer = OptimCls(optimizer_grouped_parameters,
+ lr=opts.learning_rate, betas=opts.betas)
+ return optimizer
diff --git a/optim/sched.py b/optim/sched.py
new file mode 100644
index 0000000..a4c46d6
--- /dev/null
+++ b/optim/sched.py
@@ -0,0 +1,46 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+optimizer learning rate scheduling helpers
+"""
+from math import ceil
+
+
+def noam_schedule(step, warmup_step=4000):
+ """ original Transformer schedule"""
+ if step <= warmup_step:
+ return step / warmup_step
+ return (warmup_step ** 0.5) * (step ** -0.5)
+
+
+def warmup_linear(step, warmup_step, tot_step):
+ """ BERT schedule """
+ if step < warmup_step:
+ return step / warmup_step
+ return max(0, (tot_step-step)/(tot_step-warmup_step))
+
+
+def vqa_schedule(step, warmup_interval, decay_interval,
+ decay_start, decay_rate):
+ """ VQA schedule from MCAN """
+ if step < warmup_interval:
+ return 1/4
+ elif step < 2 * warmup_interval:
+ return 2/4
+ elif step < 3 * warmup_interval:
+ return 3/4
+ elif step >= decay_start:
+ num_decay = ceil((step - decay_start) / decay_interval)
+ return decay_rate ** num_decay
+ else:
+ return 1
+
+
+def get_lr_sched(global_step, opts):
+ # learning rate scheduling
+ lr_this_step = opts.learning_rate * warmup_linear(
+ global_step, opts.warmup_steps, opts.num_train_steps)
+ if lr_this_step <= 0:
+ lr_this_step = 1e-8
+ return lr_this_step
diff --git a/prepro.py b/prepro.py
new file mode 100644
index 0000000..39236cf
--- /dev/null
+++ b/prepro.py
@@ -0,0 +1,101 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+preprocess NLVR annotations into LMDB
+"""
+import argparse
+import json
+import os
+from os.path import exists
+
+from cytoolz import curry
+from tqdm import tqdm
+from pytorch_pretrained_bert import BertTokenizer
+
+from data.data import open_lmdb
+
+
+@curry
+def bert_tokenize(tokenizer, text):
+ ids = []
+ for word in text.strip().split():
+ ws = tokenizer.tokenize(word)
+ if not ws:
+ # some special char
+ continue
+ ids.extend(tokenizer.convert_tokens_to_ids(ws))
+ return ids
+
+
+def process_nlvr2(jsonl, db, tokenizer, missing=None):
+ id2len = {}
+ txt2img = {} # not sure if useful
+ for line in tqdm(jsonl, desc='processing NLVR2'):
+ example = json.loads(line)
+ id_ = example['identifier']
+ img_id = '-'.join(id_.split('-')[:-1])
+ img_fname = (f'nlvr2_{img_id}-img0.npz', f'nlvr2_{img_id}-img1.npz')
+ if missing and (img_fname[0] in missing or img_fname[1] in missing):
+ continue
+ input_ids = tokenizer(example['sentence'])
+ if 'label' in example:
+ target = 1 if example['label'] == 'True' else 0
+ else:
+ target = None
+ txt2img[id_] = img_fname
+ id2len[id_] = len(input_ids)
+ example['input_ids'] = input_ids
+ example['img_fname'] = img_fname
+ example['target'] = target
+ db[id_] = example
+ return id2len, txt2img
+
+
+def main(opts):
+ if not exists(opts.output):
+ os.makedirs(opts.output)
+ else:
+ raise ValueError('Found existing DB. Please explicitly remove '
+ 'for re-processing')
+ meta = vars(opts)
+ meta['tokenizer'] = opts.toker
+ toker = BertTokenizer.from_pretrained(
+ opts.toker, do_lower_case='uncased' in opts.toker)
+ tokenizer = bert_tokenize(toker)
+ meta['UNK'] = toker.convert_tokens_to_ids(['[UNK]'])[0]
+ meta['CLS'] = toker.convert_tokens_to_ids(['[CLS]'])[0]
+ meta['SEP'] = toker.convert_tokens_to_ids(['[SEP]'])[0]
+ meta['MASK'] = toker.convert_tokens_to_ids(['[MASK]'])[0]
+ meta['v_range'] = (toker.convert_tokens_to_ids('!')[0],
+ len(toker.vocab))
+ with open(f'{opts.output}/meta.json', 'w') as f:
+ json.dump(vars(opts), f, indent=4)
+
+ open_db = curry(open_lmdb, opts.output, readonly=False)
+ with open_db() as db:
+ with open(opts.annotation) as ann:
+ if opts.missing_imgs is not None:
+ missing_imgs = set(json.load(open(opts.missing_imgs)))
+ else:
+ missing_imgs = None
+ id2lens, txt2img = process_nlvr2(ann, db, tokenizer, missing_imgs)
+
+ with open(f'{opts.output}/id2len.json', 'w') as f:
+ json.dump(id2lens, f)
+ with open(f'{opts.output}/txt2img.json', 'w') as f:
+ json.dump(txt2img, f)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--annotation', required=True,
+ help='annotation JSON')
+ parser.add_argument('--missing_imgs',
+ help='some training image features are corrupted')
+ parser.add_argument('--output', required=True,
+ help='output dir of DB')
+ parser.add_argument('--toker', default='bert-base-cased',
+ help='which BERT tokenizer to used')
+ args = parser.parse_args()
+ main(args)
diff --git a/run_cmds/inf_nsgd.sh b/run_cmds/inf_nsgd.sh
new file mode 100644
index 0000000..972dacb
--- /dev/null
+++ b/run_cmds/inf_nsgd.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+
+export NGPU=8
+export DATA_ROOT=./UNITER/itm-data
+
+export model_version=1
+export arch=2000
+export MODEL_ROOT=/mnt/Projects/UNITER/log/itm_nsgd/uniter_nsgd_base_v${model_version}/ckpt
+export OUTPUT_DIR=./model_log/uniter_nsgd_base_v${model_version}/step${arch}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np $NGPU python inf_nsgd.py \
+ --txt_db ${DATA_ROOT}/txt_db3/itm_flickr30k_test.db --img_db ${DATA_ROOT}/img_db/flickr30k \
+ --checkpoint ${MODEL_ROOT}/model_step_${arch}.pt --model_config ./config/uniter-base.json \
+ --output_dir ${OUTPUT_DIR} --fp16 --pin_mem
diff --git a/run_cmds/train_pnsgd2_base_coco.sh b/run_cmds/train_pnsgd2_base_coco.sh
new file mode 100644
index 0000000..202cb3b
--- /dev/null
+++ b/run_cmds/train_pnsgd2_base_coco.sh
@@ -0,0 +1,38 @@
+#!/usr/bin/env bash
+
+pwd
+
+export ROOT_DIR=.
+
+export NGPU=8
+export LR=1e-5
+export NSGD_SAMPLE_TEMPERATURE=1.5
+export NSGD_RANK_LAMBDA=0.0001
+export MLM_LAMBDA=0.05
+export DISC_LAMBDA=5e-4 # 05
+export CORRECTION_LAMBDA=5e-4
+export OUTPUT_DIR=${ROOT_DIR}/log/coco/itm_pnsgd2_base_v1
+export CHECKPOINT=${ROOT_DIR}/log/coco/itm_pnsgd_base_v1/ckpt/model_step_2000.pt
+
+export NEGATIVE_SIZE=399
+export HARD_NEG_SIZE=31
+export MLM_SAMPLE_SIZE=20
+export NSGD_SAMPLE_SIZE=20
+export TRAIN_BATCH_SIZE=32
+export MARGIN=0.2
+
+export VALID_STEPS=150
+export NUM_TRAIN_STEPS=1500
+export WARMUP_STEPS=150
+
+rm ${OUTPUT_DIR} -rf
+ls -lh ${OUTPUT_DIR}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np ${NGPU} python train_pnsgd2.py --config config/train-itm-pnsgd2-base-coco.json \
+ --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \
+ --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \
+ --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \
+ --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \
+ --disc_lambda ${DISC_LAMBDA} --correction_lambda ${CORRECTION_LAMBDA} --checkpoint ${CHECKPOINT} \
+| tee -a ${OUTPUT_DIR}/train_log.txt
diff --git a/run_cmds/train_pnsgd2_base_flickr.sh b/run_cmds/train_pnsgd2_base_flickr.sh
new file mode 100644
index 0000000..d5deb59
--- /dev/null
+++ b/run_cmds/train_pnsgd2_base_flickr.sh
@@ -0,0 +1,38 @@
+#!/usr/bin/env bash
+
+pwd
+
+export ROOT_DIR=.
+
+export NGPU=8
+export LR=1e-5
+export NSGD_SAMPLE_TEMPERATURE=1.5
+export NSGD_RANK_LAMBDA=5e-4
+export MLM_LAMBDA=0.1
+export DISC_LAMBDA=5e-4
+export CORRECTION_LAMBDA=5e-4
+export OUTPUT_DIR=${ROOT_DIR}/log/flickr/itm_pnsgd2_base_v1
+export CHECKPOINT=${ROOT_DIR}/log/flickr/itm_pnsgd_base_v1/ckpt/model_step_2000.pt
+
+export NEGATIVE_SIZE=399
+export HARD_NEG_SIZE=31
+export MLM_SAMPLE_SIZE=20
+export NSGD_SAMPLE_SIZE=20
+export TRAIN_BATCH_SIZE=32
+export MARGIN=0.2
+
+export VALID_STEPS=150
+export NUM_TRAIN_STEPS=1500
+export WARMUP_STEPS=150
+
+rm ${OUTPUT_DIR} -rf
+ls -lh ${OUTPUT_DIR}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np ${NGPU} python train_pnsgd2.py --config config/train-itm-pnsgd2-base-flickr.json \
+ --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \
+ --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \
+ --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \
+ --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \
+ --disc_lambda ${DISC_LAMBDA} --correction_lambda ${CORRECTION_LAMBDA} --checkpoint ${CHECKPOINT} \
+| tee -a ${OUTPUT_DIR}/train_log.txt
diff --git a/run_cmds/train_pnsgd2_large_coco.sh b/run_cmds/train_pnsgd2_large_coco.sh
new file mode 100644
index 0000000..511e46b
--- /dev/null
+++ b/run_cmds/train_pnsgd2_large_coco.sh
@@ -0,0 +1,38 @@
+#!/usr/bin/env bash
+
+pwd
+
+export ROOT_DIR=.
+
+export NGPU=8
+export LR=1e-5
+export NSGD_SAMPLE_TEMPERATURE=1.5
+export NSGD_RANK_LAMBDA=0.0001
+export MLM_LAMBDA=0.05
+export DISC_LAMBDA=5e-4 # 05
+export CORRECTION_LAMBDA=5e-4
+export OUTPUT_DIR=${ROOT_DIR}/log/coco/itm_pnsgd2_large_v1
+export CHECKPOINT=${ROOT_DIR}/log/coco/itm_pnsgd_large_v1/ckpt/model_step_2000.pt
+
+export NEGATIVE_SIZE=399
+export HARD_NEG_SIZE=31
+export MLM_SAMPLE_SIZE=20
+export NSGD_SAMPLE_SIZE=20
+export TRAIN_BATCH_SIZE=32
+export MARGIN=0.2
+
+export VALID_STEPS=150
+export NUM_TRAIN_STEPS=1500
+export WARMUP_STEPS=150
+
+rm ${OUTPUT_DIR} -rf
+ls -lh ${OUTPUT_DIR}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np ${NGPU} python train_pnsgd2.py --config config/train-itm-pnsgd2-large-coco.json \
+ --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \
+ --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \
+ --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \
+ --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \
+ --disc_lambda ${DISC_LAMBDA} --correction_lambda ${CORRECTION_LAMBDA} --checkpoint ${CHECKPOINT} \
+| tee -a ${OUTPUT_DIR}/train_log.txt
diff --git a/run_cmds/train_pnsgd2_large_flickr.sh b/run_cmds/train_pnsgd2_large_flickr.sh
new file mode 100644
index 0000000..41451fe
--- /dev/null
+++ b/run_cmds/train_pnsgd2_large_flickr.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+pwd
+
+export ROOT_DIR=.
+
+export NGPU=8
+export LR=1e-5
+export NSGD_SAMPLE_TEMPERATURE=1.5
+export NSGD_RANK_LAMBDA=5e-4
+export MLM_LAMBDA=0.05
+export DISC_LAMBDA=5e-4
+export CORRECTION_LAMBDA=5e-4
+
+export OUTPUT_DIR=${ROOT_DIR}/log/flickr/itm_pnsgd2_large_v1
+export CHECKPOINT=${ROOT_DIR}/log/flickr/itm_pnsgd_large_v1/ckpt/model_step_2000.pt
+
+export NEGATIVE_SIZE=399
+export HARD_NEG_SIZE=23
+export MLM_SAMPLE_SIZE=20
+export NSGD_SAMPLE_SIZE=20
+export TRAIN_BATCH_SIZE=32
+export MARGIN=0.2
+
+export VALID_STEPS=500
+export NUM_TRAIN_STEPS=5000
+export WARMUP_STEPS=500
+
+rm ${OUTPUT_DIR} -rf
+ls -lh ${OUTPUT_DIR}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np ${NGPU} python train_pnsgd2.py --config config/train-itm-pnsgd2-large-flickr.json \
+ --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \
+ --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \
+ --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \
+ --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \
+ --disc_lambda ${DISC_LAMBDA} --correction_lambda ${CORRECTION_LAMBDA} --checkpoint ${CHECKPOINT} \
+| tee -a ${OUTPUT_DIR}/train_log.txt
diff --git a/run_cmds/train_pnsgd_base_coco.sh b/run_cmds/train_pnsgd_base_coco.sh
new file mode 100644
index 0000000..82dd7b5
--- /dev/null
+++ b/run_cmds/train_pnsgd_base_coco.sh
@@ -0,0 +1,34 @@
+#!/usr/bin/env bash
+
+pwd
+
+export ROOT_DIR=.
+
+export NGPU=8
+export LR=5e-5
+export NSGD_SAMPLE_TEMPERATURE=1.5
+export NSGD_RANK_LAMBDA=5e-5
+export MLM_LAMBDA=0.01
+export OUTPUT_DIR=${ROOT_DIR}/log/coco/itm_pnsgd_base_v1
+
+export NEGATIVE_SIZE=399
+export HARD_NEG_SIZE=31
+export MLM_SAMPLE_SIZE=20
+export NSGD_SAMPLE_SIZE=20
+export TRAIN_BATCH_SIZE=32
+export MARGIN=0.2
+
+export VALID_STEPS=500
+export NUM_TRAIN_STEPS=5000
+export WARMUP_STEPS=500
+
+rm ${OUTPUT_DIR} -rf
+ls -lh ${OUTPUT_DIR}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np ${NGPU} python train_pnsgd.py --config config/train-itm-pnsgd-base-coco.json \
+ --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \
+ --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \
+ --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \
+ --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \
+| tee -a ${OUTPUT_DIR}/train_log.txt
diff --git a/run_cmds/train_pnsgd_base_flickr.sh b/run_cmds/train_pnsgd_base_flickr.sh
new file mode 100644
index 0000000..77faa17
--- /dev/null
+++ b/run_cmds/train_pnsgd_base_flickr.sh
@@ -0,0 +1,34 @@
+#!/usr/bin/env bash
+
+pwd
+
+export ROOT_DIR=.
+
+export NGPU=8
+export LR=5e-5
+export NSGD_SAMPLE_TEMPERATURE=2.0
+export NSGD_RANK_LAMBDA=0.0005
+export MLM_LAMBDA=0.05
+export OUTPUT_DIR=${ROOT_DIR}/log/filckr/itm_pnsgd_base_v1
+
+export NEGATIVE_SIZE=399
+export HARD_NEG_SIZE=31
+export MLM_SAMPLE_SIZE=20
+export NSGD_SAMPLE_SIZE=20
+export TRAIN_BATCH_SIZE=16
+export MARGIN=0.2
+
+export VALID_STEPS=500
+export NUM_TRAIN_STEPS=5000
+export WARMUP_STEPS=500
+
+rm ${OUTPUT_DIR} -rf
+ls -lh ${OUTPUT_DIR}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np ${NGPU} python train_pnsgd.py --config config/train-itm-pnsgd-base-flickr.json \
+ --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \
+ --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \
+ --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \
+ --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \
+| tee -a ${OUTPUT_DIR}/train_log.txt
diff --git a/run_cmds/train_pnsgd_large_coco.sh b/run_cmds/train_pnsgd_large_coco.sh
new file mode 100644
index 0000000..db9b4ba
--- /dev/null
+++ b/run_cmds/train_pnsgd_large_coco.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+
+pwd
+
+export ROOT_DIR=.
+
+export NGPU=8
+export LR=4e-5
+export NSGD_SAMPLE_TEMPERATURE=1.5
+export NSGD_RANK_LAMBDA=1e-3
+export MLM_LAMBDA=0.05
+export OUTPUT_DIR=${ROOT_DIR}/log/coco/itm_pnsgd_large_v1
+
+export NEGATIVE_SIZE=399
+export HARD_NEG_SIZE=23
+export MLM_SAMPLE_SIZE=20
+export NSGD_SAMPLE_SIZE=20
+export TRAIN_BATCH_SIZE=32
+export MARGIN=0.2
+
+
+export VALID_STEPS=500
+export NUM_TRAIN_STEPS=5000
+export WARMUP_STEPS=500
+
+rm ${OUTPUT_DIR} -rf
+ls -lh ${OUTPUT_DIR}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np ${NGPU} python train_pnsgd.py --config config/train-itm-pnsgd-large-coco.json \
+ --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \
+ --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \
+ --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \
+ --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \
+| tee -a ${OUTPUT_DIR}/train_log.txt
diff --git a/run_cmds/train_pnsgd_large_flickr.sh b/run_cmds/train_pnsgd_large_flickr.sh
new file mode 100644
index 0000000..7c913a9
--- /dev/null
+++ b/run_cmds/train_pnsgd_large_flickr.sh
@@ -0,0 +1,34 @@
+#!/usr/bin/env bash
+
+pwd
+
+export ROOT_DIR=.
+
+export NGPU=8
+export LR=5e-5
+export NSGD_SAMPLE_TEMPERATURE=2.0
+export NSGD_RANK_LAMBDA=0.001
+export MLM_LAMBDA=0.05
+export OUTPUT_DIR=${ROOT_DIR}/log/flickr/itm_pnsgd_large_v1
+
+export NEGATIVE_SIZE=399
+export HARD_NEG_SIZE=23
+export MLM_SAMPLE_SIZE=20
+export NSGD_SAMPLE_SIZE=20
+export TRAIN_BATCH_SIZE=32
+export MARGIN=0.2
+
+export VALID_STEPS=500
+export NUM_TRAIN_STEPS=5000
+export WARMUP_STEPS=500
+
+rm ${OUTPUT_DIR} -rf
+ls -lh ${OUTPUT_DIR}
+mkdir -p ${OUTPUT_DIR}
+
+horovodrun -np ${NGPU} python train_pnsgd.py --config config/train-itm-pnsgd-large-flickr.json \
+ --output_dir ${OUTPUT_DIR} --learning_rate ${LR} --negative_size ${NEGATIVE_SIZE} \
+ --hard_neg_size ${HARD_NEG_SIZE} --mlm_sample_size ${MLM_SAMPLE_SIZE} --nsgd_sample_size ${NSGD_SAMPLE_SIZE} \
+ --nsgd_sample_temperature ${NSGD_SAMPLE_TEMPERATURE} --train_batch_size ${TRAIN_BATCH_SIZE} \
+ --mlm_lambda ${MLM_LAMBDA} --nsgd_rank_lambda ${NSGD_RANK_LAMBDA} --margin ${MARGIN} \
+| tee -a ${OUTPUT_DIR}/train_log.txt
diff --git a/scripts/convert_ckpt.py b/scripts/convert_ckpt.py
new file mode 100644
index 0000000..9bb8a4e
--- /dev/null
+++ b/scripts/convert_ckpt.py
@@ -0,0 +1,13 @@
+import sys
+from collections import OrderedDict
+
+import torch
+
+bert_ckpt, output_ckpt = sys.argv[1:]
+
+bert = torch.load(bert_ckpt)
+uniter = OrderedDict()
+for k, v in bert.items():
+ uniter[k.replace('bert', 'uniter')] = v
+
+torch.save(uniter, output_ckpt)
diff --git a/scripts/convert_imgdir.py b/scripts/convert_imgdir.py
new file mode 100644
index 0000000..ca4440c
--- /dev/null
+++ b/scripts/convert_imgdir.py
@@ -0,0 +1,142 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+convert image npz to LMDB
+"""
+import argparse
+import glob
+import io
+import json
+import multiprocessing as mp
+import os
+from os.path import basename, exists
+
+from cytoolz import curry
+import numpy as np
+from tqdm import tqdm
+import lmdb
+
+import msgpack
+import msgpack_numpy
+msgpack_numpy.patch()
+
+
+def _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb):
+ num_bb = max(min_bb, (img_dump['conf'] > conf_th).sum())
+ num_bb = min(max_bb, num_bb)
+ return int(num_bb)
+
+
+@curry
+def load_npz(conf_th, max_bb, min_bb, num_bb, fname, keep_all=False):
+ try:
+ img_dump = np.load(fname, allow_pickle=True)
+ if keep_all:
+ nbb = None
+ else:
+ nbb = _compute_nbb(img_dump, conf_th, max_bb, min_bb, num_bb)
+ dump = {}
+ for key, arr in img_dump.items():
+ if arr.dtype == np.float32:
+ arr = arr.astype(np.float16)
+ if arr.ndim == 2:
+ dump[key] = arr[:nbb, :]
+ elif arr.ndim == 1:
+ dump[key] = arr[:nbb]
+ else:
+ raise ValueError('wrong ndim')
+ except Exception as e:
+ # corrupted file
+ print(f'corrupted file {fname}', e)
+ dump = {}
+ nbb = 0
+
+ name = basename(fname)
+ return name, dump, nbb
+
+
+def dumps_npz(dump, compress=False):
+ with io.BytesIO() as writer:
+ if compress:
+ np.savez_compressed(writer, **dump, allow_pickle=True)
+ else:
+ np.savez(writer, **dump, allow_pickle=True)
+ return writer.getvalue()
+
+
+def dumps_msgpack(dump):
+ return msgpack.dumps(dump, use_bin_type=True)
+
+
+def main(opts):
+ if opts.img_dir[-1] == '/':
+ opts.img_dir = opts.img_dir[:-1]
+ split = basename(opts.img_dir)
+ if opts.keep_all:
+ db_name = 'all'
+ else:
+ if opts.conf_th == -1:
+ db_name = f'feat_numbb{opts.num_bb}'
+ else:
+ db_name = (f'feat_th{opts.conf_th}_max{opts.max_bb}'
+ f'_min{opts.min_bb}')
+ if opts.compress:
+ db_name += '_compressed'
+ if not exists(f'{opts.output}/{split}'):
+ os.makedirs(f'{opts.output}/{split}')
+ env = lmdb.open(f'{opts.output}/{split}/{db_name}', map_size=1024**4)
+ txn = env.begin(write=True)
+ files = glob.glob(f'{opts.img_dir}/*.npz')
+ load = load_npz(opts.conf_th, opts.max_bb, opts.min_bb, opts.num_bb,
+ keep_all=opts.keep_all)
+ name2nbb = {}
+ with mp.Pool(opts.nproc) as pool, tqdm(total=len(files)) as pbar:
+ for i, (fname, features, nbb) in enumerate(
+ pool.imap_unordered(load, files, chunksize=128)):
+ if not features:
+ continue # corrupted feature
+ if opts.compress:
+ dump = dumps_npz(features, compress=True)
+ else:
+ dump = dumps_msgpack(features)
+ txn.put(key=fname.encode('utf-8'), value=dump)
+ if i % 1000 == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ name2nbb[fname] = nbb
+ pbar.update(1)
+ txn.put(key=b'__keys__',
+ value=json.dumps(list(name2nbb.keys())).encode('utf-8'))
+ txn.commit()
+ env.close()
+ if opts.conf_th != -1 and not opts.keep_all:
+ with open(f'{opts.output}/{split}/'
+ f'nbb_th{opts.conf_th}_'
+ f'max{opts.max_bb}_min{opts.min_bb}.json', 'w') as f:
+ json.dump(name2nbb, f)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_dir", default=None, type=str,
+ help="The input images.")
+ parser.add_argument("--output", default=None, type=str,
+ help="output lmdb")
+ parser.add_argument('--nproc', type=int, default=8,
+ help='number of cores used')
+ parser.add_argument('--compress', action='store_true',
+ help='compress the tensors')
+ parser.add_argument('--keep_all', action='store_true',
+ help='keep all features, overrides all following args')
+ parser.add_argument('--conf_th', type=float, default=0.2,
+ help='threshold for dynamic bounding boxes '
+ '(-1 for fixed)')
+ parser.add_argument('--max_bb', type=int, default=100,
+ help='max number of bounding boxes')
+ parser.add_argument('--min_bb', type=int, default=10,
+ help='min number of bounding boxes')
+ parser.add_argument('--num_bb', type=int, default=100,
+ help='number of bounding boxes (fixed)')
+ args = parser.parse_args()
+ main(args)
diff --git a/scripts/create_imgdb.sh b/scripts/create_imgdb.sh
new file mode 100644
index 0000000..21264bf
--- /dev/null
+++ b/scripts/create_imgdb.sh
@@ -0,0 +1,21 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+IMG_NPY=$1
+OUT_DIR=$2
+
+set -e
+
+echo "converting image features ..."
+if [ ! -d $OUT_DIR ]; then
+ mkdir -p $OUT_DIR
+fi
+NAME=$(basename $IMG_NPY)
+docker run --ipc=host --rm -it \
+ --mount src=$(pwd),dst=/src,type=bind \
+ --mount src=$OUT_DIR,dst=/img_db,type=bind \
+ --mount src=$IMG_NPY,dst=/$NAME,type=bind,readonly \
+ -w /src chenrocks/uniter \
+ python scripts/convert_imgdir.py --img_dir /$NAME --output /img_db
+
+echo "done"
diff --git a/scripts/create_txtdb.sh b/scripts/create_txtdb.sh
new file mode 100644
index 0000000..6789ef3
--- /dev/null
+++ b/scripts/create_txtdb.sh
@@ -0,0 +1,40 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+OUT_DIR=$1
+ANN_DIR=$2
+
+set -e
+
+URL='https://raw.githubusercontent.com/lil-lab/nlvr/master/nlvr2/data'
+if [ ! -d $OUT_DIR ]; then
+ mkdir -p $OUT_DIR
+fi
+if [ ! -d $ANN_DIR ]; then
+ mkdir -p $ANN_DIR
+fi
+
+BLOB='https://convaisharables.blob.core.windows.net/uniter'
+MISSING=$BLOB/ann/missing_nlvr2_imgs.json
+if [ ! -f $ANN_DIR/missing_nlvr2_imgs.json ]; then
+ wget $MISSING -O $ANN_DIR/missing_nlvr2_imgs.json
+fi
+
+for SPLIT in 'train' 'dev' 'test1'; do
+ if [ ! -f $ANN_DIR/$SPLIT.json ]; then
+ echo "downloading ${SPLIT} annotations..."
+ wget $URL/$SPLIT.json -O $ANN_DIR/$SPLIT.json
+ fi
+
+ echo "preprocessing ${SPLIT} annotations..."
+ docker run --ipc=host --rm -it \
+ --mount src=$(pwd),dst=/src,type=bind \
+ --mount src=$OUT_DIR,dst=/txt_db,type=bind \
+ --mount src=$ANN_DIR,dst=/ann,type=bind,readonly \
+ -w /src chenrocks/uniter \
+ python prepro.py --annotation /ann/$SPLIT.json \
+ --missing_imgs /ann/missing_nlvr2_imgs.json \
+ --output /txt_db/nlvr2_${SPLIT}.db
+done
+
+echo "done"
diff --git a/scripts/download_itm.sh b/scripts/download_itm.sh
new file mode 100644
index 0000000..77ac0ea
--- /dev/null
+++ b/scripts/download_itm.sh
@@ -0,0 +1,39 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+DOWNLOAD=$1
+
+for FOLDER in 'img_db' 'txt_db' 'pretrained' 'finetune'; do
+ if [ ! -d $DOWNLOAD/$FOLDER ] ; then
+ mkdir -p $DOWNLOAD/$FOLDER
+ fi
+done
+
+BLOB='https://convaisharables.blob.core.windows.net/uniter'
+
+# image dbs
+for SPLIT in 'train2014' 'val2014'; do
+ if [ ! -d $DOWNLOAD/img_db/coco_$SPLIT ] ; then
+ wget $BLOB/img_db/coco_$SPLIT.tar -P $DOWNLOAD/img_db/
+ tar -xvf $DOWNLOAD/img_db/coco_$SPLIT.tar -C $DOWNLOAD/img_db
+ fi
+done
+if [ ! -d $DOWNLOAD/img_db/flickr30k ] ; then
+ wget $BLOB/img_db/flickr30k.tar -P $DOWNLOAD/img_db/
+ tar -xvf $DOWNLOAD/img_db/flickr30k.tar -C $DOWNLOAD/img_db
+fi
+
+# text dbs
+for SPLIT in 'train' 'restval' 'val' 'test'; do
+ wget $BLOB/txt_db/itm_coco_$SPLIT.db.tar -P $DOWNLOAD/txt_db/
+ tar -xvf $DOWNLOAD/txt_db/itm_coco_$SPLIT.db.tar -C $DOWNLOAD/txt_db
+done
+for SPLIT in 'train' 'val' 'test'; do
+ wget $BLOB/txt_db/itm_flickr30k_$SPLIT.db.tar -P $DOWNLOAD/txt_db/
+ tar -xvf $DOWNLOAD/txt_db/itm_flickr30k_$SPLIT.db.tar -C $DOWNLOAD/txt_db
+done
+
+if [ ! -f $DOWNLOAD/pretrained/uniter-base.pt ] ; then
+ wget $BLOB/pretrained/uniter-base.pt -P $DOWNLOAD/pretrained/
+fi
+
diff --git a/scripts/eval_nlvr2.py b/scripts/eval_nlvr2.py
new file mode 100644
index 0000000..2e48569
--- /dev/null
+++ b/scripts/eval_nlvr2.py
@@ -0,0 +1,64 @@
+"""
+copied from official NLVR2 github
+(https://github.com/lil-lab/nlvr/tree/master/nlvr2)
+
+python scripts/eval_nlvr2.py
+"""
+import json
+import sys
+
+# Load the predictions file. Assume it is a CSV.
+predictions = { }
+for line in open(sys.argv[1]).readlines():
+ if line:
+ splits = line.strip().split(",")
+ # We assume identifiers are in the format "split-####-#-#.png".
+ identifier = splits[0]
+ prediction = splits[1]
+ predictions[identifier] = prediction
+
+# Load the labeled examples.
+labeled_examples = [json.loads(line) for line in open(sys.argv[2]).readlines() if line]
+
+# If not, identify the ones that are missing, and exit.
+total_num = len(labeled_examples)
+if len(predictions) < total_num:
+ print("Some predictions are missing!")
+ print("Got " + str(len(predictions)) + " predictions but expected " + str(total_num))
+
+ for example in labeled_examples:
+ lookup = example["identifier"]
+ if not lookup in predictions:
+ print("Missing prediction for item " + str(lookup))
+ exit()
+
+# Get the precision by iterating through the examples and checking the value
+# that was predicted.
+# Also update the "consistency" dictionary that keeps track of whether all
+# predictions for a given sentence were correct.
+num_correct = 0.
+consistency_dict = { }
+
+for example in labeled_examples:
+ anon_label = example["identifier"].split("-")
+ anon_label[2] = ''
+ anon_label = '-'.join(anon_label)
+ if not anon_label in consistency_dict:
+ consistency_dict[anon_label] = True
+ lookup = example["identifier"]
+ prediction = predictions[lookup]
+ if prediction.lower() == example["label"].lower():
+ num_correct += 1.
+ else:
+ consistency_dict[anon_label] = False
+
+# Calculate consistency.
+num_consistent = 0.
+unique_sentence = len(consistency_dict)
+for identifier, consistent in consistency_dict.items():
+ if consistent:
+ num_consistent += 1
+
+# Report values.
+print("accuracy=" + str(num_correct / total_num))
+print("consistency=" + str(num_consistent / unique_sentence))
diff --git a/scripts/eval_zs_itm_flickr.sh b/scripts/eval_zs_itm_flickr.sh
new file mode 100644
index 0000000..572beae
--- /dev/null
+++ b/scripts/eval_zs_itm_flickr.sh
@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+
+horovodrun -np $NGPU python inf_itm.py \
+ --txt_db ${TXT_DB} --img_db ${IMG_DB} --checkpoint ${PRETRAINED_MODEL} \
+ --model_config ${MODEL_CONFIG} --output_dir $ZS_ITM_RESULT --fp16 --pin_mem
\ No newline at end of file
diff --git a/scripts/extract_imgfeat.sh b/scripts/extract_imgfeat.sh
new file mode 100644
index 0000000..849589b
--- /dev/null
+++ b/scripts/extract_imgfeat.sh
@@ -0,0 +1,19 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+IMG_DIR=$1
+OUT_DIR=$2
+
+set -e
+
+echo "extracting image features..."
+if [ ! -d $OUT_DIR ]; then
+ mkdir -p $OUT_DIR
+fi
+docker run --gpus '"'device=$CUDA_VISIBLE_DEVICES'"' --ipc=host --rm \
+ --mount src=$IMG_DIR,dst=/img,type=bind,readonly \
+ --mount src=$OUT_DIR,dst=/output,type=bind \
+ -w /src chenrocks/butd-caffe:nlvr2 \
+ bash -c "python tools/generate_npz.py --gpu 0"
+
+echo "done"
diff --git a/scripts/train_itm.sh b/scripts/train_itm.sh
new file mode 100644
index 0000000..4aa70b8
--- /dev/null
+++ b/scripts/train_itm.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env bash
+
+ngpu=1
+db_root=${db_root} # ../itm-data
+txt_dir=${db_root}/text_db
+img_dir=${db_root}/img_db
+model_dir=${db_root}/pretrained
+config_dir=.
+zs_itm_result=./log
+mkdir -p ${zs_itm_result}
+
+horovodrun -np $ngpu python inf_itm.py \
+ --txt_db ${txt_dir}/itm_flickr30k_test.db --img_db ${img_dir}/img/flickr30k \
+ --checkpoint ${model_dir}/uniter-base.pt --model_config ${config_dir}/config/uniter-base.json \
+ --output_dir ${zs_itm_result} --fp16 --pin_mem
\ No newline at end of file
diff --git a/setup_container.sh b/setup_container.sh
new file mode 100644
index 0000000..4b3245c
--- /dev/null
+++ b/setup_container.sh
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+
+sudo apt-get remove docker docker-engine docker.io containerd runc
+
+sudo apt-get update
+
+sudo apt-get install apt-transport-https ca-certificates curl gnupg
+
+curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg
+
+echo "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu \
+ $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
+
+sudo apt-get update
+
+sudo apt-get install docker-ce docker-ce-cli containerd.io
+
+apt-cache madison docker-ce
+
+sudo apt-get install docker-ce=5:19.03.12~3-0~ubuntu-bionic docker-ce-cli=5:19.03.12~3-0~ubuntu-bionic containerd.io
+
+sudo docker run hello-world
+
+
+#INSTALL_DIR=~
+#source launch_container.sh ${INSTALL_DIR}/txt_db ${INSTALL_DIR}/img_db \
+# ${INSTALL_DIR}/finetune ${INSTALL_DIR}/pretrained
diff --git a/support/case_study.py b/support/case_study.py
new file mode 100644
index 0000000..064b7a2
--- /dev/null
+++ b/support/case_study.py
@@ -0,0 +1,49 @@
+import os
+import json
+import numpy as np
+
+
+def main():
+ log_dir = './log/adv_dc_visual2'
+ for filename in os.listdir(log_dir):
+ cont = False
+ for cand_imgname in [202175131, 539676201, 4373983146]:
+ if str(cand_imgname) in filename:
+ cont = True
+ if not cont:
+ continue
+ log = json.load(open(os.path.join(log_dir, filename), 'r'))
+ print('='*100)
+ print(log)
+ # print('='*100)
+ # print(log['idx'], log['imgname'])
+ # print("| gt_text: ", log['gt_text'])
+ gt_text = log['gt_text']
+ syn_texts = log['syn_texts']
+ disc_ps = log['disc_p']
+ correction_texts = log['correction_texts']
+ syn_scores = log['syn_scores']
+
+ i = 0
+ for syn_score, syn_text, disc_p, correction_text in zip(syn_scores, syn_texts, disc_ps, correction_texts):
+ syn_idt = (np.array(syn_text) == np.array(gt_text)).astype(np.int64)
+ new_correction_text = []
+ for idt, gt_token, correction_token in zip(syn_idt, gt_text, correction_text):
+ if idt == 1:
+ new_correction_text.append(gt_token)
+ else:
+ new_correction_text.append(correction_token)
+ correction_text = new_correction_text
+ # correction_text = (syn_idt * gt_text) + (correction_text * (1-syn_idt))
+
+ correction_idt = (np.array(correction_text) == np.array(gt_text)).astype(np.int64)
+ # print(syn_idt, correction_idt)
+ if correction_idt.sum() >= syn_idt.sum() + 2: # and (1 - correction_idt).sum() == 0:
+ print(i, syn_score, log['imgname'])
+ print(list(zip(gt_text, syn_text, correction_text, disc_p)))
+
+ i += 1
+
+
+if __name__ == '__main__':
+ main()
diff --git a/support/plot.py b/support/plot.py
new file mode 100644
index 0000000..3242e4d
--- /dev/null
+++ b/support/plot.py
@@ -0,0 +1,118 @@
+import os
+import json
+import math
+import numpy as np
+import matplotlib.pyplot as plt
+
+plt.style.use('ggplot')
+np.random.seed(19680801)
+
+
+def main():
+ data_dir = "/Users/fanzhihao/Documents/Research/NIPS2021"
+ kwargs = dict(alpha=0.2, bins=2000, density=True)
+
+ max_d = 4000.0
+ data0 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(0)), 'r')) if math.exp(x) < max_d]
+ print('data0', len(data0))
+ plt.hist(data0, color='g', **kwargs)
+ data1 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(1)), 'r')) if math.exp(x) < max_d]
+ print('data1', len(data1))
+ plt.hist(data1, color='b', **kwargs)
+ data2 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(2)), 'r')) if math.exp(x) < max_d]
+ print('data2', len(data2))
+ plt.hist(data2, color='r', **kwargs)
+ # plt.gca().set(title='Frequency Histogram of Diamond Depths', ylabel='Frequency')
+ plt.xlim(0, 250)
+ plt.show()
+
+
+def main2():
+ data_dir = "/Users/fanzhihao/Documents/Research/NIPS2021"
+ import seaborn as sns
+ # white, dark, whitegrid, darkgrid, ticks
+
+ sns.set_style("darkgrid")
+ # sns.set_style("ticks")
+
+ # Import data
+ # df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/diamonds.csv')
+ # x1 = df.loc[df.cut == 'Ideal', 'depth']
+ # x2 = df.loc[df.cut == 'Fair', 'depth']
+ # x3 = df.loc[df.cut == 'Good', 'depth']
+
+ # Plot
+ # kwargs = dict(hist_kws={'alpha': .6}, kde_kws={'linewidth': 2})
+ #
+ # plt.figure(figsize=(10, 7), dpi=80)
+ # sns.distplot(x1, color="dodgerblue", label="Compact", **kwargs)
+ # sns.distplot(x2, color="orange", label="SUV", **kwargs)
+ # sns.distplot(x3, color="deeppink", label="minivan", **kwargs)
+ # plt.xlim(50, 75)
+ # plt.legend();
+
+ kwargs = dict(alpha=0.8, bins=2000)
+
+ # plt.rcParams['figure.figsize'] = (40.0, 8.0)
+ max_d = 1500.0
+
+ kwargs = dict(hist_kws={'alpha': .2}, kde_kws={'linewidth': 2, 'shade': True}, bins=5000)
+
+ plt.figure(figsize=(9.2, 6), dpi=80)
+ # data0 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(0)), 'r')) if math.exp(x) < max_d]
+ data0 = [x for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(0)), 'r')) if x < max_d]
+ print('data0', len(data0))
+ line0 = sns.distplot(data0, **kwargs, color='g', hist=False, kde=True, label="Positive Text")
+ # data1 = [math.exp(x)+20 for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(1)), 'r')) if math.exp(x) < max_d]
+ data1 = [x for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(1)), 'r')) if x < max_d]
+ print('data1', len(data1))
+ line1 = sns.distplot(data1, **kwargs, color='b', hist=False, kde=True, label="Synthetic Text")
+ # data2 = [math.exp(x)+10 for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(2)), 'r')) if math.exp(x) < max_d]
+ data2 = [x for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(2)), 'r')) if x < max_d]
+ line2 = sns.distplot(data2, **kwargs, color='r', hist=False, kde=True, label="Corrected Text")
+ plt.xlim(0, 300)
+
+ plt.legend(loc="upper right", fontsize=20)
+
+ # plt.xticks(np.arange(-5, 5, 0.5), fontproperties='Times New Roman', size=10)
+ # plt.yticks(np.arange(-2, 2, 0.3), fontproperties='Times New Roman', size=10)
+
+ plt.xlabel('', fontdict={'family': 'Times New Roman', 'size': 20})
+ plt.ylabel('', fontdict={'family': 'Times New Roman', 'size': 20})
+
+ plt.yticks(np.arange(0.0, 0.012, 0.002), fontproperties='Times New Roman', size=20)
+ plt.xticks(np.arange(0, 300, 50), fontproperties='Times New Roman', size=20)
+ plt.legend(prop={'family': 'Times New Roman', 'size': 20})
+ # plt.rcParams['xtick.direction'] = 'out'
+ # plt.rcParams['ytick.direction'] = 'out'
+ # plt.bar([0, 50, 100, 150, 200, 250], [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001])
+
+ # plt.minorticks_on()
+ # plt.tick_params(which='major', width=4, length=10, direction="inout")
+
+ plt.show()
+
+ # plt.gca().set(title='Frequency Histogram of Diamond Depths', ylabel='Frequency')
+ # plt.show()
+ # plt.legend()
+
+
+def main3():
+ import matplotlib.pyplot as plt
+ import numpy as np
+ from scipy.stats import gaussian_kde
+
+ data_dir = "/Users/fanzhihao/Documents/Research/NIPS2021"
+ max_d = 4000.0
+ data0 = [math.exp(x) for x in json.load(open(os.path.join(data_dir, 'dataset_{}.json'.format(0)), 'r')) if math.exp(x) < max_d]
+
+ density = gaussian_kde(data0)
+ xs = np.linspace(0, 8, 200)
+ density.covariance_factor = lambda: .25
+ density._compute_covariance()
+ plt.plot(xs, density(xs))
+ plt.show()
+
+
+if __name__ == '__main__':
+ main2()
diff --git a/support/ppl_solver.py b/support/ppl_solver.py
new file mode 100644
index 0000000..32bf5cd
--- /dev/null
+++ b/support/ppl_solver.py
@@ -0,0 +1,140 @@
+import os
+import json
+
+import torch
+from transformers import GPT2LMHeadModel, GPT2TokenizerFast
+# from nlp import load_dataset
+from tqdm import tqdm
+import numpy as np
+import string
+
+
+def load_data():
+ puncs = list(set(list(string.punctuation)) - set([',', '.', '-', '?', '!', '\'']))
+ puncs = [punc + ' ' for punc in puncs]
+
+ def bpe2token(bpes):
+ text = ' '.join(bpes).replace(' ##', '').replace(" ,", ',').replace(" .", ".").replace(' - ', '-').\
+ replace(' ?', '?').replace(' !', '!').replace(' \' ', '\'')
+ for punc in puncs:
+ text = text.replace(punc, ' ')
+ text = text.strip()
+ return text
+
+ log_dir = './log/adv_dc_visual2'
+ dataset_gt_texts, dataset_syn_texts, dataset_correct_texts = [], [], []
+
+ max_num = 1
+ for filename in tqdm(os.listdir(log_dir)):
+ log = json.load(open(os.path.join(log_dir, filename), 'r'))
+ gt_text = log['gt_text'][1:-1]
+ syn_texts = [text[1:-1] for text in log['syn_texts'][:max_num]]
+ correct_texts = [text[1:-1] for text in log['correction_texts'][:max_num]]
+ # print("| gt_text: ", gt_text)
+ # print("| syn_text: ", syn_texts[0])
+ # print("| correct_text: ", correct_texts[0])
+ dataset_gt_texts.append(bpe2token(gt_text))
+ dataset_syn_texts.extend([bpe2token(syn_text) for syn_text in syn_texts])
+ dataset_correct_texts.extend([bpe2token(correct_text) for correct_text in correct_texts])
+ return dataset_gt_texts, dataset_syn_texts, dataset_correct_texts
+
+
+def main():
+ # device = 'cuda'
+ # from transformers import GPT2Tokenizer, GPT2LMHeadModel
+ # # Load pre-trained model (weights)
+ # with torch.no_grad():
+ # model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
+ # model.eval()
+ # # Load pre-trained model tokenizer (vocabulary)
+ # tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+
+ device = 'cuda'
+ model_id = 'gpt2-large'
+ model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
+ model.eval()
+ tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
+
+ # text = "From large scale power generators to the basic cooking in our homes, fuel is essential for all of these to happen and work."
+ # text = "A woman plays with finger puppets as a small child in a costume walks by."
+ # # "From large scale power generators to the basic cooking in our homes, fuel is essential for all of these to happen and work."
+ # input_ids = torch.tensor(tokenizer([text])["input_ids"]).to(device)
+ # target_ids = input_ids.clone()
+ # with torch.no_grad():
+ # log_likelihood = model(input_ids, labels=target_ids)[0]
+ # ppl = torch.exp(log_likelihood)
+ # print(np.exp(log_likelihood.data.cpu().numpy()))
+ # print('1: ', log_likelihood, ppl)
+ #
+ # raise Exception
+
+ # test = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
+ # encodings = tokenizer('\n\n'.join(test['text']), return_tensors='pt')
+
+ # max_length = model.config.n_positions
+ # stride = 512
+
+ dataset_gt_texts, dataset_syn_texts, dataset_correct_texts = load_data()
+ for i, dataset_texts in enumerate([dataset_gt_texts, dataset_syn_texts, dataset_correct_texts]):
+ lls = []
+ end_loc = 0
+ lls2 = []
+ end_loc2 = 0
+ lls_save = []
+ ppls = []
+ for text in tqdm(dataset_texts):
+ # print(" text: ", text, type(text))
+ # print('| text: ', text[0], text[1], text[-1])
+ input_ids = torch.tensor(tokenizer([text])["input_ids"]).to(device)
+ # print('| ', text, input_ids)
+ trg_len = input_ids.size(-1)
+ end_loc += 1
+ end_loc2 += trg_len
+ # print(type(input_ids))
+ # print(input_ids)
+ # print('| input_ids: ', input_ids.size(), input_ids)
+ input_ids = input_ids
+ target_ids = input_ids.clone()
+ # print('| input_ids', input_ids.size(), target_ids.size())
+ # print(input_ids.size(), target_ids.size())
+ # target_ids[-1] = -100
+ with torch.no_grad():
+ outputs = model(input_ids, labels=target_ids)
+ ll = outputs[0].item()
+ log_likelihood = outputs[0]
+ ppl = torch.exp(log_likelihood)
+ if ppl > 200.0:
+ print('| large ppl: ', text)
+ # print(ll, ppl)
+
+ ppls.append(ppl)
+ # print('| log_likelihood: ', log_likelihood, ppl)
+ lls.append(log_likelihood)
+ lls2.append(log_likelihood * trg_len)
+ lls_save.append(ppl.item())
+
+ ppl_mean = torch.stack(ppls).mean()
+ ppl = torch.exp(torch.stack(lls).sum() / end_loc)
+ ppl2 = torch.exp(torch.stack(lls2).sum() / end_loc2)
+ print('| ppl: ', ppl_mean.item(), ppl.item(), ppl2.item())
+
+ save_path = os.path.join('log', 'dataset_{}.json'.format(i))
+ with open(save_path, 'w') as fw:
+ json.dump(lls_save, fw)
+ fw.close()
+ print('ppl: ', ppl)
+
+ # for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
+ # begin_loc = max(i + stride - max_length, 0)
+ # end_loc = min(i + stride, encodings.input_ids.size(1))
+ # trg_len = end_loc - i # may be different from stride on last loop
+ # input_ids = encodings.input_ids[:,begin_loc:end_loc].to(device)
+ # target_ids = input_ids.clone()
+ # target_ids[:,:-trg_len] = -100
+ #
+ # with torch.no_grad():
+ # outputs = model(input_ids, labels=target_ids)
+ # log_likelihood = outputs[0] * trg_len
+
+if __name__ == '__main__':
+ main()
diff --git a/support/tensorboard_reload.py b/support/tensorboard_reload.py
new file mode 100644
index 0000000..ab30c27
--- /dev/null
+++ b/support/tensorboard_reload.py
@@ -0,0 +1,18 @@
+import os
+from tensorboard.backend.event_processing import event_accumulator
+
+
+def main():
+ data_dir = "/Users/fanzhihao/Documents/Research/NIPS2021"
+ tensorboard_path = os.path.join(data_dir, "events.out.tfevents.uniter.base.pm.dm")
+ ea = event_accumulator.EventAccumulator(tensorboard_path)
+ ea.Reload()
+ print(ea.scalars.Keys())
+
+ # val_psnr = ea.scalars.Items('val_psnr')
+ # print(len(val_psnr))
+ # print([(i.step, i.value) for i in val_psnr])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/train_pnsgd.py b/train_pnsgd.py
new file mode 100644
index 0000000..7be121e
--- /dev/null
+++ b/train_pnsgd.py
@@ -0,0 +1,578 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+UNITER finetuning for Image-Text Retrieval with hard negatives
+"""
+import argparse
+import os
+from os.path import exists, join
+from time import time
+import math
+
+import torch
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.data import DataLoader, ConcatDataset
+from torch.optim.lr_scheduler import LambdaLR
+from apex import amp
+from horovod import torch as hvd
+from tqdm import tqdm
+
+from data import (PrefetchLoader, TxtTokLmdb, ImageLmdbGroup,
+ PNSGDFromText, PNSGDFromImage, pnsgd_collate, itm_rank_hn_collate,
+ ItmValDataset, itm_val_collate,
+ ItmEvalDataset, itm_eval_collate)
+from model.nsgd import UniterForNSGD
+from optim import get_lr_sched
+from optim.misc import build_optimizer
+
+from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
+from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list,
+ broadcast_tensors)
+from utils.save import ModelSaver, save_training_meta
+from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed
+from utils.const import IMG_DIM
+from utils.itm_eval import evaluate
+
+
+def build_dataloader(dataset, collate_fn, is_train, opts):
+ dataloader = DataLoader(dataset, batch_size=1,
+ shuffle=is_train, drop_last=is_train,
+ num_workers=opts.n_workers,
+ pin_memory=opts.pin_mem, collate_fn=collate_fn)
+ dataloader = PrefetchLoader(dataloader)
+ return dataloader
+
+
+def build_lr_scheduler(opts, num_training_steps, optimizer):
+ num_warmup_steps = (
+ opts.warmup_steps
+ if opts.warmup_steps > 0
+ else opts.ceil(num_training_steps * opts.args.warmup_ratio)
+ )
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+ )
+ return LambdaLR(optimizer, lr_lambda)
+
+
+def load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint):
+ """If optimizer and scheduler states exist, load them."""
+ if checkpoint is None:
+ return
+
+ if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
+ os.path.join(checkpoint, "scheduler.pt")
+ ):
+ # Load in optimizer and scheduler states
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
+ optimizer.load_state_dict(
+ torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=device)
+ )
+ lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
+ return optimizer, lr_scheduler
+
+
+def main(opts):
+ hvd.init()
+ n_gpu = hvd.size()
+ device = torch.device("cuda", hvd.local_rank())
+ torch.cuda.set_device(hvd.local_rank())
+ rank = hvd.rank()
+ opts.rank = rank
+ LOGGER.info("device: {} n_gpu: {}, rank: {}, "
+ "16-bits training: {}".format(
+ device, n_gpu, hvd.rank(), opts.fp16))
+
+ set_random_seed(opts.seed)
+
+ if hvd.rank() == 0:
+ save_training_meta(opts)
+ TB_LOGGER.create(join(opts.output_dir, 'log'))
+ pbar = tqdm(total=opts.num_train_steps)
+ model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
+ add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
+ # store ITM predictions
+ os.makedirs(join(opts.output_dir, 'results_val'))
+ os.makedirs(join(opts.output_dir, 'results_test'))
+ os.makedirs(join(opts.output_dir, 'results_train'))
+ else:
+ LOGGER.disabled = True
+ pbar = NoOp()
+ model_saver = NoOp()
+
+ # train_examples = None
+ LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
+ f"{opts.train_img_dbs}")
+ # check multiple DBs
+ assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \
+ "train txt_db and img_db have different length"
+
+ # load DBs and image dirs
+ all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
+ opts.num_bb, opts.compressed_db)
+ # train
+ LOGGER.info(f"Loading Train Dataset "
+ f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
+ train_datasets_t = []
+ train_datasets_i = []
+ for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
+ img_db = all_img_dbs[img_path]
+ txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
+ train_datasets_t.append(
+ PNSGDFromText(txt_db, img_db, opts.negative_size, opts.mlm_sample_size))
+ train_datasets_i.append(
+ PNSGDFromImage(txt_db, img_db, opts.negative_size))
+ train_dataset_t = ConcatDataset(train_datasets_t)
+ train_dataset_i = ConcatDataset(train_datasets_i)
+ train_dataloader_t = build_dataloader(
+ train_dataset_t, pnsgd_collate, True, opts)
+ train_dataloader_i = build_dataloader(
+ train_dataset_i, pnsgd_collate, True, opts)
+
+ # val
+ LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
+ val_img_db = all_img_dbs[opts.val_img_db]
+ val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
+ val_dataset = ItmValDataset(val_txt_db, val_img_db,
+ opts.inf_minibatch_size)
+ val_dataloader = build_dataloader(val_dataset, itm_val_collate,
+ False, opts)
+ # eval
+ LOGGER.info(f"Loading val, test Dataset for full evaluation: "
+ f"{opts.val_txt_db}, {opts.val_img_db}"
+ f"{opts.test_txt_db}, {opts.test_img_db}")
+ eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
+ opts.inf_minibatch_size)
+ eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
+ False, opts)
+ test_img_db = all_img_dbs[opts.test_img_db]
+ test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
+ eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
+ opts.inf_minibatch_size)
+ eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
+ False, opts)
+
+ # Prepare model
+ if opts.checkpoint:
+ checkpoint = torch.load(opts.checkpoint)
+ else:
+ checkpoint = {}
+
+ model = UniterForNSGD.from_pretrained(
+ opts.model_config, state_dict=checkpoint,
+ img_dim=IMG_DIM, margin=opts.margin, hard_size=opts.hard_neg_size,
+ nsgd_sample_size=opts.nsgd_sample_size, nsgd_sample_temperature=opts.nsgd_sample_temperature)
+ model.init_output() # pretrain ITM head is different from ranking head
+ model.to(device)
+ # make sure every process has same model parameters in the beginning
+ broadcast_tensors([p.data for p in model.parameters()], 0)
+ set_dropout(model, opts.dropout)
+
+ # Prepare optimizer
+ optimizer = build_optimizer(model, opts)
+ # print('| optimizer file:', os.path.join(checkpoint, "optimizer.pt"))
+ # if os.path.exists(os.path.join(checkpoint, "optimizer.pt")) and \
+ # os.path.isfile(os.path.join(checkpoint, "optimizer.pt")):
+ # # Load in optimizer and scheduler states
+ # device = 'gpu' if torch.cuda.is_available() else 'cpu'
+ # optimizer.load_state_dict(torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=device))
+ model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16, opt_level='O2')
+ # Prepare scheduler
+ # lr_scheduler = build_lr_scheduler(opts, opts.num_train_steps, optimizer)
+ # print('| scheduler file:', os.path.join(checkpoint, "scheduler.pt"))
+ # if os.path.exists(os.path.join(checkpoint, "scheduler.pt")) and \
+ # os.path.isfile(os.path.join(checkpoint, "scheduler.pt")):
+ # lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, 'scheduler.pt')))
+
+ LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
+ LOGGER.info(" Num examples = %d",
+ sum(all_gather_list(len(train_dataset_t))))
+ LOGGER.info(" Batch size = %d", opts.train_batch_size)
+ LOGGER.info(" Num steps = %d", opts.num_train_steps)
+
+ # running_loss = RunningMeter('loss')
+ running_loss_dict = {loss: RunningMeter(loss) for loss in
+ ['mlm_loss', 'nsgd_rank_loss', 'i2t_hn_rank_loss', 't2i_hn_rank_loss']}
+ model.train()
+
+ global_step = 0
+ step = 0
+ n_examples = 0
+ n_hard_ex = 0
+ start = time()
+ train_iter_i = iter(train_dataloader_i)
+ # quick hack for amp delay_unscale bug
+ optimizer.zero_grad()
+ optimizer.step()
+ effect_nsgd_number, mlm_corrects, nsgd_rank_corrects, i2t_hn_rank_corrects, t2i_hn_rank_corrects = \
+ [], [[], []], [[], []], [[], []], [[], []]
+ incremental_adv_scores = []
+ while True:
+ for batch in train_dataloader_t:
+
+ # hard text from image
+ try:
+ batch_i = next(train_iter_i)
+ except StopIteration:
+ train_iter_i = iter(train_dataloader_i)
+ batch_i = next(train_iter_i)
+
+ n_examples += batch['mlm_attn_masks'].size(0)
+ model_outputs = model(
+ batch, sample_from='gi', compute_loss=True,
+ compute_mlm=True if isinstance(args.mlm_lambda, float) and args.mlm_lambda > 0 else False)
+ loss = 0.0
+ if model_outputs.get('masked_lm_loss') is not None:
+ mlm_loss = model_outputs.get('masked_lm_loss').mean()
+ # print('| mlm_loss: ', mlm_loss)
+ loss += args.mlm_lambda * mlm_loss
+ running_loss_dict.get('mlm_loss')(mlm_loss.item())
+ mlm_corrects[0].append(model_outputs['mlm_corrects'].sum().item())
+ mlm_corrects[1].append(model_outputs['mlm_corrects'].numel())
+ # print('mlm_corrects:', mlm_corrects[0][-1]/mlm_corrects[1][-1])
+ effect_nsgd_number.append(model_outputs['effect_nsgd_number'])
+ if model_outputs.get('rank_loss') is not None:
+ rank_adv_scores = model_outputs['rank_adv_scores'].squeeze(-1)
+ # print('| rank_adv_scores: ', rank_adv_scores.min(), rank_adv_scores.max())
+ incremental_adv_score = rank_adv_scores[1:] - rank_adv_scores[0]
+ incremental_adv_scores.append(incremental_adv_score)
+ nsgd_rank_loss = model_outputs.get('rank_loss')
+ rank_corrects = model_outputs.get('rank_corrects')
+ nsgd_hard_ex = rank_corrects.numel()
+ nsgd_rank_corrects[0].append(rank_corrects.sum().item())
+ nsgd_rank_corrects[1].append(nsgd_hard_ex)
+ # print('nsgd_rank_corrects:', nsgd_rank_corrects[0][-1]/nsgd_rank_corrects[1][-1])
+ n_hard_ex += nsgd_hard_ex
+ nsgd_rank_loss = nsgd_rank_loss.mean()
+ running_loss_dict.get('nsgd_rank_loss')(nsgd_rank_loss.item())
+ loss += args.nsgd_rank_lambda * nsgd_rank_loss.mean()
+ if isinstance(loss, torch.Tensor):
+ loss = loss / opts.train_batch_size
+ with amp.scale_loss(loss, optimizer, delay_unscale=True, # loss_id=0
+ ) as scaled_loss:
+ scaled_loss.backward()
+
+ n_examples += batch_i['attn_masks'].size(0)
+ model_outputs = model(batch_i, sample_from='i', compute_loss=True)
+ loss = model_outputs.get('rank_loss')
+ rank_corrects = model_outputs.get('rank_corrects')
+ i2t_hard_ex = rank_corrects.numel()
+ i2t_hn_rank_corrects[0].append(rank_corrects.sum().item())
+ i2t_hn_rank_corrects[1].append(i2t_hard_ex)
+ n_hard_ex += i2t_hard_ex
+ loss = loss.mean() / opts.train_batch_size
+ with amp.scale_loss(loss, optimizer, delay_unscale=True, # loss_id=1
+ ) as scaled_loss:
+ scaled_loss.backward()
+ running_loss_dict.get('t2i_hn_rank_loss')(loss.item())
+
+ # hard image from text
+ n_examples += batch['attn_masks'].size(0)
+ model_outputs = model(batch, sample_from='t', compute_loss=True)
+ loss = model_outputs.get('rank_loss')
+ rank_corrects = model_outputs.get('rank_corrects')
+ t2i_hard_ex = rank_corrects.numel()
+ t2i_hn_rank_corrects[0].append(rank_corrects.sum().item())
+ t2i_hn_rank_corrects[1].append(t2i_hard_ex)
+ n_hard_ex += t2i_hard_ex
+ # NOTE we use gradient accumulation to implemented train_batch_size
+ loss = loss.mean() / opts.train_batch_size
+
+ step += 1
+ delay_unscale = step % opts.train_batch_size != 0
+ with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale, #loss_id=2
+ ) as scaled_loss:
+ scaled_loss.backward()
+ if not delay_unscale:
+ # gather gradients from every processes
+ # do this before unscaling to make sure every process uses
+ # the same gradient scale
+ grads = [p.grad.data for p in model.parameters()
+ if p.requires_grad and p.grad is not None]
+ all_reduce_and_rescale_tensors(grads, float(1))
+ running_loss_dict.get('i2t_hn_rank_loss')(loss.item())
+
+ if step % opts.train_batch_size == 0:
+ global_step += 1
+
+ # learning rate scheduling
+ lr_this_step = get_lr_sched(global_step, opts)
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr_this_step
+ TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
+
+ # log loss
+ # NOTE: not gathered across GPUs for efficiency
+ incremental_adv_scores = torch.cat(incremental_adv_scores, dim=0)
+ TB_LOGGER.add_histogram(
+ 'incremental_adv_score', incremental_adv_scores.data.cpu().numpy(), global_step)
+ incremental_adv_scores = []
+
+ TB_LOGGER.add_scalar(
+ 'mlm_loss', running_loss_dict['mlm_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 'nsgd_rank_loss', running_loss_dict['nsgd_rank_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 'i2t_hn_rank_loss', running_loss_dict['i2t_hn_rank_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 't2i_hn_rank_loss', running_loss_dict['t2i_hn_rank_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 'nsgd_rank_cr', sum(nsgd_rank_corrects[0]) / float(sum(nsgd_rank_corrects[1])), global_step)
+ TB_LOGGER.add_scalar(
+ 'i2t_hn_rank_cr', sum(i2t_hn_rank_corrects[0]) / float(sum(i2t_hn_rank_corrects[1])), global_step)
+ TB_LOGGER.add_scalar(
+ 't2i_hn_rank_cr', sum(t2i_hn_rank_corrects[0]) / float(sum(t2i_hn_rank_corrects[1])), global_step)
+ TB_LOGGER.add_scalar(
+ 'effect_nsgd_number', sum(effect_nsgd_number) / float(len(effect_nsgd_number)), global_step)
+ TB_LOGGER.step()
+ effect_nsgd_number, mlm_corrects, nsgd_rank_corrects, i2t_hn_rank_corrects, t2i_hn_rank_corrects = \
+ [], [[], []], [[], []], [[], []], [[], []]
+
+ # update model params
+ if opts.grad_norm != -1:
+ grad_norm = clip_grad_norm_(amp.master_params(optimizer),
+ opts.grad_norm)
+ TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
+ optimizer.step()
+ optimizer.zero_grad()
+ pbar.update(1)
+
+ if global_step % 50 == 0:
+ # monitor training throughput
+ LOGGER.info(f'------------Step {global_step}-------------')
+ tot_ex = sum(all_gather_list(n_examples))
+ ex_per_sec = int(tot_ex / (time()-start))
+ tot_hn = sum(all_gather_list(n_hard_ex))
+ hn_per_sec = int(tot_hn / (time()-start))
+ LOGGER.info(f'{tot_ex} ({tot_hn}) examples (hard) '
+ f'trained at {ex_per_sec} ({hn_per_sec}) ex/s')
+ TB_LOGGER.add_scalar('perf/ex_per_s',
+ ex_per_sec, global_step)
+ TB_LOGGER.add_scalar('perf/hn_per_s',
+ hn_per_sec, global_step)
+ LOGGER.info(f'-------------------------------------------')
+
+ if global_step % opts.valid_steps == 0:
+ if opts.full_val:
+ LOGGER.info(
+ f"========================== Step {global_step} "
+ f"==========================")
+ val_log = evaluate(model, eval_loader_val)
+ try:
+ TB_LOGGER.log_scaler_dict(
+ {f"valid/{k}": v for k, v in val_log.items()})
+ LOGGER.info(f"image retrieval R1: "
+ f"{val_log['img_r1']*100:.2f},\n"
+ f"image retrieval R5: "
+ f"{val_log['img_r5']*100:.2f},\n"
+ f"image retrieval R10: "
+ f"{val_log['img_r10']*100:.2f}\n"
+ f"text retrieval R1: "
+ f"{val_log['txt_r1']*100:.2f},\n"
+ f"text retrieval R5: "
+ f"{val_log['txt_r5']*100:.2f},\n"
+ f"text retrieval R10: "
+ f"{val_log['txt_r10']*100:.2f}")
+ LOGGER.info("================================="
+ "=================================")
+ except KeyError:
+ pass
+ else:
+ val_log = validate(model, val_dataloader)
+ TB_LOGGER.log_scaler_dict(val_log)
+ model_saver.save(model, global_step, optimizer=optimizer)
+ # torch.save(optimizer.state_dict(), os.path.join())
+ # torch.save(lr_scheduler.state_dict(), os.path.join())
+
+ if global_step >= opts.num_train_steps:
+ break
+
+ if global_step >= opts.num_train_steps:
+ break
+
+ pbar.close()
+ # final validation
+ val_log = validate(model, val_dataloader)
+ TB_LOGGER.log_scaler_dict(val_log)
+ model_saver.save(model, f'{global_step}_final')
+
+ # evaluation
+ for split, loader in [('val', eval_loader_val),
+ ('test', eval_loader_test)]:
+ eval_log = evaluate(model, loader)
+ TB_LOGGER.log_scaler_dict({f"eval/{split}_{k}": v
+ for k, v in eval_log.items()})
+ if hvd.rank() != 0:
+ continue
+ LOGGER.info(
+ f"========================= {split} ===========================\n"
+ f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
+ f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
+ f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
+ f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
+ f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
+ f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
+ LOGGER.info("=========================================================")
+
+
+@torch.no_grad()
+def validate(model, val_loader):
+ if hvd.rank() == 0:
+ pbar = tqdm(total=len(val_loader))
+ else:
+ pbar = NoOp()
+ LOGGER.info("start running Image Retrieval validation ...")
+ model.eval()
+ n_ex = 0
+ st = time()
+
+ recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0
+ for batch in val_loader:
+ model_outputs = model(batch, compute_loss=False)
+ if isinstance(model_outputs, dict):
+ scores = model_outputs['rank_scores']
+ else:
+ scores = model_outputs
+ _, indices = scores.squeeze(1).topk(10, dim=0)
+ rank = (indices == 0).nonzero()
+ if rank.numel():
+ rank = rank.item()
+ if rank < 1:
+ recall_at_1 += 1
+ if rank < 5:
+ recall_at_5 += 1
+ if rank < 10:
+ recall_at_10 += 1
+ n_ex += 1
+ pbar.update(1)
+ n_ex = sum(all_gather_list(n_ex))
+ recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex
+ recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex
+ recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex
+ tot_time = time()-st
+ val_log = {'valid/ex_per_s': n_ex/tot_time,
+ 'valid/recall_1': recall_at_1,
+ 'valid/recall_5': recall_at_5,
+ 'valid/recall_10': recall_at_10}
+ model.train()
+ LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
+ f"recall_1: {recall_at_1*100:.2f}, "
+ f"recall_5: {recall_at_5*100:.2f}, "
+ f"recall_10: {recall_at_10*100:.2f}")
+ pbar.close()
+ return val_log
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+
+ parser.add_argument('--compressed_db', action='store_true',
+ help='use compressed LMDB')
+ parser.add_argument("--checkpoint",
+ default=None, type=str,
+ help="pretrained MLM")
+
+ parser.add_argument("--output_dir", default=None, type=str,
+ help="The output directory where the model "
+ "checkpoints will be written.")
+
+ # Prepro parameters
+ parser.add_argument('--max_txt_len', type=int, default=60,
+ help='max number of tokens in text (BERT BPE)')
+ parser.add_argument('--conf_th', type=float, default=0.2,
+ help='threshold for dynamic bounding boxes '
+ '(-1 for fixed)')
+ parser.add_argument('--max_bb', type=int, default=100,
+ help='max number of bounding boxes')
+ parser.add_argument('--min_bb', type=int, default=10,
+ help='min number of bounding boxes')
+ parser.add_argument('--num_bb', type=int, default=36,
+ help='static number of bounding boxes')
+
+ # training parameters
+ parser.add_argument("--train_batch_size", default=32, type=int,
+ help="batch size (# positive examples) for training. "
+ "(implemented with gradient accumulation)")
+
+ parser.add_argument("--negative_size", default=511, type=int,
+ help="Number of negative samples per positive sample"
+ "(forward only)")
+ parser.add_argument("--hard_neg_size", default=31, type=int,
+ help="Number of hard negative samples "
+ "per positive sample (acutally used to train)")
+ parser.add_argument("--mlm_sample_size", default=22, type=int,
+ help="Number of samples following masked language masking"
+ "per positive sample (acutally used to train)")
+ parser.add_argument("--nsgd_sample_size", default=22, type=int,
+ help="Number of NSGD for each mlm sample"
+ "per positive sample (acutally used to train)")
+ parser.add_argument("--nsgd_sample_temperature", default=2.0, type=float,
+ help="sampling temperature of NSGD sampling. ")
+
+ parser.add_argument("--mlm_lambda", default=0.1, type=float,
+ help="lambda in training of mask language modeling")
+ parser.add_argument("--nsgd_rank_lambda", default=1.0, type=float,
+ help="lambda in training of NSGD ranking loss")
+ parser.add_argument("--margin", default=0.2, type=float,
+ help="margin of ranking loss")
+ parser.add_argument("--learning_rate", default=3e-5, type=float,
+ help="The initial learning rate for Adam.")
+ parser.add_argument("--valid_steps", default=1000, type=int,
+ help="Run validation every X steps")
+ parser.add_argument("--num_train_steps", default=100000, type=int,
+ help="Total number of training updates to perform.")
+ parser.add_argument("--optim", default='adam',
+ choices=['adam', 'adamax', 'adamw'],
+ help="optimizer")
+ parser.add_argument("--betas", default=[0.9, 0.98], nargs='+',
+ help="beta for adam optimizer")
+ parser.add_argument("--dropout", default=0.1, type=float,
+ help="tune dropout regularization")
+ parser.add_argument("--weight_decay", default=0.01, type=float,
+ help="weight decay (L2) regularization")
+ parser.add_argument("--grad_norm", default=0.25, type=float,
+ help="gradient clipping (-1 for no clipping)")
+ parser.add_argument("--warmup_steps", default=4000, type=int,
+ help="Number of training steps to perform linear "
+ "learning rate warmup for.")
+
+ # device parameters
+ parser.add_argument('--seed', type=int, default=42,
+ help="random seed for initialization")
+ parser.add_argument('--full_val', action='store_true',
+ help="Always run full evaluation during training")
+ parser.add_argument('--fp16', action='store_true',
+ help="Whether to use 16-bit float precision instead "
+ "of 32-bit")
+ parser.add_argument('--n_workers', type=int, default=4,
+ help="number of data workers")
+ parser.add_argument('--pin_mem', action='store_true',
+ help="pin memory")
+
+ # can use config files
+ parser.add_argument('--config', help='JSON config files')
+
+ args = parse_with_config(parser)
+
+ # if exists(args.output_dir) and os.listdir(args.output_dir):
+ # raise ValueError("Output directory ({}) already exists and is not "
+ # "empty.".format(args.output_dir))
+
+ # options safe guard
+ if args.conf_th == -1:
+ assert args.max_bb + args.max_txt_len + 2 <= 512
+ else:
+ assert args.num_bb + args.max_txt_len + 2 <= 512
+
+ # for tensor core
+ assert (args.negative_size+1) % 8 == (args.hard_neg_size+1) % 8 == 0
+
+ print('| args: ', args)
+ main(args)
diff --git a/train_pnsgd2.py b/train_pnsgd2.py
new file mode 100644
index 0000000..0ddb3b9
--- /dev/null
+++ b/train_pnsgd2.py
@@ -0,0 +1,630 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+UNITER finetuning for Image-Text Retrieval with hard negatives
+"""
+import argparse
+import os
+from os.path import exists, join
+from time import time
+import math
+
+import torch
+from torch.nn.utils import clip_grad_norm_
+from torch.utils.data import DataLoader, ConcatDataset
+from torch.optim.lr_scheduler import LambdaLR
+from apex import amp
+from horovod import torch as hvd
+from tqdm import tqdm
+
+from data import (PrefetchLoader, TxtTokLmdb, ImageLmdbGroup,
+ PNSGDFromText, PNSGDFromImage, pnsgd_collate, itm_rank_hn_collate,
+ ItmValDataset, itm_val_collate,
+ ItmEvalDataset, itm_eval_collate)
+from model.nsgd2 import UniterForNSGD2
+from optim import get_lr_sched
+from optim.misc import build_optimizer
+
+from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
+from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list,
+ broadcast_tensors)
+from utils.save import ModelSaver, save_training_meta
+from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed
+from utils.const import IMG_DIM
+from utils.itm_eval import evaluate
+
+
+def build_dataloader(dataset, collate_fn, is_train, opts):
+ dataloader = DataLoader(dataset, batch_size=1,
+ shuffle=is_train, drop_last=is_train,
+ num_workers=opts.n_workers,
+ pin_memory=opts.pin_mem, collate_fn=collate_fn)
+ dataloader = PrefetchLoader(dataloader)
+ return dataloader
+
+
+def build_lr_scheduler(opts, num_training_steps, optimizer):
+ num_warmup_steps = (
+ opts.warmup_steps
+ if opts.warmup_steps > 0
+ else opts.ceil(num_training_steps * opts.args.warmup_ratio)
+ )
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+ )
+ return LambdaLR(optimizer, lr_lambda)
+
+
+def load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint):
+ """If optimizer and scheduler states exist, load them."""
+ if checkpoint is None:
+ return
+
+ if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
+ os.path.join(checkpoint, "scheduler.pt")
+ ):
+ # Load in optimizer and scheduler states
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
+ optimizer.load_state_dict(
+ torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=device)
+ )
+ lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
+ return optimizer, lr_scheduler
+
+
+def main(opts):
+ hvd.init()
+ n_gpu = hvd.size()
+ device = torch.device("cuda", hvd.local_rank())
+ torch.cuda.set_device(hvd.local_rank())
+ rank = hvd.rank()
+ opts.rank = rank
+ LOGGER.info("device: {} n_gpu: {}, rank: {}, "
+ "16-bits training: {}".format(
+ device, n_gpu, hvd.rank(), opts.fp16))
+
+ set_random_seed(opts.seed)
+
+ if hvd.rank() == 0:
+ save_training_meta(opts)
+ TB_LOGGER.create(join(opts.output_dir, 'log'))
+ pbar = tqdm(total=opts.num_train_steps)
+ model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
+ add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
+ # store ITM predictions
+ os.makedirs(join(opts.output_dir, 'results_val'))
+ os.makedirs(join(opts.output_dir, 'results_test'))
+ os.makedirs(join(opts.output_dir, 'results_train'))
+ else:
+ LOGGER.disabled = True
+ pbar = NoOp()
+ model_saver = NoOp()
+
+ # train_examples = None
+ LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, "
+ f"{opts.train_img_dbs}")
+ # check multiple DBs
+ assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \
+ "train txt_db and img_db have different length"
+
+ # load DBs and image dirs
+ all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
+ opts.num_bb, opts.compressed_db)
+ # train
+ LOGGER.info(f"Loading Train Dataset "
+ f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
+ train_datasets_t = []
+ train_datasets_i = []
+ for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
+ img_db = all_img_dbs[img_path]
+ txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
+ train_datasets_t.append(
+ PNSGDFromText(txt_db, img_db, opts.negative_size, opts.mlm_sample_size))
+ train_datasets_i.append(
+ PNSGDFromImage(txt_db, img_db, opts.negative_size))
+ train_dataset_t = ConcatDataset(train_datasets_t)
+ train_dataset_i = ConcatDataset(train_datasets_i)
+ train_dataloader_t = build_dataloader(
+ train_dataset_t, pnsgd_collate, True, opts)
+ train_dataloader_i = build_dataloader(
+ train_dataset_i, pnsgd_collate, True, opts)
+
+ # val
+ LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}")
+ val_img_db = all_img_dbs[opts.val_img_db]
+ val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
+ val_dataset = ItmValDataset(val_txt_db, val_img_db,
+ opts.inf_minibatch_size)
+ val_dataloader = build_dataloader(val_dataset, itm_val_collate,
+ False, opts)
+ # eval
+ LOGGER.info(f"Loading val, test Dataset for full evaluation: "
+ f"{opts.val_txt_db}, {opts.val_img_db}"
+ f"{opts.test_txt_db}, {opts.test_img_db}")
+ eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db,
+ opts.inf_minibatch_size)
+ eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate,
+ False, opts)
+ test_img_db = all_img_dbs[opts.test_img_db]
+ test_txt_db = TxtTokLmdb(opts.test_txt_db, -1)
+ eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db,
+ opts.inf_minibatch_size)
+ eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate,
+ False, opts)
+
+ # Prepare model
+ if opts.checkpoint:
+ checkpoint = torch.load(opts.checkpoint)
+ else:
+ checkpoint = {}
+
+ model = UniterForNSGD2.from_pretrained(
+ opts.model_config, state_dict=checkpoint,
+ img_dim=IMG_DIM, margin=opts.margin, hard_size=opts.hard_neg_size,
+ nsgd_sample_size=opts.nsgd_sample_size, nsgd_sample_temperature=opts.nsgd_sample_temperature)
+ model.init_output() # pretrain ITM head is different from ranking head
+ model.to(device)
+ # for name, param in model.named_parameters():
+ # print(name, param.size())
+ # make sure every process has same model parameters in the beginning
+ broadcast_tensors([p.data for p in model.parameters()], 0)
+ set_dropout(model, opts.dropout)
+
+ # Prepare optimizer
+ optimizer = build_optimizer(model, opts)
+ # print('| optimizer file:', os.path.join(checkpoint, "optimizer.pt"))
+ # if os.path.exists(os.path.join(checkpoint, "optimizer.pt")) and \
+ # os.path.isfile(os.path.join(checkpoint, "optimizer.pt")):
+ # # Load in optimizer and scheduler states
+ # device = 'gpu' if torch.cuda.is_available() else 'cpu'
+ # optimizer.load_state_dict(torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location=device))
+ model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16, opt_level='O0') # , num_losses=4)
+ # Prepare scheduler
+ # lr_scheduler = build_lr_scheduler(opts, opts.num_train_steps, optimizer)
+ # print('| scheduler file:', os.path.join(checkpoint, "scheduler.pt"))
+ # if os.path.exists(os.path.join(checkpoint, "scheduler.pt")) and \
+ # os.path.isfile(os.path.join(checkpoint, "scheduler.pt")):
+ # lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, 'scheduler.pt')))
+
+ LOGGER.info(f"***** Running training on {n_gpu} GPUs *****")
+ LOGGER.info(" Num examples = %d",
+ sum(all_gather_list(len(train_dataset_t))))
+ LOGGER.info(" Batch size = %d", opts.train_batch_size)
+ LOGGER.info(" Num steps = %d", opts.num_train_steps)
+
+ # running_loss = RunningMeter('loss')
+ running_loss_dict = {loss: RunningMeter(loss) for loss in
+ ['mlm_loss', 'nsgd_rank_loss', 'i2t_hn_rank_loss', 't2i_hn_rank_loss',
+ 'disc_loss', 'correction_loss']}
+ model.train()
+
+ global_step = 0
+ step = 0
+ n_examples = 0
+ n_hard_ex = 0
+ start = time()
+ train_iter_i = iter(train_dataloader_i)
+ # quick hack for amp delay_unscale bug
+ optimizer.zero_grad()
+ optimizer.step()
+ effect_nsgd_number, mlm_corrects, nsgd_rank_corrects, i2t_hn_rank_corrects, t2i_hn_rank_corrects = \
+ [], [[], []], [[], []], [[], []], [[], []]
+ correction_corrects = [[], []]
+ disc_pos_corrects, disc_neg_corrects = [[], []], [[], []]
+ while True:
+ for batch in train_dataloader_t:
+
+ # hard text from image
+ try:
+ batch_i = next(train_iter_i)
+ except StopIteration:
+ train_iter_i = iter(train_dataloader_i)
+ batch_i = next(train_iter_i)
+
+ n_examples += batch['mlm_attn_masks'].size(0)
+ model_outputs = model(
+ batch, sample_from='gi', compute_loss=True,
+ compute_mlm=True if isinstance(args.mlm_lambda, float) and args.mlm_lambda > 0 else False)
+ loss = 0.0
+ if model_outputs.get('masked_lm_loss') is not None:
+ mlm_loss = model_outputs.get('masked_lm_loss').mean()
+ loss += args.mlm_lambda * mlm_loss
+ running_loss_dict.get('mlm_loss')(mlm_loss.item() / opts.train_batch_size)
+ # print('| mlm_loss:', mlm_loss.item() / opts.train_batch_size)
+ mlm_corrects[0].append(model_outputs['mlm_corrects'].sum().item())
+ mlm_corrects[1].append(model_outputs['mlm_corrects'].numel())
+ # print('mlm_corrects:', mlm_corrects[0][-1]/mlm_corrects[1][-1])
+ effect_nsgd_number.append(model_outputs['effect_nsgd_number'])
+ if model_outputs.get('rank_loss') is not None:
+ nsgd_rank_loss = model_outputs.get('rank_loss').mean()
+ running_loss_dict.get('nsgd_rank_loss')(nsgd_rank_loss.item() / opts.train_batch_size)
+ # print('| nsgd_rank_loss:', nsgd_rank_loss.item() / opts.train_batch_size)
+ loss += args.nsgd_rank_lambda * nsgd_rank_loss
+ rank_corrects = model_outputs.get('rank_corrects')
+ nsgd_hard_ex = rank_corrects.numel()
+ nsgd_rank_corrects[0].append(rank_corrects.sum().item())
+ nsgd_rank_corrects[1].append(nsgd_hard_ex)
+ n_hard_ex += nsgd_hard_ex
+ if isinstance(loss, torch.Tensor):
+ loss = loss / opts.train_batch_size
+ with amp.scale_loss(loss, optimizer, delay_unscale=False, # loss_id=0
+ ) as scaled_loss:
+ scaled_loss.backward()
+
+ # discrimination and correction of synthetic text samples
+ batch['syn_input_ids'] = model_outputs['syn_input_ids']
+ batch['syn_position_ids'] = model_outputs['syn_position_ids']
+ batch['syn_img_feat'] = model_outputs['syn_img_feat']
+ batch['syn_img_pos_feat'] = model_outputs['syn_img_pos_feat']
+ batch['syn_attn_masks'] = model_outputs['syn_attn_masks']
+ batch['syn_gather_index'] = model_outputs['syn_gather_index']
+ batch['syn_txt_labels'] = model_outputs['syn_txt_labels']
+ batch['gt_input_ids'] = model_outputs['gt_input_ids']
+ n_examples += batch['syn_input_ids'].size(0)
+ model_outputs = model(
+ batch, sample_from='gsynt', compute_loss=True)
+ loss = 0.0
+ if model_outputs.get('disc_loss') is not None:
+ disc_loss = model_outputs.get('disc_loss').mean()
+ loss += args.disc_lambda * disc_loss
+ # print('| disc_loss:', disc_loss.item() / opts.train_batch_size)
+ running_loss_dict.get('disc_loss')(disc_loss.item() / opts.train_batch_size)
+ disc_pos_corrects[0].append(model_outputs['disc_pos_corrects'].item())
+ disc_pos_corrects[1].append(model_outputs['disc_pos_samples'].item())
+ disc_neg_corrects[0].append(model_outputs['disc_neg_corrects'].item())
+ disc_neg_corrects[1].append(model_outputs['disc_neg_samples'].item())
+ if model_outputs.get('correction_loss') is not None:
+ correction_loss = model_outputs.get('correction_loss').mean()
+ loss += args.correction_lambda * correction_loss
+ # print('| correction_loss:', correction_loss.item() / opts.train_batch_size)
+ running_loss_dict.get('correction_loss')(correction_loss.item() / opts.train_batch_size)
+ correction_corrects[0].append(model_outputs['correction_corrects'].sum().item())
+ correction_corrects[1].append(model_outputs['correction_corrects'].numel())
+ if isinstance(loss, torch.Tensor):
+ loss = loss / opts.train_batch_size
+ with amp.scale_loss(loss, optimizer, delay_unscale=False, # loss_id=1
+ ) as scaled_loss:
+ scaled_loss.backward()
+
+ # ITM training with hard negative samples
+ n_examples += batch_i['attn_masks'].size(0)
+ model_outputs = model(batch_i, sample_from='i', compute_loss=True)
+ loss = model_outputs.get('rank_loss')
+ rank_corrects = model_outputs.get('rank_corrects')
+ i2t_hard_ex = rank_corrects.numel()
+ i2t_hn_rank_corrects[0].append(rank_corrects.sum().item())
+ i2t_hn_rank_corrects[1].append(i2t_hard_ex)
+ n_hard_ex += i2t_hard_ex
+ loss = loss.mean() / opts.train_batch_size
+
+ with amp.scale_loss(loss, optimizer, delay_unscale=False, # loss_id=2
+ ) as scaled_loss:
+ scaled_loss.backward()
+ running_loss_dict.get('t2i_hn_rank_loss')(loss.item())
+
+ # hard image from text
+ n_examples += batch['attn_masks'].size(0)
+ model_outputs = model(batch, sample_from='t', compute_loss=True)
+ loss = model_outputs.get('rank_loss')
+ rank_corrects = model_outputs.get('rank_corrects')
+ t2i_hard_ex = rank_corrects.numel()
+ t2i_hn_rank_corrects[0].append(rank_corrects.sum().item())
+ t2i_hn_rank_corrects[1].append(t2i_hard_ex)
+ n_hard_ex += t2i_hard_ex
+ # NOTE we use gradient accumulation to implemented train_batch_size
+ loss = loss.mean() / opts.train_batch_size
+ running_loss_dict.get('i2t_hn_rank_loss')(loss.item())
+ # print('| t2i_triplet_loss:', loss.item())
+
+ step += 1
+ delay_unscale = step % opts.train_batch_size != 0
+ with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale, # loss_id=3
+ ) as scaled_loss:
+ scaled_loss.backward()
+ if not delay_unscale:
+ # gather gradients from every processes
+ # do this before unscaling to make sure every process uses
+ # the same gradient scale
+ grads = [p.grad.data for p in model.parameters()
+ if p.requires_grad and p.grad is not None]
+ all_reduce_and_rescale_tensors(grads, float(1))
+
+ if step % opts.train_batch_size == 0:
+ global_step += 1
+
+ # learning rate scheduling
+ lr_this_step = get_lr_sched(global_step, opts)
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr_this_step
+ TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
+
+ # log loss
+ # NOTE: not gathered across GPUs for efficiency
+ TB_LOGGER.add_scalar(
+ 'mlm_loss', running_loss_dict['mlm_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 'nsgd_rank_loss', running_loss_dict['nsgd_rank_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 'i2t_hn_rank_loss', running_loss_dict['i2t_hn_rank_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 't2i_hn_rank_loss', running_loss_dict['t2i_hn_rank_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 'disc_loss', running_loss_dict['disc_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 'correction_loss', running_loss_dict['correction_loss'].val, global_step)
+ TB_LOGGER.add_scalar(
+ 'nsgd_rank_cr', sum(nsgd_rank_corrects[0]) / float(sum(nsgd_rank_corrects[1])), global_step)
+ TB_LOGGER.add_scalar(
+ 'i2t_hn_rank_cr', sum(i2t_hn_rank_corrects[0]) / float(sum(i2t_hn_rank_corrects[1])), global_step)
+ TB_LOGGER.add_scalar(
+ 't2i_hn_rank_cr', sum(t2i_hn_rank_corrects[0]) / float(sum(t2i_hn_rank_corrects[1])), global_step)
+ TB_LOGGER.add_scalar(
+ 'effect_nsgd_number', sum(effect_nsgd_number) / float(len(effect_nsgd_number)), global_step)
+ TB_LOGGER.add_scalar(
+ 'correction_cr', sum(correction_corrects[0]) / max(float(sum(correction_corrects[1])), 1.0), global_step)
+ TB_LOGGER.add_scalar(
+ 'disc_pos_ntokens', sum(disc_pos_corrects[1]), global_step)
+ TB_LOGGER.add_scalar(
+ 'disc_neg_ntokens', sum(disc_neg_corrects[1]), global_step)
+ TB_LOGGER.add_scalar(
+ 'disc_pos_cr', sum(disc_pos_corrects[0]) / max(float(sum(disc_pos_corrects[1])), 1.0), global_step)
+ TB_LOGGER.add_scalar(
+ 'disc_neg_cr', sum(disc_neg_corrects[0]) / max(float(sum(disc_neg_corrects[1])), 1.0), global_step)
+ TB_LOGGER.step()
+ effect_nsgd_number, mlm_corrects, nsgd_rank_corrects, i2t_hn_rank_corrects, t2i_hn_rank_corrects = \
+ [], [[], []], [[], []], [[], []], [[], []]
+ correction_corrects = [[], []]
+ disc_pos_corrects, disc_neg_corrects = [[], []], [[], []]
+
+ # update model params
+ if opts.grad_norm != -1:
+ grad_norm = clip_grad_norm_(amp.master_params(optimizer),
+ opts.grad_norm)
+ TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
+ optimizer.step()
+ optimizer.zero_grad()
+ pbar.update(1)
+
+ if global_step % 50 == 0:
+ # monitor training throughput
+ LOGGER.info(f'------------Step {global_step}-------------')
+ tot_ex = sum(all_gather_list(n_examples))
+ ex_per_sec = int(tot_ex / (time()-start))
+ tot_hn = sum(all_gather_list(n_hard_ex))
+ hn_per_sec = int(tot_hn / (time()-start))
+ LOGGER.info(f'{tot_ex} ({tot_hn}) examples (hard) '
+ f'trained at {ex_per_sec} ({hn_per_sec}) ex/s')
+ TB_LOGGER.add_scalar('perf/ex_per_s',
+ ex_per_sec, global_step)
+ TB_LOGGER.add_scalar('perf/hn_per_s',
+ hn_per_sec, global_step)
+ LOGGER.info(f'-------------------------------------------')
+
+ if global_step % opts.valid_steps == 0:
+ if opts.full_val:
+ LOGGER.info(
+ f"========================== Step {global_step} "
+ f"==========================")
+ val_log = evaluate(model, eval_loader_val)
+ try:
+ TB_LOGGER.log_scaler_dict(
+ {f"valid/{k}": v for k, v in val_log.items()})
+ LOGGER.info(f"image retrieval R1: "
+ f"{val_log['img_r1']*100:.2f},\n"
+ f"image retrieval R5: "
+ f"{val_log['img_r5']*100:.2f},\n"
+ f"image retrieval R10: "
+ f"{val_log['img_r10']*100:.2f}\n"
+ f"text retrieval R1: "
+ f"{val_log['txt_r1']*100:.2f},\n"
+ f"text retrieval R5: "
+ f"{val_log['txt_r5']*100:.2f},\n"
+ f"text retrieval R10: "
+ f"{val_log['txt_r10']*100:.2f}")
+ LOGGER.info("================================="
+ "=================================")
+ except KeyError:
+ pass
+ else:
+ val_log = validate(model, val_dataloader)
+ TB_LOGGER.log_scaler_dict(val_log)
+ model_saver.save(model, global_step, optimizer=optimizer)
+ # torch.save(optimizer.state_dict(), os.path.join())
+ # torch.save(lr_scheduler.state_dict(), os.path.join())
+
+ if global_step >= opts.num_train_steps:
+ break
+
+ if global_step >= opts.num_train_steps:
+ break
+
+ pbar.close()
+ # final validation
+ val_log = validate(model, val_dataloader)
+ TB_LOGGER.log_scaler_dict(val_log)
+ model_saver.save(model, f'{global_step}_final')
+
+ # evaluation
+ for split, loader in [('val', eval_loader_val),
+ ('test', eval_loader_test)]:
+ eval_log = evaluate(model, loader)
+ TB_LOGGER.log_scaler_dict({f"eval/{split}_{k}": v
+ for k, v in eval_log.items()})
+ if hvd.rank() != 0:
+ continue
+ LOGGER.info(
+ f"========================= {split} ===========================\n"
+ f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
+ f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
+ f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
+ f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
+ f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
+ f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
+ LOGGER.info("=========================================================")
+
+
+@torch.no_grad()
+def validate(model, val_loader):
+ if hvd.rank() == 0:
+ pbar = tqdm(total=len(val_loader))
+ else:
+ pbar = NoOp()
+ LOGGER.info("start running Image Retrieval validation ...")
+ model.eval()
+ n_ex = 0
+ st = time()
+
+ recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0
+ for batch in val_loader:
+ model_outputs = model(batch, compute_loss=False)
+ if isinstance(model_outputs, dict):
+ scores = model_outputs['rank_scores']
+ else:
+ scores = model_outputs
+ _, indices = scores.squeeze(1).topk(10, dim=0)
+ rank = (indices == 0).nonzero()
+ if rank.numel():
+ rank = rank.item()
+ if rank < 1:
+ recall_at_1 += 1
+ if rank < 5:
+ recall_at_5 += 1
+ if rank < 10:
+ recall_at_10 += 1
+ n_ex += 1
+ pbar.update(1)
+ n_ex = sum(all_gather_list(n_ex))
+ recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex
+ recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex
+ recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex
+ tot_time = time()-st
+ val_log = {'valid/ex_per_s': n_ex/tot_time,
+ 'valid/recall_1': recall_at_1,
+ 'valid/recall_5': recall_at_5,
+ 'valid/recall_10': recall_at_10}
+ model.train()
+ LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
+ f"recall_1: {recall_at_1*100:.2f}, "
+ f"recall_5: {recall_at_5*100:.2f}, "
+ f"recall_10: {recall_at_10*100:.2f}")
+ pbar.close()
+ return val_log
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # Required parameters
+
+ parser.add_argument('--compressed_db', action='store_true',
+ help='use compressed LMDB')
+ parser.add_argument("--checkpoint",
+ default=None, type=str,
+ help="pretrained MLM")
+
+ parser.add_argument("--output_dir", default=None, type=str,
+ help="The output directory where the model "
+ "checkpoints will be written.")
+
+ # Prepro parameters
+ parser.add_argument('--max_txt_len', type=int, default=60,
+ help='max number of tokens in text (BERT BPE)')
+ parser.add_argument('--conf_th', type=float, default=0.2,
+ help='threshold for dynamic bounding boxes '
+ '(-1 for fixed)')
+ parser.add_argument('--max_bb', type=int, default=100,
+ help='max number of bounding boxes')
+ parser.add_argument('--min_bb', type=int, default=10,
+ help='min number of bounding boxes')
+ parser.add_argument('--num_bb', type=int, default=36,
+ help='static number of bounding boxes')
+
+ # training parameters
+ parser.add_argument("--train_batch_size", default=32, type=int,
+ help="batch size (# positive examples) for training. "
+ "(implemented with gradient accumulation)")
+
+ parser.add_argument("--negative_size", default=511, type=int,
+ help="Number of negative samples per positive sample"
+ "(forward only)")
+ parser.add_argument("--hard_neg_size", default=31, type=int,
+ help="Number of hard negative samples "
+ "per positive sample (acutally used to train)")
+ parser.add_argument("--mlm_sample_size", default=22, type=int,
+ help="Number of samples following masked language masking"
+ "per positive sample (acutally used to train)")
+ parser.add_argument("--nsgd_sample_size", default=22, type=int,
+ help="Number of NSGD for each mlm sample"
+ "per positive sample (acutally used to train)")
+ parser.add_argument("--nsgd_sample_temperature", default=2.0, type=float,
+ help="sampling temperature of NSGD sampling. ")
+
+ parser.add_argument("--disc_lambda", default=0.1, type=float,
+ help="lambda in training of discrimination")
+ parser.add_argument("--correction_lambda", default=0.1, type=float,
+ help="lambda in training of correction")
+ parser.add_argument("--mlm_lambda", default=0.1, type=float,
+ help="lambda in training of mask language modeling")
+ parser.add_argument("--nsgd_rank_lambda", default=1.0, type=float,
+ help="lambda in training of NSGD ranking loss")
+ parser.add_argument("--margin", default=0.2, type=float,
+ help="margin of ranking loss")
+ parser.add_argument("--learning_rate", default=3e-5, type=float,
+ help="The initial learning rate for Adam.")
+ parser.add_argument("--valid_steps", default=1000, type=int,
+ help="Run validation every X steps")
+ parser.add_argument("--num_train_steps", default=100000, type=int,
+ help="Total number of training updates to perform.")
+ parser.add_argument("--optim", default='adam',
+ choices=['adam', 'adamax', 'adamw'],
+ help="optimizer")
+ parser.add_argument("--betas", default=[0.9, 0.98], nargs='+',
+ help="beta for adam optimizer")
+ parser.add_argument("--dropout", default=0.1, type=float,
+ help="tune dropout regularization")
+ parser.add_argument("--weight_decay", default=0.01, type=float,
+ help="weight decay (L2) regularization")
+ parser.add_argument("--grad_norm", default=0.25, type=float,
+ help="gradient clipping (-1 for no clipping)")
+ parser.add_argument("--warmup_steps", default=4000, type=int,
+ help="Number of training steps to perform linear "
+ "learning rate warmup for.")
+
+ # device parameters
+ parser.add_argument('--seed', type=int, default=42,
+ help="random seed for initialization")
+ parser.add_argument('--full_val', action='store_true',
+ help="Always run full evaluation during training")
+ parser.add_argument('--fp16', action='store_true',
+ help="Whether to use 16-bit float precision instead "
+ "of 32-bit")
+ parser.add_argument('--n_workers', type=int, default=4,
+ help="number of data workers")
+ parser.add_argument('--pin_mem', action='store_true',
+ help="pin memory")
+
+ # can use config files
+ parser.add_argument('--config', help='JSON config files')
+
+ args = parse_with_config(parser)
+
+ # if exists(args.output_dir) and os.listdir(args.output_dir):
+ # raise ValueError("Output directory ({}) already exists and is not "
+ # "empty.".format(args.output_dir))
+
+ # options safe guard
+ if args.conf_th == -1:
+ assert args.max_bb + args.max_txt_len + 2 <= 512
+ else:
+ assert args.num_bb + args.max_txt_len + 2 <= 512
+
+ # for tensor core
+ assert (args.negative_size+1) % 8 == (args.hard_neg_size+1) % 8 == 0
+
+ print('| args: ', args)
+ main(args)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/ans2label.json b/utils/ans2label.json
new file mode 100644
index 0000000..9fc717b
--- /dev/null
+++ b/utils/ans2label.json
@@ -0,0 +1 @@
+{"net": 0, "pitcher": 1, "orange": 2, "yes": 3, "white": 4, "skiing": 5, "red": 6, "frisbee": 7, "brushing teeth": 8, "no": 9, "black and white": 10, "skateboard": 11, "1": 12, "blue": 13, "green": 14, "motorcycle": 15, "gray": 16, "2": 17, "purse": 18, "skis": 19, "poles": 20, "surfboard": 21, "dog": 22, "on": 23, "office": 24, "large": 25, "very big": 26, "laptop": 27, "vent": 28, "computer": 29, "black": 30, "bear": 31, "3": 32, "wii": 33, "glasses": 34, "tree": 35, "eating": 36, "log": 37, "5": 38, "raft": 39, "left": 40, "living room": 41, "pink": 42, "right": 43, "railing": 44, "grass": 45, "wire": 46, "10 years": 47, "knife": 48, "cake": 49, "banana": 50, "chef": 51, "vanilla": 52, "4": 53, "outdoor": 54, "mustard": 55, "bun": 56, "clouds": 57, "dock": 58, "brown": 59, "silver": 60, "refrigerator": 61, "square": 62, "teddy": 63, "elm": 64, "stripes": 65, "baseball": 66, "catcher": 67, "beer": 68, "bottom": 69, "north": 70, "nike": 71, "yellow and white": 72, "morning": 73, "elephant": 74, "red and white": 75, "propeller": 76, "tan": 77, "wall": 78, "rolex": 79, "clock": 80, "table": 81, "0": 82, "wood": 83, "christmas": 84, "spinach": 85, "thick": 86, "bag": 87, "leaves": 88, "necklace": 89, "6": 90, "bathroom": 91, "shower": 92, "towel": 93, "solid": 94, "referee": 95, "wilson": 96, "8:00": 97, "e": 98, "24": 99, "hat": 100, "grazing": 101, "sheep": 102, "10": 103, "tag": 104, "spanish": 105, "hot dog": 106, "plate": 107, "lunch": 108, "butter": 109, "peppers": 110, "onions": 111, "very": 112, "mayonnaise": 113, "mayo": 114, "sweet potato": 115, "pig": 116, "sweet": 117, "flowers": 118, "floral": 119, "yellow": 120, "window": 121, "7": 122, "pizza": 123, "car": 124, "": 125, "cargo": 126, "stairs": 127, "abstract": 128, "rug": 129, "baseball cap": 130, "texting": 131, "pole": 132, "crosswalk": 133, "nothing": 134, "urban": 135, "bus": 136, "light": 137, "afternoon": 138, "boat": 139, "cheese": 140, "paper": 141, "real": 142, "sun": 143, "birthday": 144, "words": 145, "inside": 146, "shadows": 147, "tomato": 148, "evergreen": 149, "100 feet": 150, "shingles": 151, "trees": 152, "building": 153, "hay": 154, "ski pole": 155, "patterned": 156, "walking": 157, "ice": 158, "laundry": 159, "pepsi": 160, "good": 161, "1:50": 162, "purple": 163, "13": 164, "africa": 165, "teddy bears": 166, "socks": 167, "giraffe": 168, "soccer": 169, "blue and yellow": 170, "zebras": 171, "cupcake": 172, "broccoli": 173, "soldier": 174, "parking lot": 175, "cows": 176, "herding": 177, "on table": 178, "fish": 179, "nightstand": 180, "50": 181, "overcast": 182, "cross": 183, "toaster oven": 184, "tile": 185, "11:55": 186, "red and yellow": 187, "nowhere": 188, "hair dryer": 189, "truck": 190, "11": 191, "people": 192, "rectangle": 193, "hot dogs": 194, "party": 195, "12:55": 196, "apron": 197, "kitchen": 198, "cooking": 199, "ring": 200, "1 way": 201, "stop": 202, "neither": 203, "many": 204, "female": 205, "brushing": 206, "tie": 207, "tennis racket": 208, "knife and fork": 209, "restaurant": 210, "cat": 211, "bed": 212, "sand": 213, "ocean": 214, "cold": 215, "kites": 216, "cumulus": 217, "standing": 218, "male": 219, "star": 220, "tracks": 221, "chocolate": 222, "round": 223, "fork and knife": 224, "yankees": 225, "pictures": 226, "dots": 227, "bird": 228, "parrot": 229, "red white and blue": 230, "man": 231, "metal": 232, "fence": 233, "snowboarding": 234, "pine": 235, "snow": 236, "shorts": 237, "swim": 238, "wine": 239, "brick": 240, "no parking": 241, "children": 242, "beef": 243, "phone": 244, "english": 245, "cell phone": 246, "pink and yellow": 247, "clear": 248, "watermelon": 249, "bedroom": 250, "fork": 251, "cow": 252, "rackets": 253, "tennis rackets": 254, "8": 255, "collar": 256, "tennis": 257, "1950s": 258, "playing tennis": 259, "skirt": 260, "30": 261, "polka dot": 262, "beach": 263, "horse": 264, "grill": 265, "african american": 266, "down": 267, "street": 268, "in air": 269, "sweater": 270, "yellow and blue": 271, "park": 272, "backyard": 273, "spectators": 274, "parasailing": 275, "31": 276, "river": 277, "55": 278, "shadow": 279, "winter": 280, "chicken": 281, "tea": 282, "evening": 283, "dusk": 284, "ski resort": 285, "helmet": 286, "penne": 287, "bench": 288, "resting": 289, "elephants": 290, "southwest": 291, "usa": 292, "cars": 293, "town": 294, "bananas": 295, "umbrella": 296, "container": 297, "woman": 298, "on counter": 299, "salad": 300, "striped": 301, "motel": 302, "vertical": 303, "oranges": 304, "hot sauce": 305, "bottle": 306, "juice": 307, "eyes": 308, "ground": 309, "backpack": 310, "black and yellow": 311, "forward": 312, "jackets": 313, "1 on right": 314, "green and yellow": 315, "playing baseball": 316, "riding": 317, "sitting": 318, "carrot": 319, "basket": 320, "seagull": 321, "ski poles": 322, "p": 323, "parking": 324, "street light": 325, "mets": 326, "strap": 327, "bike": 328, "riding bike": 329, "poodle": 330, "shoes": 331, "carpet": 332, "lettuce": 333, "food": 334, "1 foot": 335, "roses": 336, "mountains": 337, "scissors": 338, "camera": 339, "beige": 340, "beard": 341, "cutting": 342, "baby": 343, "tape": 344, "watch": 345, "never": 346, "taking picture": 347, "eggs": 348, "syrup": 349, "sandwich": 350, "water skiing": 351, "microphone": 352, "back": 353, "bears": 354, "donuts": 355, "w": 356, "sky": 357, "double decker": 358, "england": 359, "surfing": 360, "running": 361, "shirt": 362, "barn": 363, "weather vane": 364, "white and blue": 365, "fishing": 366, "bridge": 367, "los angeles": 368, "open": 369, "red sox": 370, "bat": 371, "plane": 372, "white and green": 373, "transportation": 374, "sunny": 375, "bus stop": 376, "city": 377, "brown and white": 378, "bicycle": 379, "crow": 380, "magazines": 381, "daisy": 382, "14": 383, "old": 384, "curtains": 385, "jumped": 386, "snowboard": 387, "dinosaur": 388, "racing": 389, "asphalt": 390, "court": 391, "plastic": 392, "circle": 393, "red and blue": 394, "zebra": 395, "12": 396, "biplane": 397, "shallow": 398, "brazil": 399, "logo": 400, "2:20": 401, "electric": 402, "night time": 403, "motion": 404, "toothbrushes": 405, "orange and white": 406, "66": 407, "spoon": 408, "toyota": 409, "tennis shoes": 410, "46": 411, "second": 412, "no 1": 413, "iphone": 414, "friend": 415, "apple": 416, "carnation": 417, "15": 418, "tiger": 419, "glove": 420, "airplane": 421, "bow": 422, "air france": 423, "passengers": 424, "tv": 425, "on building": 426, "3:55": 427, "victorian": 428, "steeple": 429, "happy": 430, "skateboarding": 431, "fruit": 432, "cutting board": 433, "cantaloupe": 434, "kiwi": 435, "sliced": 436, "heart": 437, "water": 438, "rainy": 439, "carrots": 440, "giraffes": 441, "eat": 442, "ramp": 443, "lab": 444, "field": 445, "horizontal": 446, "birds": 447, "home": 448, "shrimp": 449, "12 feet": 450, "girl": 451, "modern": 452, "turtle": 453, "dell": 454, "boots": 455, "sunglasses": 456, "black and orange": 457, "yellow and black": 458, "gloves": 459, "hp": 460, "desk": 461, "both": 462, "sign": 463, "on street": 464, "2000": 465, "cirrus": 466, "to dry": 467, "ceiling": 468, "fluorescent": 469, "up": 470, "9": 471, "boys": 472, "playing soccer": 473, "american": 474, "passenger": 475, "turn": 476, "palm": 477, "no train": 478, "wedding": 479, "branch": 480, "parrots": 481, "air force": 482, "on tracks": 483, "small": 484, "tank": 485, "dirty": 486, "france": 487, "honda": 488, "2.00": 489, "whale": 490, "vase": 491, "flying": 492, "professional": 493, "driving": 494, "tissue": 495, "protest": 496, "corona": 497, "for balance": 498, "twin": 499, "clothes": 500, "t shirt": 501, "window sill": 502, "wild": 503, "noon": 504, "caution": 505, "spring": 506, "raining": 507, "cane": 508, "school": 509, "windsurfing": 510, "parachute": 511, "black and red": 512, "25": 513, "background": 514, "toaster": 515, "planes": 516, "yellow and red": 517, "spatula": 518, "10:10": 519, "ivory": 520, "train": 521, "welcome": 522, "highway": 523, "off": 524, "on track": 525, "electricity": 526, "italy": 527, "dinner": 528, "sink": 529, "squares": 530, "5 ft": 531, "parked": 532, "store": 533, "dress": 534, "signs": 535, "meow": 536, "football": 537, "rugby": 538, "stainless steel": 539, "la": 540, "dirt": 541, "blue and white": 542, "klm": 543, "house": 544, "unknown": 545, "ford": 546, "reading": 547, "chair": 548, "mountain": 549, "alive": 550, "water skis": 551, "picture": 552, "parade": 553, "slippers": 554, "trailer": 555, "boating": 556, "holding it": 557, "shade": 558, "cloth": 559, "6:20": 560, "candle": 561, "hose": 562, "hand": 563, "3:25": 564, "on sidewalk": 565, "poster": 566, "downhill": 567, "68": 568, "reflection": 569, "summer": 570, "pickles": 571, "halloween": 572, "bats": 573, "london": 574, "zoo": 575, "surfer": 576, "racket": 577, "flickr": 578, "cutting hair": 579, "strawberries": 580, "mushroom": 581, "teddy bear": 582, "big": 583, "suitcase": 584, "veggie": 585, "pepper": 586, "houses": 587, "70": 588, "toshiba": 589, "triangle": 590, "boxes": 591, "photograph": 592, "smoke": 593, "engine": 594, "camel": 595, "sidewalk": 596, "left 1": 597, "red and green": 598, "4:35": 599, "on couch": 600, "candy": 601, "minnie mouse": 602, "homemade": 603, "mouse": 604, "box": 605, "movie": 606, "45": 607, "strawberry": 608, "fridge": 609, "full": 610, "vegetables": 611, "bright": 612, "play": 613, "remote": 614, "pond": 615, "savannah": 616, "celery": 617, "concrete": 618, "semi": 619, "dump": 620, "scania": 621, "safety": 622, "posing": 623, "fabric": 624, "laying": 625, "couch": 626, "blueberries": 627, "handle": 628, "pipe": 629, "stick": 630, "parmesan": 631, "steak": 632, "chain link": 633, "catch": 634, "barbed wire": 635, "mozzarella": 636, "soda": 637, "fire hydrant": 638, "cat food": 639, "pepperoni": 640, "lot": 641, "licking": 642, "red and black": 643, "clay": 644, "tennis court": 645, "jumping": 646, "potatoes": 647, "toothbrush": 648, "kite": 649, "not at all": 650, "flying kite": 651, "broken": 652, "black and silver": 653, "lap": 654, "outside": 655, "44": 656, "delta": 657, "greyhound": 658, "ring finger": 659, "talking on phone": 660, "bad": 661, "kettle": 662, "35": 663, "motorcycles": 664, "produce": 665, "comfort": 666, "steering wheel": 667, "18": 668, "humans": 669, "coffee": 670, "white and brown": 671, "fall": 672, "bread": 673, "cherry": 674, "4:30": 675, "flag": 676, "night": 677, "lamp": 678, "cucumber": 679, "can't see": 680, "porcelain": 681, "oval": 682, "museum": 683, "rain": 684, "sprinkles": 685, "20": 686, "kids": 687, "bracelet": 688, "sneakers": 689, "mask": 690, "mickey mouse": 691, "twins": 692, "very high": 693, "costume": 694, "cabbage": 695, "paint": 696, "lighting": 697, "young": 698, "air conditioner": 699, "wooden": 700, "board": 701, "someone": 702, "beets": 703, "16": 704, "day time": 705, "4 inches": 706, "lights": 707, "ladder": 708, "glass": 709, "ferris wheel": 710, "fries": 711, "steamed": 712, "shepherd": 713, "cotton": 714, "suit": 715, "goatee": 716, "on his head": 717, "print": 718, "happy birthday": 719, "forks": 720, "travel": 721, "maple": 722, "200": 723, "oil": 724, "jeans": 725, "can": 726, "chopsticks": 727, "on wall": 728, "construction": 729, "mack": 730, "36": 731, "chinese": 732, "moped": 733, "festival": 734, "gas": 735, "throwing": 736, "circus": 737, "wires": 738, "not possible": 739, "plates": 740, "sugar": 741, "in": 742, "women's": 743, "door": 744, "no man": 745, "volleyball": 746, "serving": 747, "ponytail": 748, "business": 749, "decoration": 750, "santa": 751, "flat": 752, "barrel": 753, "12:15": 754, "candles": 755, "atv": 756, "free": 757, "hair": 758, "waffle": 759, "ball": 760, "stop sign": 761, "wetsuit": 762, "very deep": 763, "swimsuit": 764, "green and black": 765, "foreground": 766, "stands": 767, "china airlines": 768, "flower": 769, "300": 770, "lobster": 771, "on bench": 772, "plaster": 773, "phones": 774, "sailboat": 775, "apples": 776, "road": 777, "recently": 778, "cones": 779, "cactus": 780, "rice": 781, "vegetarian": 782, "donut": 783, "ketchup": 784, "police": 785, "mirror": 786, "rock": 787, "meat": 788, "blinds": 789, "cell phones": 790, "china": 791, "rust": 792, "7:25": 793, "stone": 794, "vans": 795, "middle": 796, "eagle": 797, "9:30": 798, "ping pong": 799, "microwave": 800, "gmc": 801, "umbrellas": 802, "wrist": 803, "cuddling": 804, "laughing": 805, "boy": 806, "next to toilet": 807, "tabby": 808, "petting": 809, "south": 810, "40": 811, "name tag": 812, "checkered": 813, "name": 814, "slow": 815, "cardboard": 816, "windows": 817, "croissant": 818, "plain": 819, "cookie": 820, "on ground": 821, "low": 822, "water bottle": 823, "goggles": 824, "turkey": 825, "pull": 826, "shut": 827, "kite flying": 828, "bowl": 829, "smile": 830, "in bowl": 831, "bush": 832, "cloudy": 833, "top left": 834, "skateboarder": 835, "coca cola": 836, "pan": 837, "drinking": 838, "short": 839, "floor": 840, "thanksgiving": 841, "radio": 842, "drink": 843, "on toilet": 844, "bike rack": 845, "bleachers": 846, "train tracks": 847, "horses": 848, "far": 849, "top": 850, "toilet": 851, "in water": 852, "private": 853, "nature": 854, "checkers": 855, "commercial": 856, "stroller": 857, "power": 858, "stuffed animals": 859, "uniforms": 860, "japan": 861, "liquor": 862, "faucet": 863, "green and orange": 864, "corn": 865, "sub": 866, "white and yellow": 867, "mercedes": 868, "in sky": 869, "tarp": 870, "indian": 871, "counter": 872, "multicolored": 873, "polar": 874, "go": 875, "now": 876, "no number": 877, "swimming": 878, "bridle": 879, "cowboy": 880, "union station": 881, "salt and pepper": 882, "olives": 883, "pizza cutter": 884, "british airways": 885, "nighttime": 886, "domestic": 887, "trolley": 888, "australia": 889, "tiles": 890, "pug": 891, "wicker": 892, "british": 893, "us airways express": 894, "burton": 895, "christmas tree": 896, "napkin": 897, "writing": 898, "rocks": 899, "hello kitty": 900, "lacoste": 901, "gold": 902, "fan": 903, "skateboards": 904, "day": 905, "on floor": 906, "2008": 907, "dark": 908, "flying kites": 909, "rural": 910, "olympics": 911, "bmw": 912, "34": 913, "factory": 914, "denim": 915, "typing": 916, "for fun": 917, "steel": 918, "watching tv": 919, "chevron": 920, "driver": 921, "baggage claim": 922, "grapes": 923, "f": 924, "angels": 925, "roof": 926, "handlebars": 927, "train station": 928, "public": 929, "oak": 930, "sleeping": 931, "canada": 932, "on runway": 933, "air canada": 934, "on top": 935, "tired": 936, "blonde": 937, "cups": 938, "little": 939, "adidas": 940, "10 feet": 941, "white and gray": 942, "leaf": 943, "fisheye": 944, "forest": 945, "war": 946, "octagon": 947, "raspberry": 948, "helmets": 949, "united states": 950, "29": 951, "noodles": 952, "van": 953, "long": 954, "traveling": 955, "luggage": 956, "airport": 957, "single": 958, "pitching": 959, "dugout": 960, "garbage": 961, "in street": 962, "happiness": 963, "cigarette": 964, "on tower": 965, "antelope": 966, "graffiti": 967, "skating": 968, "on road": 969, "curved": 970, "red light": 971, "washington": 972, "ski lift": 973, "athletics": 974, "brace": 975, "squatting": 976, "catching": 977, "batter": 978, "batting": 979, "game": 980, "towards": 981, "33": 982, "sliding": 983, "makeup": 984, "japanese": 985, "person": 986, "pirates": 987, "plaid": 988, "rose": 989, "daytime": 990, "keyboard": 991, "surfboards": 992, "hummingbird": 993, "ollie": 994, "11:30": 995, "clock tower": 996, "5:55": 997, "san francisco": 998, "stopping": 999, "tags": 1000, "samsung": 1001, "computers": 1002, "cabinets": 1003, "talking": 1004, "cage": 1005, "asparagus": 1006, "5 years": 1007, "hanger": 1008, "adult": 1009, "rabbit": 1010, "empty": 1011, "softball": 1012, "1st": 1013, "playing": 1014, "chairs": 1015, "farm": 1016, "cross country": 1017, "dump truck": 1018, "women": 1019, "snowboarder": 1020, "tall": 1021, "monkey": 1022, "mantle": 1023, "fire": 1024, "books": 1025, "quilt": 1026, "cessna": 1027, "chandelier": 1028, "dunkin donuts": 1029, "beans": 1030, "relish": 1031, "no flag": 1032, "parking meter": 1033, "spots": 1034, "ducks": 1035, "sandals": 1036, "doughnut": 1037, "lighthouse": 1038, "yacht": 1039, "german shepherd": 1040, "in middle": 1041, "raw": 1042, "chain": 1043, "2 feet": 1044, "pedestal": 1045, "sauerkraut": 1046, "bagels": 1047, "mutt": 1048, "dog and cat": 1049, "race": 1050, "poor": 1051, "cat and dog": 1052, "station": 1053, "printer": 1054, "daisies": 1055, "front": 1056, "gravel": 1057, "rear": 1058, "grassy": 1059, "pigeons": 1060, "dogs": 1061, "in car": 1062, "life": 1063, "wii remotes": 1064, "suv": 1065, "leather": 1066, "bottom right": 1067, "peace": 1068, "facebook": 1069, "blanket": 1070, "fountain": 1071, "frisbees": 1072, "12:30": 1073, "am": 1074, "scooter": 1075, "going": 1076, "analog": 1077, "america": 1078, "pitbull": 1079, "relaxing": 1080, "paddle boarding": 1081, "white and pink": 1082, "shampoo": 1083, "alps": 1084, "ride": 1085, "side": 1086, "mane": 1087, "on desk": 1088, "on chair": 1089, "2012": 1090, "multi": 1091, "straight": 1092, "big ben": 1093, "closed": 1094, "frosted": 1095, "3 feet": 1096, "waves": 1097, "buoy": 1098, "life vest": 1099, "trash can": 1100, "medium": 1101, "boxer": 1102, "very tall": 1103, "yamaha": 1104, "sunlight": 1105, "hit ball": 1106, "dry": 1107, "coke": 1108, "gym": 1109, "orange and black": 1110, "center": 1111, "rope": 1112, "flip flops": 1113, "4th of july": 1114, "siamese": 1115, "crafts": 1116, "color": 1117, "italian": 1118, "playing frisbee": 1119, "skate park": 1120, "orange juice": 1121, "windowsill": 1122, "corgi": 1123, "thumb": 1124, "peanut butter": 1125, "pie": 1126, "toast": 1127, "no hat": 1128, "benches": 1129, "diamond": 1130, "blender": 1131, "avocado": 1132, "television": 1133, "speakers": 1134, "pony": 1135, "baseball field": 1136, "pavement": 1137, "sydney": 1138, "not there": 1139, "diamonds": 1140, "4 feet": 1141, "goalie": 1142, "soccer ball": 1143, "runway": 1144, "video game": 1145, "gaming": 1146, "casual": 1147, "green and white": 1148, "toilet brush": 1149, "working": 1150, "pickup": 1151, "girls": 1152, "remotes": 1153, "pasta": 1154, "hood": 1155, "braves": 1156, "skier": 1157, "motorola": 1158, "17": 1159, "b": 1160, "100": 1161, "diet coke": 1162, "hospital": 1163, "wagon": 1164, "milk": 1165, "ferry": 1166, "rainbow": 1167, "on bed": 1168, "toward": 1169, "1:30": 1170, "19": 1171, "security": 1172, "herself": 1173, "mercedes benz": 1174, "supreme": 1175, "thin": 1176, "platform": 1177, "gray and red": 1178, "thai": 1179, "storage": 1180, "thailand": 1181, "swan": 1182, "peach": 1183, "10:05": 1184, "dome": 1185, "chiquita": 1186, "2:00": 1187, "mountain dew": 1188, "23": 1189, "knives": 1190, "street sign": 1191, "on beach": 1192, "playing wii": 1193, "using laptop": 1194, "stickers": 1195, "yogurt": 1196, "on grass": 1197, "9:50": 1198, "9:45": 1199, "sweat": 1200, "gatorade": 1201, "umpire": 1202, "37": 1203, "transport": 1204, "desktop": 1205, "desserts": 1206, "main": 1207, "boston": 1208, "fell": 1209, "top right": 1210, "case": 1211, "asleep": 1212, "over": 1213, "9:55": 1214, "grapefruit": 1215, "breakfast": 1216, "headphones": 1217, "freight": 1218, "cup": 1219, "sweatband": 1220, "nobody": 1221, "lamps": 1222, "9:25": 1223, "scarf": 1224, "on fridge": 1225, "main st": 1226, "moving": 1227, "confused": 1228, "fresh": 1229, "kiting": 1230, "blue jay": 1231, "flats": 1232, "long time": 1233, "chihuahua": 1234, "ceramic": 1235, "mushrooms": 1236, "on plate": 1237, "human": 1238, "power lines": 1239, "hotel": 1240, "map": 1241, "earring": 1242, "boarding": 1243, "display": 1244, "warm": 1245, "napkins": 1246, "brown and black": 1247, "broom": 1248, "basketball": 1249, "papers": 1250, "holding baby": 1251, "sad": 1252, "kickstand": 1253, "60": 1254, "shoulder": 1255, "sleep": 1256, "footprints": 1257, "tunnel": 1258, "1990": 1259, "hats": 1260, "6 inches": 1261, "ham": 1262, "bacon": 1263, "church": 1264, "53": 1265, "pineapple": 1266, "at camera": 1267, "red bull": 1268, "pilot": 1269, "tattoo": 1270, "work": 1271, "polar bear": 1272, "taking off": 1273, "website": 1274, "22": 1275, "4:00": 1276, "coffee maker": 1277, "fast": 1278, "fur": 1279, "rubber": 1280, "tongs": 1281, "german": 1282, "germany": 1283, "3 inches": 1284, "toy": 1285, "3:20": 1286, "calm": 1287, "pots": 1288, "balloons": 1289, "fruits": 1290, "9:20": 1291, "drawer": 1292, "oven": 1293, "soup": 1294, "stove": 1295, "heels": 1296, "wind": 1297, "island": 1298, "blood": 1299, "leg": 1300, "theater": 1301, "tennis racquet": 1302, "21": 1303, "gothic": 1304, "2:35": 1305, "wii remote": 1306, "turning": 1307, "20 feet": 1308, "pink and black": 1309, "ears": 1310, "fun": 1311, "wreath": 1312, "to right": 1313, "child": 1314, "fly": 1315, "head": 1316, "drywall": 1317, "shorter": 1318, "pier": 1319, "feeding giraffe": 1320, "in vase": 1321, "burger": 1322, "easter": 1323, "onion": 1324, "uniform": 1325, "remote control": 1326, "guitar": 1327, "time": 1328, "verizon": 1329, "tomatoes": 1330, "ship": 1331, "tulips": 1332, "glaze": 1333, "on suitcase": 1334, "tent": 1335, "1:45": 1336, "market": 1337, "bnsf": 1338, "bandana": 1339, "still": 1340, "don't know": 1341, "piano": 1342, "mouth": 1343, "run": 1344, "sparrow": 1345, "throw": 1346, "lines": 1347, "vest": 1348, "1950": 1349, "jet": 1350, "sepia": 1351, "2015": 1352, "busy": 1353, "lighter": 1354, "dessert": 1355, "bending": 1356, "75": 1357, "finch": 1358, "pastries": 1359, "outdoors": 1360, "bakery": 1361, "clean": 1362, "ipod": 1363, "tablecloth": 1364, "cigarettes": 1365, "looking at phone": 1366, "in front": 1367, "food truck": 1368, "face": 1369, "swinging": 1370, "safari": 1371, "500": 1372, "volkswagen": 1373, "2010": 1374, "shape": 1375, "shelves": 1376, "riding horses": 1377, "2016": 1378, "behind bus": 1379, "towels": 1380, "lemon": 1381, "straw": 1382, "bamboo": 1383, "5 feet": 1384, "hardwood": 1385, "oregon": 1386, "schnauzer": 1387, "organic": 1388, "h": 1389, "kid": 1390, "meter": 1391, "61": 1392, "charging": 1393, "bald": 1394, "caucasian": 1395, "man on left": 1396, "stand": 1397, "27": 1398, "dining room": 1399, "sandwiches": 1400, "32": 1401, "apartment": 1402, "tower": 1403, "virgin": 1404, "out": 1405, "white and red": 1406, "2:05": 1407, "i don't know": 1408, "chains": 1409, "legs": 1410, "age": 1411, "goats": 1412, "s": 1413, "congratulations": 1414, "dresser": 1415, "camper": 1416, "half": 1417, "silverware": 1418, "decorative": 1419, "hawaiian": 1420, "petting horse": 1421, "wheel": 1422, "florida": 1423, "reds": 1424, "washington dc": 1425, "moon": 1426, "conference": 1427, "screen": 1428, "controller": 1429, "robin": 1430, "men": 1431, "protection": 1432, "roll": 1433, "harley davidson": 1434, "coal": 1435, "mustache": 1436, "smiling": 1437, "pedestrians": 1438, "88": 1439, "me": 1440, "tray": 1441, "males": 1442, "monitor": 1443, "bell": 1444, "landscape": 1445, "club": 1446, "toothpick": 1447, "seagulls": 1448, "bowtie": 1449, "lake": 1450, "steam": 1451, "surf": 1452, "baseball glove": 1453, "blinders": 1454, "woods": 1455, "stuffed": 1456, "sunbathing": 1457, "shearing": 1458, "dad": 1459, "mixer": 1460, "pot": 1461, "blending": 1462, "identification": 1463, "owl": 1464, "wine glass": 1465, "on bike": 1466, "billabong": 1467, "new york": 1468, "yarn": 1469, "tube": 1470, "tennis ball": 1471, "2:55": 1472, "ice cream": 1473, "chevrolet": 1474, "shirt and tie": 1475, "taking selfie": 1476, "blue and green": 1477, "he isn't": 1478, "cutting cake": 1479, "east": 1480, "setting": 1481, "brewers": 1482, "riding bikes": 1483, "7 eleven": 1484, "stars": 1485, "jockey": 1486, "jacket": 1487, "standing still": 1488, "book": 1489, "gray and white": 1490, "pen": 1491, "red white blue": 1492, "above": 1493, "alaska": 1494, "tongue": 1495, "feathers": 1496, "k": 1497, "camping": 1498, "pasture": 1499, "corner": 1500, "away": 1501, "ski": 1502, "texas": 1503, "fire truck": 1504, "sailboats": 1505, "jump": 1506, "walk": 1507, "spray paint": 1508, "loading": 1509, "united": 1510, "1000": 1511, "brushing his teeth": 1512, "roman numerals": 1513, "garlic": 1514, "surprise": 1515, "3rd": 1516, "first": 1517, "side of road": 1518, "dodgers": 1519, "airplanes": 1520, "unsure": 1521, "russian": 1522, "wet": 1523, "skyscraper": 1524, "5 star": 1525, "brushing her teeth": 1526, "blankets": 1527, "natural": 1528, "across street": 1529, "smartphone": 1530, "duck": 1531, "sausage": 1532, "paris": 1533, "newspaper": 1534, "pants": 1535, "spices": 1536, "pillow": 1537, "to left": 1538, "snowboards": 1539, "colgate": 1540, "on elephant": 1541, "string": 1542, "horns": 1543, "2:40": 1544, "men's": 1545, "cobblestone": 1546, "regular": 1547, "staring": 1548, "28": 1549, "barber shop": 1550, "linoleum": 1551, "grind": 1552, "cut": 1553, "x": 1554, "above sink": 1555, "above stove": 1556, "dishes": 1557, "dalmatian": 1558, "watching": 1559, "glazed": 1560, "5:25": 1561, "j": 1562, "messy": 1563, "wallet": 1564, "tuna": 1565, "toasted": 1566, "grilled": 1567, "french": 1568, "green and blue": 1569, "sunflowers": 1570, "to catch frisbee": 1571, "wool": 1572, "sprint": 1573, "no grass": 1574, "cabinet": 1575, "shell": 1576, "foil": 1577, "bottles": 1578, "bar": 1579, "king": 1580, "paper towels": 1581, "friends": 1582, "beagle": 1583, "school bus": 1584, "laptops": 1585, "snowing": 1586, "cement": 1587, "pc": 1588, "accident": 1589, "stuffed animal": 1590, "wakeboard": 1591, "balance": 1592, "in suitcase": 1593, "white and black": 1594, "nikon": 1595, "cleats": 1596, "on sink": 1597, "pool": 1598, "mom": 1599, "downtown": 1600, "asian": 1601, "heater": 1602, "bathing": 1603, "193": 1604, "against wall": 1605, "canopy": 1606, "jungle": 1607, "berries": 1608, "military": 1609, "pickle": 1610, "clams": 1611, "seafood": 1612, "in box": 1613, "boats": 1614, "tables": 1615, "lizard": 1616, "lemonade": 1617, "m": 1618, "soft": 1619, "illinois": 1620, "country": 1621, "for sale": 1622, "arm": 1623, "listening": 1624, "curly": 1625, "play tennis": 1626, "hands": 1627, "cereal": 1628, "blue and red": 1629, "robe": 1630, "around neck": 1631, "red and silver": 1632, "soap": 1633, "trains": 1634, "throwing frisbee": 1635, "smoking": 1636, "india": 1637, "headband": 1638, "not very": 1639, "westin": 1640, "serve": 1641, "bicycles": 1642, "can't tell": 1643, "to catch ball": 1644, "visibility": 1645, "ana": 1646, "reins": 1647, "rodeo": 1648, "boot": 1649, "on horse": 1650, "12:35": 1651, "riding motorcycle": 1652, "mexico": 1653, "mother": 1654, "african": 1655, "left and right": 1656, "button": 1657, "earrings": 1658, "blackberry": 1659, "cell": 1660, "10:00": 1661, "harness": 1662, "pillows": 1663, "vegetable": 1664, "tablet": 1665, "fern": 1666, "cats": 1667, "golden retriever": 1668, "goat": 1669, "tractor": 1670, "valentine's day": 1671, "hearts": 1672, "khaki": 1673, "man on right": 1674, "mcdonald's": 1675, "player": 1676, "arriving": 1677, "husky": 1678, "on skateboard": 1679, "vases": 1680, "coat": 1681, "beanie": 1682, "coming": 1683, "granite": 1684, "shopping cart": 1685, "it's raining": 1686, "sports": 1687, "leash": 1688, "balls": 1689, "blurry": 1690, "baseball bat": 1691, "team": 1692, "mango": 1693, "mug": 1694, "eiffel tower": 1695, "worms": 1696, "trash": 1697, "robot": 1698, "show": 1699, "terrier": 1700, "painting": 1701, "rooster": 1702, "42": 1703, "jones": 1704, "state farm": 1705, "balloon": 1706, "trunk": 1707, "coach": 1708, "t": 1709, "playing game": 1710, "fireplace": 1711, "behind clouds": 1712, "uphill": 1713, "motocross": 1714, "sony": 1715, "magazine": 1716, "kitesurfing": 1717, "catching frisbee": 1718, "catch frisbee": 1719, "bud light": 1720, "drive": 1721, "fighting": 1722, "1 on left": 1723, "very old": 1724, "hallway": 1725, "lexus": 1726, "wii controller": 1727, "9:15": 1728, "fast food": 1729, "5:45": 1730, "catholic": 1731, "muffin": 1732, "traffic light": 1733, "band": 1734, "button up": 1735, "grocery": 1736, "shelf": 1737, "2:25": 1738, "honey": 1739, "plants": 1740, "oars": 1741, "foggy": 1742, "nathan's": 1743, "cord": 1744, "yard": 1745, "48": 1746, "donut shop": 1747, "chimney": 1748, "calico": 1749, "suits": 1750, "sideways": 1751, "animals": 1752, "black and blue": 1753, "bikini": 1754, "photographer": 1755, "700": 1756, "queen": 1757, "1:00": 1758, "12:05": 1759, "horseback riding": 1760, "awake": 1761, "bunny": 1762, "12:00": 1763, "continental": 1764, "flamingo": 1765, "rye": 1766, "family": 1767, "lots": 1768, "owner": 1769, "stew": 1770, "palm tree": 1771, "cruise ship": 1772, "56": 1773, "design": 1774, "ny": 1775, "far right": 1776, "tire": 1777, "younger": 1778, "biking": 1779, "at&t": 1780, "giants": 1781, "marshmallows": 1782, "caramel": 1783, "polo": 1784, "emirates": 1785, "salon": 1786, "focus": 1787, "on motorcycle": 1788, "magnets": 1789, "mat": 1790, "ivy": 1791, "cakes": 1792, "chrome": 1793, "bob": 1794, "asia": 1795, "graduation": 1796, "cauliflower": 1797, "in snow": 1798, "c": 1799, "rough": 1800, "vacation": 1801, "air": 1802, "windy": 1803, "victoria": 1804, "4:45": 1805, "trick": 1806, "coconut": 1807, "labrador": 1808, "on left": 1809, "yellow and green": 1810, "butterfly": 1811, "fake": 1812, "on napkin": 1813, "bricks": 1814, "wine glasses": 1815, "detroit": 1816, "man's": 1817, "parsley": 1818, "art": 1819, "subway": 1820, "wave": 1821, "placemat": 1822, "hydrant": 1823, "sofa": 1824, "pigeon": 1825, "riding elephant": 1826, "all": 1827, "branches": 1828, "plant": 1829, "to eat": 1830, "zucchini": 1831, "feta": 1832, "neon": 1833, "mouse pad": 1834, "cloud": 1835, "toilet paper": 1836, "pumpkin": 1837, "rowing": 1838, "toronto": 1839, "handicap": 1840, "seeds": 1841, "fly kite": 1842, "chicago": 1843, "marble": 1844, "frame": 1845, "150": 1846, "rocky": 1847, "give way": 1848, "sauce": 1849, "it's not": 1850, "control": 1851, "high chair": 1852, "playstation": 1853, "xbox": 1854, "not likely": 1855, "roman": 1856, "land": 1857, "1:35": 1858, "lifeguard": 1859, "on pizza": 1860, "size": 1861, "bull": 1862, "dandelions": 1863, "equestrian": 1864, "goose": 1865, "8 feet": 1866, "recessed": 1867, "statue": 1868, "index": 1869, "phillies": 1870, "strike": 1871, "mirrors": 1872, "pointing": 1873, "farmer": 1874, "collie": 1875, "motorbike": 1876, "lanes": 1877, "bikes": 1878, "biker": 1879, "arrows": 1880, "gas station": 1881, "logs": 1882, "smaller": 1883, "desert": 1884, "yield": 1885, "flags": 1886, "stool": 1887, "kitten": 1888, "doll": 1889, "daffodils": 1890, "letters": 1891, "dishwasher": 1892, "first base": 1893, "nuts": 1894, "2013": 1895, "persian": 1896, "swim trunks": 1897, "deep": 1898, "o": 1899, "doubles": 1900, "toothpicks": 1901, "in field": 1902, "wristband": 1903, "wheels": 1904, "baking": 1905, "4:15": 1906, "11:00": 1907, "ear": 1908, "2007": 1909, "51": 1910, "chevy": 1911, "using computer": 1912, "frog": 1913, "storm": 1914, "boogie board": 1915, "hungry": 1916, "by window": 1917, "ambulance": 1918, "pigtails": 1919, "audi": 1920, "microsoft": 1921, "on man": 1922, "cannot tell": 1923, "stained glass": 1924, "hugging": 1925, "laying down": 1926, "3:00": 1927, "taxi": 1928, "pedestrian": 1929, "landing": 1930, "numbers": 1931, "38": 1932, "stones": 1933, "on tree": 1934, "clocks": 1935, "new": 1936, "picnic": 1937, "fog": 1938, "buffalo": 1939, "under armour": 1940, "cocker spaniel": 1941, "orioles": 1942, "no sign": 1943, "telling time": 1944, "bags": 1945, "golden gate": 1946, "cover": 1947, "castle": 1948, "canoe": 1949, "selfie": 1950, "cream": 1951, "floating": 1952, "indoor": 1953, "antique": 1954, "aluminum": 1955, "silver and black": 1956, "cast iron": 1957, "peas": 1958, "sun hat": 1959, "on right": 1960, "swiss": 1961, "flour": 1962, "under sink": 1963, "fashion": 1964, "fedora": 1965, "shells": 1966, "1 hour": 1967, "puppy": 1968, "in stands": 1969, "not here": 1970, "motor": 1971, "thousands": 1972, "120": 1973, "sail": 1974, "butt": 1975, "mexican": 1976, "dead end": 1977, "paddle": 1978, "bathing suit": 1979, "shop": 1980, "onion rings": 1981, "boxing": 1982, "birthday cake": 1983, "chalk": 1984, "scenery": 1985, "style": 1986, "nissan": 1987, "sticker": 1988, "on rack": 1989, "1 4": 1990, "woman's": 1991, "surprised": 1992, "north face": 1993, "squash": 1994, "not sure": 1995, "email": 1996, "spotted": 1997, "seat": 1998, "himself": 1999, "circles": 2000, "san diego": 2001, "kia": 2002, "mattress": 2003, "obama": 2004, "lamb": 2005, "american flag": 2006, "climbing": 2007, "skull and crossbones": 2008, "roast beef": 2009, "visor": 2010, "herd": 2011, "double": 2012, "52": 2013, "high": 2014, "stagecoach": 2015, "cart": 2016, "feeding": 2017, "eaten": 2018, "cone": 2019, "11:15": 2020, "smoothie": 2021, "golf": 2022, "colorado": 2023, "electronics": 2024, "5:15": 2025, "bowling": 2026, "players": 2027, "ketchup and mustard": 2028, "styrofoam": 2029, "6 feet": 2030, "hawk": 2031, "cheddar": 2032, "12:28": 2033, "arabic": 2034, "12:25": 2035, "12:10": 2036, "shower curtain": 2037, "army": 2038, "salmon": 2039, "10:40": 2040, "hanging": 2041, "whole": 2042, "behind fence": 2043, "bars": 2044, "moss": 2045, "no dog": 2046, "traffic": 2047, "10:25": 2048, "r": 2049, "countryside": 2050, "machine": 2051, "directions": 2052, "cooked": 2053, "aa": 2054, "6:45": 2055, "4 way": 2056, "stripe": 2057, "brand": 2058, "baseball player": 2059, "bunk": 2060, "coleslaw": 2061, "fishing boat": 2062, "at table": 2063, "europe": 2064, "dead": 2065, "arch": 2066, "scrambled": 2067, "clothing": 2068, "closet": 2069, "egg": 2070, "suitcases": 2071, "indoors": 2072, "coffee pot": 2073, "tires": 2074, "lilies": 2075, "cafe": 2076, "9:35": 2077, "teal": 2078, "toothpaste": 2079, "in background": 2080, "tarmac": 2081, "painted": 2082, "sunset": 2083, "orange and yellow": 2084, "oar": 2085, "peaches": 2086, "zebra and giraffe": 2087, "ladybug": 2088, "20 ft": 2089, "sesame seeds": 2090, "hills": 2091, "2:30": 2092, "stucco": 2093, "tail": 2094, "couple": 2095, "kawasaki": 2096, "smooth": 2097, "powdered sugar": 2098, "pedestrian crossing": 2099, "french fries": 2100, "picnic table": 2101, "teeth": 2102, "ribbon": 2103, "saddle": 2104, "15 feet": 2105, "earbuds": 2106, "on train": 2107, "39": 2108, "curb": 2109, "tow": 2110, "shark": 2111, "white and orange": 2112, "6:25": 2113, "gravy": 2114, "fork and spoon": 2115, "pooping": 2116, "curtain": 2117, "lime": 2118, "skull": 2119, "crossing": 2120, "speed limit": 2121, "peacock": 2122, "boredom": 2123, "neck": 2124, "hit": 2125, "dragon": 2126, "tissues": 2127, "basil": 2128, "waving": 2129, "blue team": 2130, "rectangles": 2131, "helicopter": 2132, "mud": 2133, "us": 2134, "balcony": 2135, "red and gray": 2136, "firefighter": 2137, "sunflower": 2138, "wallpaper": 2139, "best buy": 2140, "11:20": 2141, "public market center": 2142, "seattle": 2143, "bookshelf": 2144, "looking": 2145, "1 inch": 2146, "harley": 2147, "urinal": 2148, "cartoon": 2149, "t shirt and jeans": 2150, "navy": 2151, "fedex": 2152, "rays": 2153, "deck": 2154, "coaster": 2155, "1:20": 2156, "50 feet": 2157, "4:20": 2158, "us open": 2159, "looking at camera": 2160, "600": 2161, "national express": 2162, "white house": 2163, "5:00": 2164, "jp morgan": 2165, "palm trees": 2166, "tub": 2167, "pens": 2168, "soldiers": 2169, "2 people": 2170, "animal": 2171, "speaker": 2172, "hamburger": 2173, "spaghetti": 2174, "green beans": 2175, "it isn't": 2176, "10:20": 2177, "buildings": 2178, "on shelf": 2179, "baseball uniform": 2180, "tiled": 2181, "orange and blue": 2182, "90": 2183, "north america": 2184, "arrow": 2185, "news": 2186, "tropicana": 2187, "formal": 2188, "in grass": 2189, "thumbs up": 2190, "clip": 2191, "gate": 2192, "tennis player": 2193, "lilac": 2194, "pastry": 2195, "nose": 2196, "pacifier": 2197, "11:35": 2198, "different teams": 2199, "cardinals": 2200, "exhaust": 2201, "hauling": 2202, "on tray": 2203, "bagel": 2204, "huge": 2205, "out of focus": 2206, "cook": 2207, "wheat": 2208, "photo": 2209, "ghost": 2210, "sedan": 2211, "qatar": 2212, "zig zag": 2213, "lanyard": 2214, "pink and white": 2215, "sesame": 2216, "space": 2217, "no clock": 2218, "warning": 2219, "snowy": 2220, "tater tots": 2221, "tropical": 2222, "grandfather": 2223, "mac": 2224, "magnet": 2225, "photoshop": 2226, "pajamas": 2227, "350": 2228, "casserole": 2229, "4:55": 2230, "pelican": 2231, "2009": 2232, "clydesdale": 2233, "tow truck": 2234, "belt": 2235, "west": 2236, "omelet": 2237, "heavy": 2238, "crown": 2239, "in corner": 2240, "hexagon": 2241, "mound": 2242, "iris": 2243, "g": 2244, "12:45": 2245, "2:15": 2246, "3:10": 2247, "drawing": 2248, "only": 2249, "little girl": 2250, "washing": 2251, "nokia": 2252, "windsor": 2253, "2 men": 2254, "parmesan cheese": 2255, "on woman": 2256, "freezer": 2257, "icing": 2258, "venice": 2259, "dairy": 2260, "several": 2261, "concentration": 2262, "3:15": 2263, "no smoking": 2264, "kayak": 2265, "frosting": 2266, "jetblue": 2267, "thoroughbred": 2268, "parakeet": 2269, "shoe": 2270, "skeleton": 2271, "britain": 2272, "ties": 2273, "in sink": 2274, "patio": 2275, "bank": 2276, "camouflage": 2277, "privacy": 2278, "bib": 2279, "blue and gray": 2280, "looking out window": 2281, "falling": 2282, "bucket": 2283, "cupcakes": 2284, "throw ball": 2285, "garden": 2286, "almonds": 2287, "ducati": 2288, "ireland": 2289, "plastic wrap": 2290, "starbucks": 2291, "all way": 2292, "bark": 2293, "home plate": 2294, "base": 2295, "dog food": 2296, "toys": 2297, "blue and orange": 2298, "1 in front": 2299, "foot": 2300, "dc": 2301, "california": 2302, "towing": 2303, "cheesecake": 2304, "bushes": 2305, "bow tie": 2306, "millions": 2307, "down street": 2308, "2011": 2309, "police officer": 2310, "windmill": 2311, "taking pictures": 2312, "street name": 2313, "cleaning": 2314, "on pole": 2315, "russia": 2316, "main street": 2317, "catch ball": 2318, "mario": 2319, "pirate": 2320, "track": 2321, "garage": 2322, "7:10": 2323, "they aren't": 2324, "mother and child": 2325, "tents": 2326, "fancy": 2327, "tattoos": 2328, "alcohol": 2329, "2:45": 2330, "wheelchair": 2331, "money": 2332, "top hat": 2333, "willow": 2334, "cd": 2335, "brushing hair": 2336, "pancake": 2337, "80": 2338, "listening to music": 2339, "green and red": 2340, "barrier": 2341, "vests": 2342, "hiking": 2343, "tank top": 2344, "lufthansa": 2345, "student": 2346, "menu": 2347, "forehand": 2348, "wii controllers": 2349, "acer": 2350, "wall st": 2351, "hundreds": 2352, "water ski": 2353, "furniture": 2354, "paisley": 2355, "pizza hut": 2356, "baseball game": 2357, "hill": 2358, "prom": 2359, "1 world": 2360, "tiara": 2361, "students": 2362, "information": 2363, "hazy": 2364, "nasa": 2365, "canon": 2366, "bird feeder": 2367, "crane": 2368, "dr pepper": 2369, "logitech": 2370, "2:10": 2371, "all of them": 2372, "utensils": 2373, "telephone": 2374, "converse": 2375, "bone": 2376, "jeep": 2377, "nursing": 2378, "krispy kreme": 2379, "cameraman": 2380, "pee": 2381, "ranch": 2382, "polka dots": 2383, "railroad crossing": 2384, "shirts": 2385, "feeder": 2386, "above toilet": 2387, "unclear": 2388, "below": 2389, "43": 2390, "spoons": 2391, "calendar": 2392, "vaio": 2393, "fox": 2394, "mint": 2395, "after": 2396, "spiderman": 2397, "lg": 2398, "concert": 2399, "on rock": 2400, "fluffy": 2401, "gray and black": 2402, "coats": 2403, "lady": 2404, "dodge": 2405, "easyjet": 2406, "pearl": 2407, "bunt": 2408, "flat screen": 2409, "10:30": 2410, "music": 2411, "polar bears": 2412, "riding horse": 2413, "lift": 2414, "angry": 2415, "cookies": 2416, "3:45": 2417, "buttons": 2418, "hot": 2419, "cute": 2420, "behind": 2421, "dole": 2422, "in motion": 2423, "26": 2424, "pans": 2425, "love": 2426, "winnie pooh": 2427, "pear": 2428, "copyright": 2429, "2 hours": 2430, "snowsuit": 2431, "kissing": 2432, "backhand": 2433, "to get to other side": 2434, "metro": 2435, "swans": 2436, "very fast": 2437, "can't see it": 2438, "nintendo": 2439, "direction": 2440, "waiting": 2441, "mohawk": 2442, "st patrick's day": 2443, "rail": 2444, "hoodie": 2445, "feet": 2446, "swirls": 2447, "muffins": 2448, "4:05": 2449, "106": 2450, "10:55": 2451, "coins": 2452, "mitt": 2453, "game controller": 2454, "room": 2455, "adults": 2456, "urinals": 2457, "cameras": 2458, "marker": 2459, "upright": 2460, "brass": 2461, "sled": 2462, "teacher": 2463, "conductor": 2464, "farmers market": 2465, "toiletries": 2466, "blue and black": 2467, "soccer field": 2468, "banana peel": 2469, "sprite": 2470, "doughnuts": 2471, "bank of america": 2472, "on his face": 2473, "heat": 2474, "emergency": 2475, "ski slope": 2476, "hard": 2477, "41": 2478, "6:00": 2479, "in his hand": 2480, "cluttered": 2481, "dog show": 2482, "on boat": 2483, "grizzly": 2484, "drums": 2485, "not": 2486, "in hand": 2487, "easy": 2488, "400": 2489, "under table": 2490, "d": 2491, "hitting ball": 2492, "photography": 2493, "intersection": 2494, "backwards": 2495, "crocs": 2496, "marina": 2497, "chips": 2498, "bible": 2499, "harry potter": 2500, "hawaii": 2501, "fanta": 2502, "half full": 2503, "carriage": 2504, "curious": 2505, "12:50": 2506, "black white": 2507, "geese": 2508, "pork": 2509, "mailbox": 2510, "l": 2511, "sidecar": 2512, "poop": 2513, "wings": 2514, "penguin": 2515, "to see": 2516, "pocket": 2517, "steps": 2518, "cubs": 2519, "junk": 2520, "deer": 2521, "ottoman": 2522, "salt": 2523, "condiments": 2524, "1:55": 2525, "post": 2526, "bulldog": 2527, "notebook": 2528, "no cat": 2529, "champagne": 2530, "jets": 2531, "knee pads": 2532, "throw frisbee": 2533, "drinks": 2534, "leopard": 2535, "taller": 2536, "cooler": 2537, "bundt": 2538, "monday": 2539, "grape": 2540, "wine tasting": 2541, "under": 2542, "baskets": 2543, "santa hat": 2544, "chest": 2545, "sewing": 2546, "on car": 2547, "sony ericsson": 2548, "peeing": 2549, "for photo": 2550, "tour": 2551, "few": 2552, "singapore": 2553, "fireman": 2554, "fire extinguisher": 2555, "wildebeest": 2556, "lemons": 2557, "peanuts": 2558, "babies": 2559, "wiimote": 2560, "guitar hero": 2561, "slide": 2562, "stopped": 2563, "library": 2564, "multi colored": 2565, "blue and pink": 2566, "choppy": 2567, "sailing": 2568, "brush": 2569, "grinding": 2570, "jelly": 2571, "dairy queen": 2572, "shaking hands": 2573, "ge": 2574, "tigers": 2575, "tokyo": 2576, "philadelphia": 2577, "ski boots": 2578, "buses": 2579, "11:45": 2580, "collage": 2581, "pink and blue": 2582, "jesus": 2583, "singles": 2584, "iron": 2585, "coffee table": 2586, "2 years": 2587, "don't walk": 2588, "classroom": 2589, "on water": 2590, "potato salad": 2591, "posts": 2592, "harbor": 2593, "residential": 2594, "joshua": 2595, "uk": 2596, "burgers": 2597, "deli": 2598, "kicking": 2599, "lace": 2600, "overalls": 2601, "vehicles": 2602, "ram": 2603, "dancing": 2604, "47": 2605, "shed": 2606, "lid": 2607, "he's not": 2608, "fans": 2609, "amtrak": 2610, "space shuttle": 2611, "ostrich": 2612, "bathtub": 2613, "kneeling": 2614, "2:50": 2615, "mall": 2616, "yellow and orange": 2617, "gazebo": 2618, "wax": 2619, "slow down": 2620, "lays": 2621, "hammer time": 2622, "octopus": 2623, "crib": 2624, "banana split": 2625, "broadway": 2626, "pottery": 2627, "wavy": 2628, "farmers": 2629, "holding phone": 2630, "on phone": 2631, "squirrel": 2632, "wax paper": 2633, "tusks": 2634, "dining": 2635, "packing": 2636, "kangaroo": 2637, "dawn": 2638, "defense": 2639, "powdered": 2640, "thomas": 2641, "budweiser": 2642, "back left": 2643, "stir fry": 2644, "beijing": 2645, "11:10": 2646, "tripod": 2647, "wide": 2648, "slope": 2649, "black and gray": 2650, "planter": 2651, "chili": 2652, "siblings": 2653, "kayaking": 2654, "captivity": 2655, "opaque": 2656, "rack": 2657, "panda": 2658, "doorway": 2659, "wheelie": 2660, "pelicans": 2661, "genetics": 2662, "not in service": 2663, "volvo": 2664, "dachshund": 2665, "v": 2666, "on laptop": 2667, "western": 2668, "gone": 2669, "birthday party": 2670, "parking garage": 2671, "tying tie": 2672, "blueberry": 2673, "scale": 2674, "notes": 2675, "train car": 2676, "man made": 2677, "stability": 2678, "lily": 2679, "lying down": 2680, "pacific": 2681, "high heels": 2682, "pare": 2683, "checkerboard": 2684, "partly cloudy": 2685, "cool": 2686, "n": 2687, "toilets": 2688, "tree branch": 2689, "copper": 2690, "cycling": 2691, "5:50": 2692, "870": 2693, "shopping": 2694, "7:05": 2695, "zipper": 2696, "holding umbrella": 2697, "batman": 2698, "lotion": 2699, "1:25": 2700, "black and brown": 2701, "playing video game": 2702, "girl on right": 2703, "legos": 2704, "drinking water": 2705, "burrito": 2706, "plow": 2707, "jet ski": 2708, "spiral": 2709, "ibm": 2710, "tools": 2711, "flashlight": 2712, "cherries": 2713, "maple leaf": 2714, "mountainous": 2715, "under tree": 2716, "vines": 2717, "sushi": 2718, "baker": 2719, "snake": 2720, "globe": 2721, "target": 2722, "john": 2723, "pomeranian": 2724, "tuxedo": 2725, "hockey": 2726, "sleeve": 2727, "leaning": 2728, "wireless": 2729, "11:05": 2730, "compaq": 2731, "do not enter": 2732, "radish": 2733, "1:05": 2734, "dim": 2735, "advertisement": 2736, "movement": 2737, "model": 2738, "hammock": 2739, "swing": 2740, "sheet": 2741, "google": 2742, "boardwalk": 2743, "right 1": 2744, "haircut": 2745, "ankle": 2746, "3:30": 2747, "exit": 2748, "csx": 2749, "tim hortons": 2750, "lego": 2751, "cucumbers": 2752, "angel": 2753, "12:20": 2754, "racquet": 2755, "behind woman": 2756, "potato": 2757, "egg salad": 2758, "controllers": 2759, "recliner": 2760, "upside down": 2761, "mosaic": 2762, "before": 2763, "antenna": 2764, "3:50": 2765, "10:15": 2766, "lion": 2767, "camo": 2768, "fighter": 2769, "silver and red": 2770, "dirt bike": 2771, "playing video games": 2772, "used": 2773, "crates": 2774, "horizontally": 2775, "plunger": 2776, "refrigerators": 2777, "radiator": 2778, "stork": 2779, "in basket": 2780, "cap": 2781, "living": 2782, "married": 2783, "briefcase": 2784, "bottom left": 2785, "30 mph": 2786, "ascending": 2787, "flip phone": 2788, "101": 2789, "11:50": 2790, "gun": 2791, "arizona": 2792, "foam": 2793, "serious": 2794, "y": 2795, "close up": 2796, "pancakes": 2797, "heineken": 2798, "paw": 2799, "cnn": 2800, "comforter": 2801, "sheets": 2802, "8:35": 2803, "driveway": 2804, "fair": 2805, "cleaner": 2806, "1 year": 2807, "delivery": 2808, "commuter": 2809, "apple and banana": 2810, "chase": 2811, "72": 2812, "safe": 2813, "trucks": 2814, "trunks": 2815, "spider": 2816, "64": 2817, "slacks": 2818, "meeting": 2819, "7:00": 2820, "skiers": 2821, "shaved": 2822, "carrot cake": 2823, "holding": 2824, "surfers": 2825, "giraffe and zebra": 2826, "7:45": 2827, "mississippi": 2828, "seaweed": 2829, "black and pink": 2830, "horse racing": 2831, "orchid": 2832, "rv": 2833, "tourist": 2834, "above door": 2835, "leaving": 2836, "pitch": 2837, "crest": 2838, "miami": 2839, "asics": 2840, "flood": 2841, "bus station": 2842, "take off": 2843, "amazon": 2844, "practice": 2845, "entering": 2846, "diesel": 2847, "pm": 2848, "wetsuits": 2849, "remodeling": 2850, "porch": 2851, "7:35": 2852, "tie dye": 2853, "baked": 2854, "life jacket": 2855, "cylinder": 2856, "grilled cheese": 2857, "meatballs": 2858, "paddling": 2859, "banana bread": 2860, "monster": 2861, "smiley face": 2862, "not high": 2863, "keys": 2864, "dreadlocks": 2865, "kitchenaid": 2866, "straight ahead": 2867, "badminton": 2868, "long sleeve": 2869, "sheepdog": 2870, "5:18": 2871, "end": 2872, "on shore": 2873, "scratching": 2874, "oriental": 2875, "5:05": 2876, "alligator": 2877, "city bus": 2878, "purple and white": 2879, "10:50": 2880, "each other": 2881, "weeds": 2882, "tinkerbell": 2883, "rottweiler": 2884, "apartments": 2885, "snowflakes": 2886, "stop light": 2887, "sweatshirt": 2888, "shore": 2889, "bidet": 2890, "switzerland": 2891, "stretching": 2892, "tv stand": 2893, "boundaries": 2894, "65": 2895, "bronze": 2896, "jar": 2897, "middle 1": 2898, "54": 2899, "skate": 2900, "easton": 2901, "turn right": 2902, "raspberries": 2903, "singing": 2904, "on bus": 2905, "carnations": 2906, "descending": 2907, "classic": 2908, "suspenders": 2909, "not long": 2910, "8:50": 2911, "father": 2912, "anniversary": 2913, "hsbc": 2914, "very long": 2915, "space needle": 2916, "skatepark": 2917, "fruit salad": 2918, "kenmore": 2919, "no water": 2920, "8:05": 2921, "db": 2922, "baby's breath": 2923, "shelter": 2924, "1980": 2925, "no left turn": 2926, "washington monument": 2927, "ham and cheese": 2928, "10 inches": 2929, "8:55": 2930, "savory": 2931, "6:35": 2932, "indians": 2933, "9:05": 2934, "fires": 2935, "pipes": 2936, "donkey": 2937, "cds": 2938, "mitsubishi": 2939, "tell time": 2940, "outfield": 2941, "christian": 2942, "puma": 2943, "parking meters": 2944, "cranes": 2945, "flip": 2946, "wine bottle": 2947, "stadium": 2948, "mouthwash": 2949, "heinz": 2950, "distance": 2951, "macaroni": 2952, "on plane": 2953, "triumph": 2954, "more": 2955, "4:50": 2956, "single engine": 2957, "disney": 2958, "on stove": 2959, "shih tzu": 2960, "fried": 2961, "to hit ball": 2962, "in her hand": 2963, "sunrise": 2964, "2nd": 2965, "elmo": 2966, "kite string": 2967, "suzuki": 2968, "traffic lights": 2969, "blt": 2970, "i": 2971, "hitting": 2972, "htc": 2973, "healthy": 2974, "current": 2975, "star alliance": 2976, "stomach": 2977, "watch tv": 2978, "tulip": 2979, "5:10": 2980, "right side": 2981, "4:40": 2982, "ginger": 2983, "on sign": 2984, "cushion": 2985, "5:30": 2986, "learning": 2987, "pencil": 2988, "maroon": 2989, "food processor": 2990, "5:40": 2991, "dog bed": 2992, "michigan": 2993, "close": 2994, "license plate": 2995, "crows": 2996, "right hand": 2997, "normal": 2998, "green and brown": 2999, "1.00": 3000, "000": 3001, "1:40": 3002, "wing": 3003, "american airlines": 3004, "kodak": 3005, "mural": 3006, "sniffing": 3007, "1:15": 3008, "behind bench": 3009, "cardinal": 3010, "no light": 3011, "warmth": 3012, "paved": 3013, "skyscrapers": 3014, "swinging bat": 3015, "watermark": 3016, "in cup": 3017, "pizza box": 3018, "dough": 3019, "hiding": 3020, "goal": 3021, "no plate": 3022, "shower head": 3023, "ripe": 3024, "1:10": 3025, "1 in back": 3026, "older": 3027, "nest": 3028, "multiple": 3029, "cinnamon": 3030, "bin": 3031, "new orleans": 3032, "colored": 3033, "enclosure": 3034, "bride": 3035, "on dresser": 3036, "star wars": 3037, "in back": 3038, "triangles": 3039, "over easy": 3040, "cilantro": 3041, "statues": 3042, "sticks": 3043, "formica": 3044, "roundabout": 3045, "bowls": 3046, "ahead": 3047, "years": 3048, "drain": 3049, "veggies": 3050, "no shirt": 3051, "taking photo": 3052, "tugboat": 3053, "broke": 3054, "59": 3055, "cadillac": 3056, "prince": 3057, "left side": 3058, "1 in middle": 3059, "10:45": 3060, "drying": 3061, "11:25": 3062, "silk": 3063, "conference room": 3064, "buoys": 3065, "pockets": 3066, "daffodil": 3067, "6:40": 3068, "walgreens": 3069, "4 ft": 3070, "6:05": 3071, "virgin atlantic": 3072, "12:40": 3073, "digital": 3074, "ups": 3075, "westjet": 3076, "bikers": 3077, "us air force": 3078, "limes": 3079, "comcast": 3080, "dip": 3081, "7:55": 3082, "man in middle": 3083, "bus driver": 3084, "soon": 3085, "futon": 3086, "selling": 3087, "braid": 3088, "mariners": 3089, "wisconsin": 3090, "99": 3091, "citizen": 3092, "broccoli and carrots": 3093, "grocery store": 3094, "us airways": 3095, "49": 3096, "bored": 3097, "red velvet": 3098, "hotel room": 3099, "qantas": 3100, "tam": 3101, "korean air": 3102, "10:35": 3103, "whirlpool": 3104, "coffee cup": 3105, "hilly": 3106, "9:12": 3107, "whipped cream": 3108, "video": 3109, "finger": 3110, "competition": 3111, "hollywood": 3112, "sas": 3113, "backward": 3114, "beads": 3115, "cosmo": 3116, "10:08": 3117, "jal": 3118, "6:30": 3119, "100 year party ct": 3120, "hispanic": 3121, "in cabbage town": 3122, "opponent": 3123, "woodpecker": 3124, "visilab": 3125, "mt airy": 3126, "crosstown": 3127, "freightliner": 3128}
\ No newline at end of file
diff --git a/utils/const.py b/utils/const.py
new file mode 100644
index 0000000..f82e205
--- /dev/null
+++ b/utils/const.py
@@ -0,0 +1,9 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+constants
+"""
+IMG_DIM = 2048
+IMG_LABEL_DIM = 1601
+BUCKET_SIZE = 8192
diff --git a/utils/distributed.py b/utils/distributed.py
new file mode 100644
index 0000000..fce69ea
--- /dev/null
+++ b/utils/distributed.py
@@ -0,0 +1,209 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+distributed API using Horovod
+Modified from OpenNMT's native pytorch distributed utils
+(https://github.com/OpenNMT/OpenNMT-py)
+"""
+import math
+import pickle
+
+import torch
+from horovod import torch as hvd
+
+
+def all_reduce_and_rescale_tensors(tensors, rescale_denom):
+ """All-reduce and rescale tensors at once (as a flattened tensor)
+
+ Args:
+ tensors: list of Tensors to all-reduce
+ rescale_denom: denominator for rescaling summed Tensors
+ """
+ # buffer size in bytes, determine equiv. # of elements based on data type
+ sz = sum(t.numel() for t in tensors)
+ buffer_t = tensors[0].new(sz).zero_()
+
+ # copy tensors into buffer_t
+ offset = 0
+ for t in tensors:
+ numel = t.numel()
+ buffer_t[offset:offset+numel].copy_(t.view(-1))
+ offset += numel
+
+ # all-reduce and rescale
+ hvd.allreduce_(buffer_t[:offset])
+ buffer_t.div_(rescale_denom)
+
+ # copy all-reduced buffer back into tensors
+ offset = 0
+ for t in tensors:
+ numel = t.numel()
+ t.view(-1).copy_(buffer_t[offset:offset+numel])
+ offset += numel
+
+
+def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom,
+ buffer_size=10485760):
+ """All-reduce and rescale tensors in chunks of the specified size.
+
+ Args:
+ tensors: list of Tensors to all-reduce
+ rescale_denom: denominator for rescaling summed Tensors
+ buffer_size: all-reduce chunk size in bytes
+ """
+ # buffer size in bytes, determine equiv. # of elements based on data type
+ buffer_t = tensors[0].new(
+ math.ceil(buffer_size / tensors[0].element_size())).zero_()
+ buffer = []
+
+ def all_reduce_buffer():
+ # copy tensors into buffer_t
+ offset = 0
+ for t in buffer:
+ numel = t.numel()
+ buffer_t[offset:offset+numel].copy_(t.view(-1))
+ offset += numel
+
+ # all-reduce and rescale
+ hvd.allreduce_(buffer_t[:offset])
+ buffer_t.div_(rescale_denom)
+
+ # copy all-reduced buffer back into tensors
+ offset = 0
+ for t in buffer:
+ numel = t.numel()
+ t.view(-1).copy_(buffer_t[offset:offset+numel])
+ offset += numel
+
+ filled = 0
+ for t in tensors:
+ sz = t.numel() * t.element_size()
+ if sz > buffer_size:
+ # tensor is bigger than buffer, all-reduce and rescale directly
+ hvd.allreduce_(t)
+ t.div_(rescale_denom)
+ elif filled + sz > buffer_size:
+ # buffer is full, all-reduce and replace buffer with grad
+ all_reduce_buffer()
+ buffer = [t]
+ filled = sz
+ else:
+ # add tensor to buffer
+ buffer.append(t)
+ filled += sz
+
+ if len(buffer) > 0:
+ all_reduce_buffer()
+
+
+def broadcast_tensors(tensors, root_rank, buffer_size=10485760):
+ """broadcast tensors in chunks of the specified size.
+
+ Args:
+ tensors: list of Tensors to broadcast
+ root_rank: rank to broadcast
+ buffer_size: broadcast chunk size in bytes
+ """
+ # buffer size in bytes, determine equiv. # of elements based on data type
+ buffer_t = tensors[0].new(
+ math.ceil(buffer_size / tensors[0].element_size())).zero_()
+ buffer = []
+
+ def broadcast_buffer():
+ # copy tensors into buffer_t
+ offset = 0
+ for t in buffer:
+ numel = t.numel()
+ buffer_t[offset:offset+numel].copy_(t.view(-1))
+ offset += numel
+
+ # broadcast
+ hvd.broadcast_(buffer_t[:offset], root_rank)
+
+ # copy all-reduced buffer back into tensors
+ offset = 0
+ for t in buffer:
+ numel = t.numel()
+ t.view(-1).copy_(buffer_t[offset:offset+numel])
+ offset += numel
+
+ filled = 0
+ for t in tensors:
+ sz = t.numel() * t.element_size()
+ if sz > buffer_size:
+ # tensor is bigger than buffer, broadcast directly
+ hvd.broadcast_(t, root_rank)
+ elif filled + sz > buffer_size:
+ # buffer is full, broadcast and replace buffer with tensor
+ broadcast_buffer()
+ buffer = [t]
+ filled = sz
+ else:
+ # add tensor to buffer
+ buffer.append(t)
+ filled += sz
+
+ if len(buffer) > 0:
+ broadcast_buffer()
+
+
+def _encode(enc, max_size, use_max_size=False):
+ enc_size = len(enc)
+ enc_byte = max(math.floor(math.log(max_size, 256)+1), 1)
+ if use_max_size:
+ # this is used for broadcasting
+ buffer_ = torch.cuda.ByteTensor(max_size+enc_byte)
+ else:
+ buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte)
+ remainder = enc_size
+ for i in range(enc_byte):
+ base = 256 ** (enc_byte-i-1)
+ buffer_[i] = remainder // base
+ remainder %= base
+ buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc))
+ return buffer_, enc_byte
+
+
+def _decode(buffer_, enc_byte):
+ size = sum(256 ** (enc_byte-i-1) * buffer_[i].item()
+ for i in range(enc_byte))
+ bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist())
+ shift = size + enc_byte
+ return bytes_list, shift
+
+
+_BUFFER_SIZE = 4096
+
+
+def all_gather_list(data):
+ """Gathers arbitrary data from all nodes into a list."""
+ enc = pickle.dumps(data)
+
+ enc_size = len(enc)
+ max_size = hvd.allgather(torch.tensor([enc_size]).cuda()).max().item()
+ in_buffer, enc_byte = _encode(enc, max_size)
+
+ out_buffer = hvd.allgather(in_buffer[:enc_byte+enc_size])
+
+ results = []
+ for _ in range(hvd.size()):
+ bytes_list, shift = _decode(out_buffer, enc_byte)
+ out_buffer = out_buffer[shift:]
+ result = pickle.loads(bytes_list)
+ results.append(result)
+ return results
+
+
+def any_broadcast(data, root_rank):
+ """broadcast arbitrary data from root_rank to all nodes."""
+ enc = pickle.dumps(data)
+
+ max_size = hvd.allgather(torch.tensor([len(enc)]).cuda()).max().item()
+ buffer_, enc_byte = _encode(enc, max_size, use_max_size=True)
+
+ hvd.broadcast_(buffer_, root_rank)
+
+ bytes_list, _ = _decode(buffer_, enc_byte)
+ result = pickle.loads(bytes_list)
+ return result
diff --git a/utils/itm_eval.py b/utils/itm_eval.py
new file mode 100644
index 0000000..0c76a8e
--- /dev/null
+++ b/utils/itm_eval.py
@@ -0,0 +1,136 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Image Text Retrieval evaluation helper
+"""
+from time import time
+import numpy as np
+import torch
+from horovod import torch as hvd
+from tqdm import tqdm
+
+from .logger import LOGGER
+from .misc import NoOp
+from .distributed import all_gather_list
+
+
+@torch.no_grad()
+def itm_eval(score_matrix, txt_ids, img_ids, txt2img, img2txts, id2len=None):
+ # image retrieval
+ img2j = {i: j for j, i in enumerate(img_ids)}
+ _, rank_txt = score_matrix.topk(10, dim=1)
+
+ if id2len is not None:
+ # print('| itm_eval | ids: ', len(ids))
+ # print(ids[-10:])
+ top5_rank_text_ids = rank_txt[:, :5]
+ dataset_text_lens = []
+ for text_ids in top5_rank_text_ids.data.cpu().numpy():
+ text_lens = []
+ for text_id in text_ids:
+ # print('| txt_id: ', txt_id, type(txt_id))
+ text_id = int(text_id)
+ text_lens.append(id2len[txt_ids[text_id]])
+ dataset_text_lens.append(text_lens)
+ dataset_text_lens = np.array(dataset_text_lens)
+ print('| mean of top5 texts: ', dataset_text_lens.mean())
+
+ gt_img_j = torch.LongTensor([img2j[txt2img[txt_id]]
+ for txt_id in txt_ids],
+ ).to(rank_txt.device
+ ).unsqueeze(1).expand_as(rank_txt)
+ rank = (rank_txt == gt_img_j).nonzero()
+ if rank.numel():
+ ir_r1 = (rank < 1).sum().item() / len(txt_ids)
+ ir_r5 = (rank < 5).sum().item() / len(txt_ids)
+ ir_r10 = (rank < 10).sum().item() / len(txt_ids)
+ else:
+ ir_r1, ir_r5, ir_r10 = 0, 0, 0
+
+ # text retrieval
+ txt2i = {t: i for i, t in enumerate(txt_ids)}
+ _, rank_img = score_matrix.topk(10, dim=0)
+ tr_r1, tr_r5, tr_r10 = 0, 0, 0
+ for j, img_id in enumerate(img_ids):
+ gt_is = [txt2i[t] for t in img2txts[img_id]]
+ ranks = [(rank_img[:, j] == i).nonzero() for i in gt_is]
+ rank = min([10] + [r.item() for r in ranks if r.numel()])
+ if rank < 1:
+ tr_r1 += 1
+ if rank < 5:
+ tr_r5 += 1
+ if rank < 10:
+ tr_r10 += 1
+ tr_r1 /= len(img_ids)
+ tr_r5 /= len(img_ids)
+ tr_r10 /= len(img_ids)
+
+ tr_mean = (tr_r1 + tr_r5 + tr_r10) / 3
+ ir_mean = (ir_r1 + ir_r5 + ir_r10) / 3
+ r_mean = (tr_mean + ir_mean) / 2
+
+ eval_log = {'txt_r1': tr_r1,
+ 'txt_r5': tr_r5,
+ 'txt_r10': tr_r10,
+ 'txt_r_mean': tr_mean,
+ 'img_r1': ir_r1,
+ 'img_r5': ir_r5,
+ 'img_r10': ir_r10,
+ 'img_r_mean': ir_mean,
+ 'r_mean': r_mean}
+ return eval_log
+
+
+@torch.no_grad()
+def evaluate(model, eval_loader):
+ st = time()
+ LOGGER.info("start running Image/Text Retrieval evaluation ...")
+ score_matrix = inference(model, eval_loader)
+ # print('| score_matrix: ', type(score_matrix))
+ dset = eval_loader.dataset
+ all_score = hvd.allgather(score_matrix)
+ all_txt_ids = [i for ids in all_gather_list(dset.ids)
+ for i in ids]
+ all_img_ids = dset.all_img_ids
+ assert all_score.size() == (len(all_txt_ids), len(all_img_ids))
+ if hvd.rank() != 0:
+ return {}
+
+ # NOTE: only use rank0 to compute final scores
+ eval_log = itm_eval(all_score, all_txt_ids, all_img_ids,
+ dset.txt2img, dset.img2txts)
+
+ tot_time = time()-st
+ LOGGER.info(f"evaluation finished in {int(tot_time)} seconds")
+ return eval_log
+
+
+@torch.no_grad()
+def inference(model, eval_loader):
+ model.eval()
+ if hvd.rank() == 0:
+ pbar = tqdm(total=len(eval_loader))
+ else:
+ pbar = NoOp()
+ score_matrix = torch.zeros(len(eval_loader.dataset),
+ len(eval_loader.dataset.all_img_ids),
+ device=torch.device("cuda"),
+ dtype=torch.float16)
+
+ for i, mini_batches in enumerate(eval_loader):
+ j = 0
+ for batch in mini_batches:
+ model_outputs = model(batch, compute_loss=False)
+ if isinstance(model_outputs, dict):
+ scores = model_outputs['rank_scores']
+ else:
+ scores = model_outputs
+ bs = scores.size(0)
+ score_matrix.data[i, j:j+bs] = scores.data.squeeze(1).half()
+ j += bs
+ assert j == score_matrix.size(1)
+ pbar.update(1)
+ model.train()
+ pbar.close()
+ return score_matrix
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000..c0ddb4f
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,108 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+helper for logging
+NOTE: loggers are global objects use with caution
+"""
+import logging
+import math
+
+import tensorboardX
+
+
+_LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
+_DATE_FMT = '%m/%d/%Y %H:%M:%S'
+logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
+LOGGER = logging.getLogger('__main__') # this is the global logger
+
+
+def add_log_to_file(log_path):
+ fh = logging.FileHandler(log_path)
+ formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT)
+ fh.setFormatter(formatter)
+ LOGGER.addHandler(fh)
+
+
+class TensorboardLogger(object):
+ def __init__(self):
+ self._logger = None
+ self._global_step = 0
+
+ def create(self, path):
+ self._logger = tensorboardX.SummaryWriter(path)
+
+ def noop(self, *args, **kwargs):
+ return
+
+ def step(self):
+ self._global_step += 1
+
+ @property
+ def global_step(self):
+ return self._global_step
+
+ def log_scaler_dict(self, log_dict, prefix=''):
+ """ log a dictionary of scalar values"""
+ if self._logger is None:
+ return
+ if prefix:
+ prefix = f'{prefix}_'
+ for name, value in log_dict.items():
+ if isinstance(value, dict):
+ self.log_scaler_dict(value, self._global_step,
+ prefix=f'{prefix}{name}')
+ else:
+ self._logger.add_scalar(f'{prefix}{name}', value,
+ self._global_step)
+
+ def log_histogram_dict(self, log_dict, prefix=''):
+ """ log a dictionary of scalar values"""
+ if self._logger is None:
+ return
+ if prefix:
+ prefix = f'{prefix}_'
+ for name, value in log_dict.items():
+ if isinstance(value, dict):
+ self.log_histogram_dict(value, self._global_step,
+ prefix=f'{prefix}{name}')
+ else:
+ self._logger.add_histogram(f'{prefix}{name}', value,
+ self._global_step)
+
+ def __getattr__(self, name):
+ if self._logger is None:
+ return self.noop
+ return self._logger.__getattribute__(name)
+
+
+TB_LOGGER = TensorboardLogger()
+
+
+class RunningMeter(object):
+ """ running meteor of a scalar value
+ (useful for monitoring training loss)
+ """
+ def __init__(self, name, val=None, smooth=0.99):
+ self._name = name
+ self._sm = smooth
+ self._val = val
+
+ def __call__(self, value):
+ val = (value if self._val is None
+ else value*(1-self._sm) + self._val*self._sm)
+ if not math.isnan(val):
+ self._val = val
+
+ def __str__(self):
+ return f'{self._name}: {self._val:.4f}'
+
+ @property
+ def val(self):
+ if self._val is None:
+ return 0
+ return self._val
+
+ @property
+ def name(self):
+ return self._name
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000..961322a
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,70 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+Misc utilities
+"""
+import json
+import random
+import sys
+
+import torch
+import numpy as np
+
+from utils.logger import LOGGER
+
+
+class NoOp(object):
+ """ useful for distributed training No-Ops """
+ def __getattr__(self, name):
+ return self.noop
+
+ def noop(self, *args, **kwargs):
+ return
+
+
+def parse_with_config(parser):
+ args = parser.parse_args()
+ if args.config is not None:
+ config_args = json.load(open(args.config))
+ override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
+ if arg.startswith('--')}
+ for k, v in config_args.items():
+ if k not in override_keys:
+ setattr(args, k, v)
+ del args.config
+ return args
+
+
+VE_ENT2IDX = {
+ 'contradiction': 0,
+ 'entailment': 1,
+ 'neutral': 2
+}
+
+VE_IDX2ENT = {
+ 0: 'contradiction',
+ 1: 'entailment',
+ 2: 'neutral'
+}
+
+
+class Struct(object):
+ def __init__(self, dict_):
+ self.__dict__.update(dict_)
+
+
+def set_dropout(model, drop_p):
+ for name, module in model.named_modules():
+ # we might want to tune dropout for smaller dataset
+ if isinstance(module, torch.nn.Dropout):
+ if module.p != drop_p:
+ module.p = drop_p
+ LOGGER.info(f'{name} set to {drop_p}')
+
+
+def set_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
diff --git a/utils/save.py b/utils/save.py
new file mode 100644
index 0000000..8d4383b
--- /dev/null
+++ b/utils/save.py
@@ -0,0 +1,77 @@
+"""
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT license.
+
+saving utilities
+"""
+import json
+import os
+from os.path import abspath, dirname, exists, join
+import subprocess
+
+import torch
+
+from utils.logger import LOGGER
+
+
+def save_training_meta(args):
+ if args.rank > 0:
+ return
+
+ if not exists(args.output_dir):
+ os.makedirs(join(args.output_dir, 'log'))
+ os.makedirs(join(args.output_dir, 'ckpt'))
+ if not exists(os.path.join(args.output_dir, 'log')):
+ os.makedirs(join(args.output_dir, 'log'))
+ if not exists(os.path.join(args.output_dir, 'ckpt')):
+ os.makedirs(join(args.output_dir, 'ckpt'))
+
+ with open(join(args.output_dir, 'log', 'hps.json'), 'w') as writer:
+ json.dump(vars(args), writer, indent=4)
+ model_config = json.load(open(args.model_config))
+ with open(join(args.output_dir, 'log', 'model.json'), 'w') as writer:
+ json.dump(model_config, writer, indent=4)
+ # git info
+ try:
+ LOGGER.info("Waiting on git info....")
+ c = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"],
+ timeout=10, stdout=subprocess.PIPE)
+ git_branch_name = c.stdout.decode().strip()
+ LOGGER.info("Git branch: %s", git_branch_name)
+ c = subprocess.run(["git", "rev-parse", "HEAD"],
+ timeout=10, stdout=subprocess.PIPE)
+ git_sha = c.stdout.decode().strip()
+ LOGGER.info("Git SHA: %s", git_sha)
+ git_dir = abspath(dirname(__file__))
+ git_status = subprocess.check_output(
+ ['git', 'status', '--short'],
+ cwd=git_dir, universal_newlines=True).strip()
+ with open(join(args.output_dir, 'log', 'git_info.json'),
+ 'w') as writer:
+ json.dump({'branch': git_branch_name,
+ 'is_dirty': bool(git_status),
+ 'status': git_status,
+ 'sha': git_sha},
+ writer, indent=4)
+ except subprocess.TimeoutExpired as e:
+ LOGGER.exception(e)
+ LOGGER.warn("Git info not found. Moving right along...")
+
+
+class ModelSaver(object):
+ def __init__(self, output_dir, prefix='model_step', suffix='pt'):
+ self.output_dir = output_dir
+ self.prefix = prefix
+ self.suffix = suffix
+
+ def save(self, model, step, optimizer=None):
+ output_model_file = join(self.output_dir,
+ f"{self.prefix}_{step}.{self.suffix}")
+ state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
+ for k, v in model.state_dict().items()}
+ torch.save(state_dict, output_model_file)
+ if optimizer is not None:
+ dump = {'step': step, 'optimizer': optimizer.state_dict()}
+ if hasattr(optimizer, '_amp_stash'):
+ pass # TODO fp16 optimizer
+ torch.save(dump, f'{self.output_dir}/train_state_{step}.pt')