diff --git a/.args.py.swn b/.args.py.swn new file mode 100644 index 0000000..7ed3bf7 Binary files /dev/null and b/.args.py.swn differ diff --git a/.args.py.swo b/.args.py.swo new file mode 100644 index 0000000..ae5f46e Binary files /dev/null and b/.args.py.swo differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a2294f6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,51 @@ +# Byte-compiled / optimized / DLL files +__pycache__ +**/__pycache__ +*.py[cod] +*$py.class +.idea +*.swp +# C extensions +*.so +*.pyc +*._ +*.png +__pycache__ +/venv +/.idea +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Datasets, pretrained models, checkpoints and preprocessed files +data/ +!visdialch/data/ +checkpoints +checkpoints/* +logs/ +results/ +log/ +launcher.sh +data +# IPython Notebook +.ipynb_checkpoints +data +# virtualenv +venv/ +.vscode +.swp diff --git a/args.py b/args.py new file mode 100644 index 0000000..0291066 --- /dev/null +++ b/args.py @@ -0,0 +1,91 @@ +import argparse + + +def get_args(parser, description='MILNCE'): + if parser is None: + parser = argparse.ArgumentParser(description=description) + + parser.add_argument_group('Input modalites arguments') + + parser.add_argument('-input_type', default='Q_DH_V', + choices=['Q_only', 'Q_DH', 'Q_A', 'Q_I', 'Q_V', 'Q_C_I', 'Q_DH_V', 'Q_DH_I', 'Q_V_A', 'Q_DH_V_A'], help='Specify the inputs') + + parser.add_argument_group('Encoder Decoder choice arguments') + parser.add_argument('-encoder', default='lf-ques-im-hist', + choices=['lf-ques-im-hist'], help='Encoder to use for training') + parser.add_argument('-concat_history', default=True, + help='True for lf encoding') + parser.add_argument('-decoder', default='disc', + choices=['disc'], help='Decoder to use for training') + parser.add_argument('-finetune_textEncoder', default=0, type=int, + help= 'Finetune the text encoder') + parser.add_argument_group('Optimization related arguments') + parser.add_argument('-num_epochs', default=45, type=int, help='Epochs') + parser.add_argument('-batch_size', default=12, type=int, help='Batch size') + parser.add_argument('-lr', default=1e-4, type=float, help='Learning rate') + parser.add_argument('-lr_decay_rate', default=0.9, + type=float, help='Decay for lr') + parser.add_argument('-min_lr', default=5e-5, type=float, + help='Minimum learning rate') + parser.add_argument('-weight_init', default='xavier', + choices=['xavier', 'kaiming'], help='Weight initialization strategy') + parser.add_argument('-weight_decay', default=5e-4, + help='Weight decay for l2 regularization') + parser.add_argument('-overfit', action='store_true', + help='Overfit on 5 examples, meant for debugging') + parser.add_argument('-gpuid', default=0, type=int, help='GPU id to use') + + parser.add_argument_group('Checkpointing related arguments') + parser.add_argument('-load_path', default='', + help='Checkpoint to load path from') + parser.add_argument('-save_path', default='checkpoints/', + help='Path to save checkpoints') + parser.add_argument('-save_step', default=4, type=int, + help='Save checkpoint after every save_step epochs') + parser.add_argument('-eval_step', default=100, type=int, + help='Run validation after every eval_step iterations') + parser.add_argument('-input_vid', default="data/charades_s3d_mixed_5c_fps_16_num_frames_40_original_scaled", + help=".h5 file path for the charades s3d features.") + parser.add_argument('-finetune', default=0, type=int, + help="When set true, the model finetunes the s3dg model for video") + + # S3DG parameters and dataloader + parser.add_argument('-num_frames', type=int, default=40, + help='num_frame') + parser.add_argument('-video_size', type=int, default=224, + help='random seed') + parser.add_argument('-fps', type=int, default=16, help='') + parser.add_argument('-crop_only', type=int, default=1, + help='random seed') + parser.add_argument('-center_crop', type=int, default=0, + help='random seed') + parser.add_argument('-random_flip', type=int, default=0, + help='random seed') + parser.add_argument('-video_root', default='data/videos') + parser.add_argument('-unfreeze_layers', default=1, type=int, + help="if 1, unfreezes _5 layers, if 2 unfreezes _4 and _5 layers, if 0, unfreezes all layers") + parser.add_argument("-text_encoder", default="lstm", + help="lstm or transformer", type=str) + parser.add_argument("-use_npy", default=1, type=int, + help="Uses npy instead of reading from videos") + parser.add_argument("-numpy_path", default="data/charades") + parser.add_argument("-num_workers", default=8, type=int) + + parser.add_argument_group('Visualzing related arguments') + parser.add_argument('-enableVis', type=int, default=1) + parser.add_argument('-visEnvName', type=str, default='s3d_Nofinetune') + parser.add_argument('-server', type=str, default='127.0.0.1') + parser.add_argument('-serverPort', type=int, default=8855) + parser.add_argument('-set_cuda_device', type=str, default='') + parser.add_argument("-seed", type=int, default=1, + help="random seed for initialization") + + parser.add_argument('-save_ranks', action='store_true', help='Whether to save retrieved ranks') + parser.add_argument('-use_gt', action='store_true', help='Whether to use ground truth for retriveing ranks') + parser.add_argument('--split', default='test', choices=['val', 'test', 'train'], help='Split to evaluate on') + # ---------------------------------------------------------------------------- + # input arguments and options + # ---------------------------------------------------------------------------- + + args = parser.parse_args() + return args diff --git a/create_npy.py b/create_npy.py new file mode 100644 index 0000000..3249198 --- /dev/null +++ b/create_npy.py @@ -0,0 +1,208 @@ +import argparse +import os +import random + +import cv2 +import ffmpeg +import h5py +import numpy as np +import pandas as pd +import torch +import torch as th +import torch.nn.functional as F +from torch.utils.data import Dataset +from torchvision import io, transforms +from tqdm import tqdm + +random.seed(42) +np.random.seed(42) + + +class Transform(object): + + def __init__(self): + self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) + self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) + + def __call__(self, add_jitter=False, crop_size=224): + transform = transforms.Compose([ + self.random_crop(crop_size), + ]) + return transform + + def to_tensor(self): + return transforms.ToTensor() + + def random_crop(self, size): + return transforms.RandomCrop(size, pad_if_needed=True) + + def colorJitter(self): + return transforms.ColorJitter(0.4, 0.2, 0.2, 0.1) + + +class CustomDataset(Dataset): + + def __init__(self, args, path): + """Initialize the dataset with splits given by 'subsets', where + subsets is taken from ['train', 'val', 'test'] + """ + super(CustomDataset, self).__init__() + self.args = args + self.path = path + self.fl_list = self.get_filenames( + os.path.join(args.video_root, path)) + self.transform = Transform() + + def __len__(self): + return len(self.fl_list) + + def _get_opencv_video(self, video_path): + cap = cv2.VideoCapture(video_path) + cap.set(cv2.CAP_PROP_FPS, 30) + ret, frame = cap.read() + frames = [frame] + while ret: + ret, frame = cap.read() + if frame is not None: + frames.append(frame) + cap.release() + frames_array = np.concatenate(np.expand_dims(frames, 0)) + return frames_array + + def get_filenames(self, path): + results = [] + results += [each for each in os.listdir(path) if each.endswith('.mp4')] + return results + + def _get_video_torch(self, video_path): + vframes, _, vmeta = io.read_video(video_path) + vframes = vframes.permute(0, 3, 1, 2) + vframes = self.transform(self.args.video_size)(vframes) + if vframes.shape[0] < self.args.num_frames: + zeros = th.zeros( + (3, self.args.num_frames - video.shape[0], self.args.video_size, self.args.video_size), dtype=th.uint8) + vframes = th.cat((vframes, zeros), axis=0) + # Gets n_frames from tne entire video, linearly spaced + vid_indices = np.linspace( + 0, vframes.shape[0] - 1, self.args.num_frames, dtype=int) + vid = vframes[vid_indices, :].permute(1, 0, 2, 3) + for i in range(3): + for j in range(vid.shape[1]): + if vid[i, j, :, :].sum() == 0: + print(i, j) + return vid + + def _get_video(self, video_path, start=0, end=0): + ''' + :param video_path: Path of the video file + start: Start time for the video + end: End time. + :return: video: video_frames. + ''' + # start_seek = random.randint(start, int(max(start, end - self.num_sec))) + start_seek = 0 + cmd = ( + ffmpeg + .input(video_path) + .filter('fps', fps=self.args.fps) + ) + if self.args.center_crop: + aw, ah = 0.5, 0.5 + else: + aw, ah = random.uniform(0, 1), random.uniform(0, 1) + if self.args.crop_only: + ''' + Changes from the original code, because we have few videos that have <224 resolution and needs to be scaled up after cropping, and cropping needs to take care of the size of the image which it did not before. + cmd = (cmd.crop('(iw - {})*{}'.format(self.args.video_size, aw), + '(ih - {})*{}'.format(self.args.video_size, ah), + str(self.args.video_size), str(self.args.video_size)) + )''' + cmd = ( + cmd.crop('max(0, (iw-{}))*{}'.format(self.args.video_size, aw), + 'max(0, (ih-{}))*{}'.format(self.args.video_size, ah), + 'min(iw, {})'.format(self.args.video_size), + 'min(ih, {})'.format(self.args.video_size)) + .filter('scale', self.args.video_size, self.args.video_size) + ) + else: + cmd = ( + cmd.crop('(iw - max(0, min(iw,ih)))*{}'.format(aw), + '(ih - max(0, min(iw,ih)))*{}'.format(ah), + 'min(iw,ih)', + 'min(iw,ih)') + .filter('scale', self.args.video_size, self.args.video_size) + ) + if self.args.random_flip and random.uniform(0, 1) > 0.5: + cmd = cmd.hflip() + out, _ = ( + cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) + + video = np.frombuffer(out, np.uint8).reshape( + [-1, self.args.video_size, self.args.video_size, 3]) + video = th.from_numpy(video) + video = video.permute(3, 0, 1, 2) + if video.shape[1] < self.args.num_frames: + zeros = th.zeros( + (3, self.args.num_frames - video.shape[1], self.args.video_size, self.args.video_size), dtype=th.uint8) + video = th.cat((video, zeros), axis=1) + # Gets n_frames from tne entire video, linearly spaced + vid_indices = np.linspace( + 0, video.shape[1]-1, self.args.num_frames, dtype=int) + return video[:, vid_indices] + + def __getitem__(self, idx): + video_file = self.fl_list[idx] + write_file = os.path.join( + self.args.write_path, video_file.replace(".mp4", ".npy")) + video_path = os.path.join( + self.args.video_root, self.path, video_file) + vid = self._get_video_torch(video_path) + np.save(write_file, vid) + return video_file + + +def main(args): + dataloader = torch.utils.data.DataLoader( + CustomDataset(args, args.train_val_path), + batch_size=1, + shuffle=False, drop_last=True) + + dataloader_val = torch.utils.data.DataLoader( + CustomDataset(args, args.test_path), + batch_size=1, + shuffle=False, drop_last=True) + + if args.train: + for i, batch in tqdm(enumerate(dataloader)): + print("train ", batch) + if args.test: + for i, batch in tqdm(enumerate(dataloader_val)): + print("val ", batch) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--num_frames', type=int, default=40, + help='num_frame') + parser.add_argument('--video_root', default='./data/charades/videos') + + parser.add_argument('--write_path', default="./data/charades") + parser.add_argument('--video_size', type=int, default=224, + help='random seed') + parser.add_argument('--fps', type=int, default=16, help='') + parser.add_argument('--crop_only', type=int, default=1, + help='random seed') + parser.add_argument('--center_crop', type=int, default=0, + help='random seed') + parser.add_argument('--random_flip', type=int, default=0, + help='random seed') + parser.add_argument('--train', default=1) + parser.add_argument('--test', default=1) + args = parser.parse_args() + args.train_val_path = "train_val" + args.test_path = "test" + args.write_path += "/num_frames_{}".format(args.num_frames) + os.makedirs(args.write_path, exist_ok=True) + main(args) diff --git a/dataloader.py b/dataloader.py index 9e0e61e..0554f78 100644 --- a/dataloader.py +++ b/dataloader.py @@ -1,15 +1,43 @@ -import os import json -from six import iteritems +import os +import pdb +import random from random import shuffle +from transformers import BertTokenizer +import ffmpeg import h5py +import hdfdict import numpy as np -from tqdm import tqdm - import torch +import torch as th import torch.nn.functional as F +from six import iteritems from torch.utils.data import Dataset +from torchvision import io, transforms +from tqdm import tqdm + + +class Transform(object): + + def __init__(self): + self.mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) + self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) + + def __call__(self, add_jitter=False, crop_size=224): + transform = transforms.Compose([ + self.random_crop(crop_size) + ]) + return transform + + def to_tensor(self): + return transforms.ToTensor() + + def random_crop(self, size): + return transforms.RandomCrop(size, pad_if_needed=True) + + def colorJitter(self): + return transforms.ColorJitter(0.4, 0.2, 0.2, 0.1) class VisDialDataset(Dataset): @@ -17,12 +45,18 @@ class VisDialDataset(Dataset): @staticmethod def add_cmdline_args(parser): parser.add_argument_group('Dataloader specific arguments') - parser.add_argument('-input_img', default='data/data_img.h5', help='HDF5 file with image features') - parser.add_argument('-input_vid', default='data/data_video.h5', help='HDF5 file with image features') - parser.add_argument('-input_audio', default='data/data_audio.h5', help='HDF5 file with audio features') - parser.add_argument('-input_ques', default='data/dialogs.h5', help='HDF5 file with preprocessed questions') - parser.add_argument('-input_json', default='data/params.json', help='JSON file with image paths and vocab') - parser.add_argument('-img_norm', default=1, choices=[1, 0], help='normalize the image feature. 1=yes, 0=no') + parser.add_argument( + '-input_img', default='data/data_img.h5', help='HDF5 file with image features') + # parser.add_argument( + # '-input_vid', default='data/data_video.h5', help='HDF5 file with video features') + parser.add_argument( + '-input_audio', default='data/data_audio.h5', help='HDF5 file with audio features') + parser.add_argument('-input_ques', default='data/dialogs.h5', + help='HDF5 file with preprocessed questions') + parser.add_argument('-input_json', default='data/params.json', + help='JSON file with image paths and vocab') + parser.add_argument( + '-img_norm', default=1, choices=[1, 0], help='normalize the image feature. 1=yes, 0=no') return parser def __init__(self, args, subsets): @@ -32,6 +66,7 @@ def __init__(self, args, subsets): super(VisDialDataset, self).__init__() self.args = args self.subsets = tuple(subsets) + self.transform = Transform() print("Dataloader loading json file: {}".format(args.input_json)) with open(args.input_json, 'r') as info_file: @@ -60,16 +95,15 @@ def __init__(self, args, subsets): print("Dataloader loading h5 file: {}".format(args.input_ques)) ques_file = h5py.File(args.input_ques, 'r') - - if 'image' in args.input_type: + if 'I' in args.input_type: print("Dataloader loading h5 file: {}".format(args.input_img)) img_file = h5py.File(args.input_img, 'r') - if 'video' in args.input_type: + if 'V' in args.input_type: print("Dataloader loading h5 file: {}".format(args.input_vid)) - vid_file = h5py.File(args.input_vid, 'r') + vid_file = args.input_vid - if 'audio' in args.input_type: + if 'A' in args.input_type: print("Dataloader loading h5 file: {}".format(args.input_audio)) audio_file = h5py.File(args.input_audio, 'r') @@ -102,18 +136,26 @@ def __init__(self, args, subsets): self.data[save_label.format(dtype)] = torch.from_numpy( np.array(ques_file[load_label.format(dtype)], dtype='int64')) - if 'video' in args.input_type: + if 'V' in args.input_type: print("Reading video features...") - vid_feats = torch.from_numpy(np.array(vid_file['images_' + dtype])) + + # Charades dataset features are all saved in one h5 file as a key, feat dictionary + # vid_feats = hdfdict.load( + # args.input_vid + "_{0}.h5".format(dtype)) + # If this throws an error because it cannot find the video filename,uncomment below + vid_feats = hdfdict.load( + args.input_vid + "_{0}.h5".format("train")) + vid_feats.update(hdfdict.load( + args.input_vid + "_{0}.h5".format("test"))) img_fnames = getattr(self, 'unique_img_' + dtype) self.data[dtype + '_img_fnames'] = img_fnames self.data[dtype + '_vid_fv'] = vid_feats - - if 'image' in args.input_type: + if 'I' in args.input_type: print("Reading image features...") - img_feats = torch.from_numpy(np.array(img_file['images_' + dtype])) + img_feats = torch.from_numpy( + np.array(img_file['images_' + dtype])) if args.img_norm: print("Normalizing image features...") @@ -123,9 +165,10 @@ def __init__(self, args, subsets): self.data[dtype + '_img_fnames'] = img_fnames self.data[dtype + '_img_fv'] = img_feats - if 'audio' in args.input_type: + if 'A' in args.input_type: print("Reading audio features...") - audio_feats = torch.from_numpy(np.array(audio_file['images_' + dtype])) + audio_feats = torch.from_numpy( + np.array(audio_file['images_' + dtype])) audio_feats = F.normalize(audio_feats, dim=1, p=2) self.data[dtype + '_audio_fv'] = audio_feats @@ -139,20 +182,18 @@ def __init__(self, args, subsets): self.max_ans_len = self.data[dtype + '_ans'].size(2) # reduce amount of data for preprocessing in fast mode - #TODO - if args.overfit: - print('\n \n \n ---------->> NOT IMPLEMENTED OVERFIT CASE <-----\n \n \n ') - + # TODO self.num_data_points = {} for dtype in subsets: self.num_data_points[dtype] = len(self.data[dtype + '_ques']) - print("[{0}] no. of threads: {1}".format(dtype, self.num_data_points[dtype])) + print("[{0}] no. of threads: {1}".format( + dtype, self.num_data_points[dtype])) print("\tMax no. of rounds: {}".format(self.max_ques_count)) print("\tMax ques len: {}".format(self.max_ques_len)) print("\tMax ans len: {}".format(self.max_ans_len)) # prepare history - if 'dialog' in args.input_type or 'caption' in args.input_type: + if 'DH' in args.input_type or 'C' in args.input_type: for dtype in subsets: self._process_history(dtype) @@ -168,6 +209,10 @@ def __init__(self, args, subsets): else: self._split = subsets[0] + if args.overfit: + self.num_data_points['train'] = 5 + self.num_data_points['val'] = 5 + @property def split(self): return self._split @@ -184,27 +229,121 @@ def split(self, split): def __len__(self): return self.num_data_points[self._split] + def _get_video_torch(self, video_path): + vframes, _, vmeta = io.read_video(video_path) + vframes = vframes.permute(0, 3, 1, 2) + vframes = self.transform(self.args.video_size)(vframes) + if vframes.shape[0] < self.args.num_frames: + zeros = th.zeros( + (3, self.args.num_frames - video.shape[0], self.args.video_size, self.args.video_size), dtype=th.uint8) + vframes = th.cat((vframes, zeros), axis=0) + # Gets n_frames from tne entire video, linearly spaced + vid_indices = np.linspace( + 0, vframes.shape[0] - 1, self.args.num_frames, dtype=int) + vid = vframes[vid_indices, :].permute(1, 0, 2, 3) + return vid + + def _get_video(self, video_path, start=0, end=0): + ''' + :param video_path: Path of the video file + start: Start time for the video + end: End time. + :return: video: video_frames. + ''' + # start_seek = random.randint(start, int(max(start, end - self.num_sec))) + start_seek = 0 + cmd = ( + ffmpeg + .input(video_path) + .filter('fps', fps=self.args.fps) + ) + if self.args.center_crop: + aw, ah = 0.5, 0.5 + else: + aw, ah = random.uniform(0, 1), random.uniform(0, 1) + if self.args.crop_only: + ''' + Changes from the original code, because we have few videos that have <224 resolution and needs to be scaled up after cropping, and cropping needs to take care of the size of the image which it did not before. + cmd = (cmd.crop('(iw - {})*{}'.format(self.args.video_size, aw), + '(ih - {})*{}'.format(self.args.video_size, ah), + str(self.args.video_size), str(self.args.video_size)) + )''' + cmd = ( + cmd.crop('max(0, (iw - {}))*{}'.format(self.args.video_size, aw), + 'max(0, (ih - {}))*{}'.format(self.args.video_size, ah), + 'min(iw, {})'.format(self.args.video_size), + 'min(ih, {})'.format(self.args.video_size)) + .filter('scale', self.args.video_size, self.args.video_size) + ) + else: + cmd = ( + cmd.crop('(iw - max(0, min(iw,ih)))*{}'.format(aw), + '(ih - max(0, min(iw,ih)))*{}'.format(ah), + 'min(iw,ih)', + 'min(iw,ih)') + .filter('scale', self.args.video_size, self.args.video_size) + ) + if self.args.random_flip and random.uniform(0, 1) > 0.5: + cmd = cmd.hflip() + out, _ = ( + cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) + + video = np.frombuffer(out, np.uint8).reshape( + [-1, self.args.video_size, self.args.video_size, 3]) + video = th.from_numpy(video) + video = video.permute(3, 0, 1, 2) + if video.shape[1] < self.args.num_frames: + zeros = th.zeros( + (3, self.args.num_frames - video.shape[1], self.args.video_size, self.args.video_size), dtype=th.uint8) + video = th.cat((video, zeros), axis=1) + # Gets n_frames from tne entire video, linearly spaced + vid_indices = np.linspace( + 0, video.shape[1]-1, self.args.num_frames, dtype=int) + return video[:, vid_indices] + def __getitem__(self, idx): + dtype = self._split item = {'index': idx} item['num_rounds'] = self.data[dtype + '_num_rounds'][idx] # get video features - if 'video' in self.args.input_type: - item['vid_feat'] = self.data[dtype + '_vid_fv'][idx] + if 'V' in self.args.input_type: item['img_fnames'] = self.data[dtype + '_img_fnames'][idx] + # item['img_fnames'] is as train_val/vid_id.jpg hence the splits + vid_id = item['img_fnames'].split("/")[-1].split(".")[0] + if ".mp4" not in vid_id: + vid_id = vid_id + ".mp4" + + if self.args.finetune: + f_dtype = "train_val" + if dtype == "test": + f_dtype = "test" + if self.args.use_npy: + video_path = os.path.join(self.args.numpy_path, vid_id) + item['vid_feat'] = torch.from_numpy(np.load( + video_path.replace(".mp4", ".npy"))) + else: + video_path = os.path.join( + self.args.video_root, f_dtype, vid_id) + item['vid_feat'] = self._get_video()(video_path) + else: + item['vid_feat'] = torch.from_numpy( + self.data[dtype + '_vid_fv'][vid_id]).reshape(-1) # get image features - if 'image' in self.args.input_type: + if 'I' in self.args.input_type: item['img_feat'] = self.data[dtype + '_img_fv'][idx] - item['img_fnames'] = self.data[dtype + '_img_fnames'][idx] + item['img_fnames'] = [self.data[dtype + '_img_fnames'][idx]] # get audio features - if 'audio' in self.args.input_type: + if 'A' in self.args.input_type: item['audio_feat'] = self.data[dtype + '_audio_fv'][idx] # get history tokens - if 'dialog' in self.args.input_type or 'caption' in self.args.input_type: + if 'DH' in self.args.input_type or 'caption' in self.args.input_type: item['hist_len'] = self.data[dtype + '_hist_len'][idx] item['hist_len'][item['hist_len'] == 0] += 1 item['hist'] = self.data[dtype + '_hist'][idx] @@ -226,20 +365,20 @@ def __getitem__(self, idx): item['opt'] = option_in item['opt_len'] = opt_len - #if dtype != 'test': + # if dtype != 'test': ans_ind = self.data[dtype + '_ans_ind'][idx] item['ans_ind'] = ans_ind.view(-1) # convert zero length sequences to one length # this is for handling empty rounds of v1.0 test, they will be dropped anyway - #if dtype == 'test': + # if dtype == 'test': item['ques_len'][item['ques_len'] == 0] += 1 item['opt_len'][item['opt_len'] == 0] += 1 return item - #------------------------------------------------------------------------- + # ------------------------------------------------------------------------- # collate function utilized by dataloader for batching - #------------------------------------------------------------------------- + # ------------------------------------------------------------------------- def collate_fn(self, batch): dtype = self._split @@ -254,15 +393,18 @@ def collate_fn(self, batch): out[key] = torch.stack(merged_batch[key], 0) # Dynamic shaping of padded batch if 'hist' in out: - out['hist'] = out['hist'][:, :, :torch.max(out['hist_len'])].contiguous() - out['ques'] = out['ques'][:, :, :torch.max(out['ques_len'])].contiguous() - out['opt'] = out['opt'][:, :, :, :torch.max(out['opt_len'])].contiguous() + out['hist'] = out['hist'][:, :, :torch.max( + out['hist_len'])].contiguous() + out['ques'] = out['ques'][:, :, :torch.max( + out['ques_len'])].contiguous() + out['opt'] = out['opt'][:, :, :, :torch.max( + out['opt_len'])].contiguous() return out - #------------------------------------------------------------------------- + # ------------------------------------------------------------------------- # preprocessing functions - #------------------------------------------------------------------------- + # ------------------------------------------------------------------------- def _process_history(self, dtype): """Process caption as well as history. Optionally, concatenate history @@ -278,13 +420,16 @@ def _process_history(self, dtype): num_convs, num_rounds, max_ans_len = answers.size() if self.args.concat_history: - self.max_hist_len = min(num_rounds * (max_ques_len + max_ans_len), 400) - history = torch.zeros(num_convs, num_rounds, self.max_hist_len).long() + self.max_hist_len = min( + num_rounds * (max_ques_len + max_ans_len), 400) + history = torch.zeros(num_convs, num_rounds, + self.max_hist_len).long() else: - history = torch.zeros(num_convs, num_rounds, max_ques_len + max_ans_len).long() + history = torch.zeros(num_convs, num_rounds, + max_ques_len + max_ans_len).long() hist_len = torch.zeros(num_convs, num_rounds).long() - if 'dialog' in self.args.input_type: + if 'DH' in self.args.input_type: # go over each question and append it with answer for th_id in range(num_convs): clen = cap_len[th_id] @@ -319,13 +464,14 @@ def _process_history(self, dtype): hlen = alen + qlen # save the history length hist_len[th_id][round_id] = hlen - else: # -- caption only + else: # -- caption only # go over each question and append it with answer for th_id in range(num_convs): clen = cap_len[th_id] hlen = min(clen, max_ques_len + max_ans_len) for round_id in range(num_rounds): - history[th_id][round_id][:max_ques_len + max_ans_len] = captions[th_id][:max_ques_len + max_ans_len] + history[th_id][round_id][:max_ques_len + + max_ans_len] = captions[th_id][:max_ques_len + max_ans_len] hist_len[th_id][round_id] = hlen self.data[dtype + '_hist'] = history diff --git a/decoders/__init__.pyc b/decoders/__init__.pyc deleted file mode 100644 index a164c31..0000000 Binary files a/decoders/__init__.pyc and /dev/null differ diff --git a/decoders/__pycache__/__init__.cpython-35.pyc b/decoders/__pycache__/__init__.cpython-35.pyc deleted file mode 100644 index faff7c2..0000000 Binary files a/decoders/__pycache__/__init__.cpython-35.pyc and /dev/null differ diff --git a/decoders/__pycache__/__init__.cpython-36.pyc b/decoders/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index c38bdf3..0000000 Binary files a/decoders/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/decoders/__pycache__/disc.cpython-35.pyc b/decoders/__pycache__/disc.cpython-35.pyc deleted file mode 100644 index b212b0c..0000000 Binary files a/decoders/__pycache__/disc.cpython-35.pyc and /dev/null differ diff --git a/decoders/__pycache__/disc.cpython-36.pyc b/decoders/__pycache__/disc.cpython-36.pyc deleted file mode 100644 index 9004661..0000000 Binary files a/decoders/__pycache__/disc.cpython-36.pyc and /dev/null differ diff --git a/decoders/__pycache__/disc_realdialogs.cpython-36.pyc b/decoders/__pycache__/disc_realdialogs.cpython-36.pyc deleted file mode 100644 index 52b7415..0000000 Binary files a/decoders/__pycache__/disc_realdialogs.cpython-36.pyc and /dev/null differ diff --git a/decoders/disc.py b/decoders/disc.py index 3fadae3..1f9da43 100644 --- a/decoders/disc.py +++ b/decoders/disc.py @@ -10,7 +10,8 @@ def __init__(self, args, encoder): self.args = args # share word embedding self.word_embed = encoder.word_embed - self.option_rnn = nn.LSTM(args.embed_size, args.rnn_hidden_size, batch_first=True) + self.option_rnn = nn.LSTM( + args.embed_size, args.rnn_hidden_size, batch_first=True) self.log_softmax = nn.LogSoftmax(dim=1) # options are variable length padded sequences, use DynamicRNN @@ -30,22 +31,35 @@ def forward(self, enc_out, batch): options = batch['opt'] options_len = batch['opt_len'] # word embed options - options = options.view(options.size(0) * options.size(1), options.size(2), -1) - options_len = options_len.view(options_len.size(0) * options_len.size(1), -1) - batch_size, num_options, max_opt_len = options.size() - options = options.contiguous().view(-1, num_options * max_opt_len) - options = self.word_embed(options) - options = options.view(batch_size, num_options, max_opt_len, -1) + if self.args.text_encoder == 'BERT': + batch_size, rounds, num_options, num_words = options.size() + options_embeds = torch.zeros([batch_size * rounds, num_options, num_words, self.args.embed_size], + dtype=torch.float) + options = options.view(batch_size*rounds, num_options, -1) + for i in range(batch_size*rounds): + options_embeds[i,:] = self.word_embed(options[i])['last_hidden_state'] + #options_embeds[i, :] = opt_embed + options_embeds = options_embeds.view(batch_size * rounds, num_options, num_words, -1) + + else: + options = options.view(options.size(0) * options.size(1), options.size(2), -1) + batch_size, num_options, max_opt_len = options.size() + options = options.contiguous().view(-1, num_options * max_opt_len) + options_embeds = self.word_embed(options) + options_embeds = options_embeds.view(batch_size, num_options, max_opt_len, -1) + + options_len = options_len.view(options_len.size(0) * options_len.size(1), -1) # score each option scores = [] for opt_id in range(num_options): - opt = options[:, opt_id, :, :] + opt = options_embeds[:, opt_id, :, :] opt_len = options_len[:, opt_id] - opt_embed = self.option_rnn(opt, opt_len) + device = opt_len.device + opt_embed = self.option_rnn(opt.to(device), opt_len) scores.append(torch.sum(opt_embed * enc_out, 1)) scores = torch.stack(scores, 1) return scores #log_probs = self.log_softmax(scores) - #return log_probs + # return log_probs diff --git a/decoders/disc.pyc b/decoders/disc.pyc deleted file mode 100644 index 8c992fa..0000000 Binary files a/decoders/disc.pyc and /dev/null differ diff --git a/encoders/__init__.py b/encoders/__init__.py index 6dccef4..4c92256 100644 --- a/encoders/__init__.py +++ b/encoders/__init__.py @@ -1,3 +1,4 @@ +from .s3dg_video import S3D from .lf import LateFusionEncoder @@ -6,4 +7,3 @@ def Encoder(model_args): 'lf-ques-im-hist': LateFusionEncoder } return name_enc_map[model_args.encoder](model_args) - diff --git a/encoders/__init__.pyc b/encoders/__init__.pyc deleted file mode 100644 index 1f9c758..0000000 Binary files a/encoders/__init__.pyc and /dev/null differ diff --git a/encoders/__pycache__/__init__.cpython-36.pyc b/encoders/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index 0661356..0000000 Binary files a/encoders/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/encoders/__pycache__/lf.cpython-36.pyc b/encoders/__pycache__/lf.cpython-36.pyc deleted file mode 100644 index 7cc826a..0000000 Binary files a/encoders/__pycache__/lf.cpython-36.pyc and /dev/null differ diff --git a/encoders/lf.py b/encoders/lf.py index bd6af23..76eb4c0 100644 --- a/encoders/lf.py +++ b/encoders/lf.py @@ -1,22 +1,30 @@ import torch from torch import nn from torch.nn import functional as F - from utils import DynamicRNN - +from utils import DynamicRNN +from encoders.s3dg_video import S3D +from transformers import BertTokenizer, BertModel class LateFusionEncoder(nn.Module): @staticmethod def add_cmdline_args(parser): parser.add_argument_group('Encoder specific arguments') - parser.add_argument('-img_feature_size', default=4096, help='Channel size of image feature') - parser.add_argument('-vid_feature_size', default=4096, help='Channel size of video feature') - parser.add_argument('-audio_feature_size', default=4096, help='Channel size of audio feature') - parser.add_argument('-embed_size', default=300, help='Size of the input word embedding') - parser.add_argument('-rnn_hidden_size', default=512, help='Size of the multimodal embedding') - parser.add_argument('-num_layers', default=2, help='Number of layers in LSTM') - parser.add_argument('-max_history_len', default=60, help='Size of the multimodal embedding') + parser.add_argument('-img_feature_size', default=1024, + help='Channel size of image feature') + parser.add_argument('-vid_feature_size', default=1024, + help='Channel size of video feature') + parser.add_argument('-audio_feature_size', default=1024, + help='Channel size of audio feature') + parser.add_argument('-embed_size', default=300, + help='Size of the input word embedding') + parser.add_argument('-rnn_hidden_size', default=512, + help='Size of the multimodal embedding') + parser.add_argument('-num_layers', default=2, + help='Number of layers in LSTM') + parser.add_argument('-max_history_len', default=60, + help='Size of the multimodal embedding') parser.add_argument('-dropout', default=0.5, help='Dropout') return parser @@ -24,102 +32,160 @@ def __init__(self, args): super().__init__() self.args = args - self.word_embed = nn.Embedding(args.vocab_size, args.embed_size, padding_idx=0) - - if 'dialog' in args.input_type or 'caption' in args.input_type: - self.hist_rnn = nn.LSTM(args.embed_size, args.rnn_hidden_size, args.num_layers, - batch_first=True, dropout=args.dropout) + if args.text_encoder == 'lstm': + args.embed_size = 300 + self.word_embed = nn.Embedding( + args.vocab_size, args.embed_size, padding_idx=0) + else: + freeze_embeddings = True + args.embed_size = 768 + self.word_embed = BertModel.from_pretrained('bert-base-uncased') + # Freeze all the layers and use bert to encode the text for now + if not self.args.finetune_textEncoder: + print('Freezing all bert layers') + for param in self.word_embed.parameters(): + param.requires_grad = False + else: + print('Finetuning text encoder layers') + for param in list(self.word_embed.embeddings.parameters()): + param.requires_grad = False + print("Froze Embedding Layer") + freeze_layers = "0,1,2,3,4,5,6,7,8,9" + layer_indexes = [int(x) for x in freeze_layers.split(",")] + for layer_idx in layer_indexes: + for param in list(self.word_embed.encoder.layer[layer_idx].parameters()): + param.requires_grad = False + print("Froze Layer: ", layer_idx) + for name, param in self.word_embed.named_parameters(): + print(name, param.requires_grad) + if self.args.finetune: + self.video_embed = S3D( + dict_path='data/s3d_dict.npy', space_to_depth=True) + self.video_embed.load_state_dict( + torch.load('data/s3d_howto100m.pth'), strict=False) + + self.video_embed.train() + if self.args.unfreeze_layers: + self.__freeze_s3dg_layers() + + if 'DH' in args.input_type or 'C' in args.input_type: + self.hist_rnn = nn.LSTM(args.embed_size, args.rnn_hidden_size, + args.num_layers, batch_first=True, dropout=args.dropout) self.hist_rnn = DynamicRNN(self.hist_rnn) - - self.ques_rnn = nn.LSTM(args.embed_size, args.rnn_hidden_size, args.num_layers, - batch_first=True, dropout=args.dropout) + + self.ques_rnn = nn.LSTM(args.embed_size, args.rnn_hidden_size, + args.num_layers, batch_first=True, + dropout=args.dropout) # questions and history are right padded sequences of variable length # use the DynamicRNN utility module to handle them properly self.ques_rnn = DynamicRNN(self.ques_rnn) self.dropout = nn.Dropout(p=args.dropout) # fusion layer - if args.input_type == 'question_only': + if args.input_type == 'Q_only': fusion_size = args.rnn_hidden_size - if args.input_type == 'question_dialog': - fusion_size =args.rnn_hidden_size * 2 - if args.input_type == 'question_audio': - fusion_size =args.rnn_hidden_size + args.audio_feature_size - if args.input_type == 'question_image' or args.input_type=='question_video': - fusion_size = args.img_feature_size + args.rnn_hidden_size - if args.input_type == 'question_caption_image' or args.input_type=='question_dialog_video' or args.input_type=='question_dialog_image': + if args.input_type == 'Q_DH': + fusion_size = args.rnn_hidden_size * 2 + if args.input_type == 'Q_A': + fusion_size = args.rnn_hidden_size + args.audio_feature_size + if args.input_type == 'Q_I' or args.input_type == 'Q_V': + fusion_size = args.img_feature_size + args.rnn_hidden_size + if args.input_type == 'Q_C_I' or args.input_type == 'Q_DH_V' or args.input_type == 'Q_DH_I': fusion_size = args.img_feature_size + args.rnn_hidden_size * 2 - if args.input_type == 'question_video_audio': - fusion_size = args.img_feature_size + args.rnn_hidden_size + args.audio_feature_size - if args.input_type == 'question_dialog_video_audio': - fusion_size = args.img_feature_size + args.rnn_hidden_size * 2 + args.audio_feature_size - + if args.input_type == 'Q_V_A': + fusion_size = args.img_feature_size + \ + args.rnn_hidden_size + args.audio_feature_size + if args.input_type == 'Q_DH_V_A': + fusion_size = args.img_feature_size + \ + args.rnn_hidden_size * 2 + args.audio_feature_size + self.fusion = nn.Linear(fusion_size, args.rnn_hidden_size) if args.weight_init == 'xavier': - nn.init.xavier_uniform(self.fusion.weight.data) + nn.init.xavier_uniform_(self.fusion.weight.data) elif args.weight_init == 'kaiming': - nn.init.kaiming_uniform(self.fusion.weight.data) - nn.init.constant(self.fusion.bias.data, 0) + nn.init.kaiming_uniform_(self.fusion.weight.data) + nn.init.constant_(self.fusion.bias.data, 0) + + def __freeze_s3dg_layers(self): + # Only train _4 and _5 layers + layers = ["mixed_5c"] + if self.args.unfreeze_layers == 2: + layers = ["mixed_5b", "mixed_5c"] + for name, param in self.video_embed.named_parameters(): + param.requires_grad = False + if any(l in name for l in layers): + param.requires_grad = True def forward(self, batch): - if 'image' in self.args.input_type: + if 'I' in self.args.input_type: img = batch['img_feat'] # repeat image feature vectors to be provided for every round img = img.view(-1, 1, self.args.img_feature_size) img = img.repeat(1, self.args.max_ques_count, 1) img = img.view(-1, self.args.img_feature_size) - - if 'audio' in self.args.input_type: + + if 'A' in self.args.input_type: audio = batch['audio_feat'] # repeat audio feature vectors to be provided for every round audio = audio.view(-1, 1, self.args.audio_feature_size) audio = audio.repeat(1, self.args.max_ques_count, 1) audio = audio.view(-1, self.args.audio_feature_size) - if 'video' in self.args.input_type: - vid = batch['vid_feat'] + if 'V' in self.args.input_type: + if self.args.finetune: + # In this case, vid_feat has video frames.Multiplication by 255 because s3d video frames are normalised + vid = self.video_embed(batch['vid_feat'].float())[ + "mixed_5c"] * 255.0 + else: + vid = batch['vid_feat'] * 255.0 # repeat image feature vectors to be provided for every round vid = vid.view(-1, 1, self.args.vid_feature_size) vid = vid.repeat(1, self.args.max_ques_count, 1) vid = vid.view(-1, self.args.vid_feature_size) - - if 'dialog' in self.args.input_type or 'caption' in self.args.input_type: + + if 'DH' in self.args.input_type or 'C' in self.args.input_type: hist = batch['hist'] # embed history hist = hist.view(-1, hist.size(2)) hist_embed = self.word_embed(hist) + + if self.args.text_encoder == 'BERT': + hist_embed = hist_embed['last_hidden_state'] + hist_embed = self.hist_rnn(hist_embed, batch['hist_len']) - + ques = batch['ques'] # embed questions ques = ques.view(-1, ques.size(2)) ques_embed = self.word_embed(ques) + if self.args.text_encoder == 'BERT': + ques_embed = ques_embed['last_hidden_state'] ques_embed = self.ques_rnn(ques_embed, batch['ques_len']) - - if self.args.input_type == 'question_only': + + if self.args.input_type == 'Q_only': fused_vector = ques_embed - if self.args.input_type == 'question_dialog': + if self.args.input_type == 'Q_DH': fused_vector = torch.cat((ques_embed, hist_embed), 1) - if self.args.input_type == 'question_audio': + if self.args.input_type == 'Q_A': fused_vector = torch.cat((audio, ques_embed), 1) - if self.args.input_type == 'question_image': + if self.args.input_type == 'Q_I': fused_vector = torch.cat((img, ques_embed), 1) - if self.args.input_type=='question_video': + if self.args.input_type == 'Q_V': fused_vector = torch.cat((vid, ques_embed), 1) - if self.args.input_type=='question_dialog_image': + if self.args.input_type == 'Q_DH_I': fused_vector = torch.cat((img, ques_embed, hist_embed), 1) - if self.args.input_type == 'question_dialog_video': + if self.args.input_type == 'Q_DH_V': fused_vector = torch.cat((vid, ques_embed, hist_embed), 1) - if self.args.input_type == 'question_caption_image': + if self.args.input_type == 'Q_C_I': fused_vector = torch.cat((img, ques_embed, hist_embed), 1) - if self.args.input_type == 'question_video_audio': + if self.args.input_type == 'Q_V_A': fused_vector = torch.cat((vid, audio, ques_embed), 1) - if self.args.input_type == 'question_dialog_video_audio': + if self.args.input_type == 'Q_DH_V_A': fused_vector = torch.cat((vid, audio, ques_embed, hist_embed), 1) - + fused_vector = self.dropout(fused_vector) - fused_embedding = F.tanh(self.fusion(fused_vector)) + fused_embedding = torch.tanh(self.fusion(fused_vector)) return fused_embedding diff --git a/encoders/lf.pyc b/encoders/lf.pyc deleted file mode 100644 index b64fd6b..0000000 Binary files a/encoders/lf.pyc and /dev/null differ diff --git a/encoders/s3dg_video.py b/encoders/s3dg_video.py new file mode 100644 index 0000000..dbc2665 --- /dev/null +++ b/encoders/s3dg_video.py @@ -0,0 +1,295 @@ + +"""Contains a PyTorch definition for Gated Separable 3D network (S3D-G) +with a text module for computing joint text-video embedding from raw text +and video input. The following code will enable you to load the HowTo100M +pretrained S3D Text-Video model from: + A. Miech, J.-B. Alayrac, L. Smaira, I. Laptev, J. Sivic and A. Zisserman, + End-to-End Learning of Visual Representations from Uncurated Instructional Videos. + https://arxiv.org/abs/1912.06430. + +S3D-G was proposed by: + S. Xie, C. Sun, J. Huang, Z. Tu and K. Murphy, + Rethinking Spatiotemporal Feature Learning For Video Understanding. + https://arxiv.org/abs/1712.04851. + Tensorflow code: https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py + +The S3D architecture was slightly modified with a space to depth trick for TPU +optimization. +""" + +import os +import re + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + + +class InceptionBlock(nn.Module): + def __init__( + self, + input_dim, + num_outputs_0_0a, + num_outputs_1_0a, + num_outputs_1_0b, + num_outputs_2_0a, + num_outputs_2_0b, + num_outputs_3_0b, + gating=True, + ): + super(InceptionBlock, self).__init__() + self.conv_b0 = STConv3D(input_dim, num_outputs_0_0a, [1, 1, 1]) + self.conv_b1_a = STConv3D(input_dim, num_outputs_1_0a, [1, 1, 1]) + self.conv_b1_b = STConv3D( + num_outputs_1_0a, num_outputs_1_0b, [3, 3, 3], padding=1, separable=True + ) + self.conv_b2_a = STConv3D(input_dim, num_outputs_2_0a, [1, 1, 1]) + self.conv_b2_b = STConv3D( + num_outputs_2_0a, num_outputs_2_0b, [3, 3, 3], padding=1, separable=True + ) + self.maxpool_b3 = th.nn.MaxPool3d((3, 3, 3), stride=1, padding=1) + self.conv_b3_b = STConv3D(input_dim, num_outputs_3_0b, [1, 1, 1]) + self.gating = gating + self.output_dim = ( + num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b + num_outputs_3_0b + ) + if gating: + self.gating_b0 = SelfGating(num_outputs_0_0a) + self.gating_b1 = SelfGating(num_outputs_1_0b) + self.gating_b2 = SelfGating(num_outputs_2_0b) + self.gating_b3 = SelfGating(num_outputs_3_0b) + + def forward(self, input): + """Inception block + """ + b0 = self.conv_b0(input) + b1 = self.conv_b1_a(input) + b1 = self.conv_b1_b(b1) + b2 = self.conv_b2_a(input) + b2 = self.conv_b2_b(b2) + b3 = self.maxpool_b3(input) + b3 = self.conv_b3_b(b3) + if self.gating: + b0 = self.gating_b0(b0) + b1 = self.gating_b1(b1) + b2 = self.gating_b2(b2) + b3 = self.gating_b3(b3) + return th.cat((b0, b1, b2, b3), dim=1) + + +class SelfGating(nn.Module): + def __init__(self, input_dim): + super(SelfGating, self).__init__() + self.fc = nn.Linear(input_dim, input_dim) + + def forward(self, input_tensor): + """Feature gating as used in S3D-G. + """ + spatiotemporal_average = th.mean(input_tensor, dim=[2, 3, 4]) + weights = self.fc(spatiotemporal_average) + weights = th.sigmoid(weights) + return weights[:, :, None, None, None] * input_tensor + + +class STConv3D(nn.Module): + def __init__( + self, input_dim, output_dim, kernel_size, stride=1, padding=0, separable=False + ): + super(STConv3D, self).__init__() + self.separable = separable + self.relu = nn.ReLU(inplace=True) + assert len(kernel_size) == 3 + if separable and kernel_size[0] != 1: + spatial_kernel_size = [1, kernel_size[1], kernel_size[2]] + temporal_kernel_size = [kernel_size[0], 1, 1] + if isinstance(stride, list) and len(stride) == 3: + spatial_stride = [1, stride[1], stride[2]] + temporal_stride = [stride[0], 1, 1] + else: + spatial_stride = [1, stride, stride] + temporal_stride = [stride, 1, 1] + if isinstance(padding, list) and len(padding) == 3: + spatial_padding = [0, padding[1], padding[2]] + temporal_padding = [padding[0], 0, 0] + else: + spatial_padding = [0, padding, padding] + temporal_padding = [padding, 0, 0] + if separable: + self.conv1 = nn.Conv3d( + input_dim, + output_dim, + kernel_size=spatial_kernel_size, + stride=spatial_stride, + padding=spatial_padding, + bias=False, + ) + self.bn1 = nn.BatchNorm3d(output_dim) + self.conv2 = nn.Conv3d( + output_dim, + output_dim, + kernel_size=temporal_kernel_size, + stride=temporal_stride, + padding=temporal_padding, + bias=False, + ) + self.bn2 = nn.BatchNorm3d(output_dim) + else: + self.conv1 = nn.Conv3d( + input_dim, + output_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + self.bn1 = nn.BatchNorm3d(output_dim) + + def forward(self, input): + out = self.relu(self.bn1(self.conv1(input))) + if th.isnan(self.conv1(input)).any(): + print("conv1 is the issue") + if self.separable: + out = self.relu(self.bn2(self.conv2(out))) + if th.isnan(out).any(): + print("isnan") + return out + + +class MaxPool3dTFPadding(th.nn.Module): + def __init__(self, kernel_size, stride=None, padding="SAME"): + super(MaxPool3dTFPadding, self).__init__() + if padding == "SAME": + padding_shape = self._get_padding_shape(kernel_size, stride) + self.padding_shape = padding_shape + self.pad = th.nn.ConstantPad3d(padding_shape, 0) + self.pool = th.nn.MaxPool3d(kernel_size, stride, ceil_mode=True) + + def _get_padding_shape(self, filter_shape, stride): + def _pad_top_bottom(filter_dim, stride_val): + pad_along = max(filter_dim - stride_val, 0) + pad_top = pad_along // 2 + pad_bottom = pad_along - pad_top + return pad_top, pad_bottom + + padding_shape = [] + for filter_dim, stride_val in zip(filter_shape, stride): + pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val) + padding_shape.append(pad_top) + padding_shape.append(pad_bottom) + depth_top = padding_shape.pop(0) + depth_bottom = padding_shape.pop(0) + padding_shape.append(depth_top) + padding_shape.append(depth_bottom) + return tuple(padding_shape) + + def forward(self, inp): + inp = self.pad(inp) + out = self.pool(inp) + return out + + +class S3D(nn.Module): + def __init__(self, dict_path, num_classes=512, gating=True, space_to_depth=True): + super(S3D, self).__init__() + self.num_classes = num_classes + self.gating = gating + self.space_to_depth = space_to_depth + if space_to_depth: + self.conv1 = STConv3D( + 24, 64, [2, 4, 4], stride=1, padding=(1, 2, 2), separable=False + ) + else: + self.conv1 = STConv3D( + 3, 64, [3, 7, 7], stride=2, padding=(1, 3, 3), separable=False + ) + self.conv_2b = STConv3D(64, 64, [1, 1, 1], separable=False) + self.conv_2c = STConv3D(64, 192, [3, 3, 3], padding=1, separable=True) + self.gating = SelfGating(192) + self.maxpool_2a = MaxPool3dTFPadding( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME" + ) + self.maxpool_3a = MaxPool3dTFPadding( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME" + ) + self.mixed_3b = InceptionBlock(192, 64, 96, 128, 16, 32, 32) + self.mixed_3c = InceptionBlock( + self.mixed_3b.output_dim, 128, 128, 192, 32, 96, 64 + ) + self.maxpool_4a = MaxPool3dTFPadding( + kernel_size=(3, 3, 3), stride=(2, 2, 2), padding="SAME" + ) + self.mixed_4b = InceptionBlock( + self.mixed_3c.output_dim, 192, 96, 208, 16, 48, 64 + ) + self.mixed_4c = InceptionBlock( + self.mixed_4b.output_dim, 160, 112, 224, 24, 64, 64 + ) + self.mixed_4d = InceptionBlock( + self.mixed_4c.output_dim, 128, 128, 256, 24, 64, 64 + ) + self.mixed_4e = InceptionBlock( + self.mixed_4d.output_dim, 112, 144, 288, 32, 64, 64 + ) + self.mixed_4f = InceptionBlock( + self.mixed_4e.output_dim, 256, 160, 320, 32, 128, 128 + ) + self.maxpool_5a = self.maxPool3d_5a_2x2 = MaxPool3dTFPadding( + kernel_size=(2, 2, 2), stride=(2, 2, 2), padding="SAME" + ) + self.mixed_5b = InceptionBlock( + self.mixed_4f.output_dim, 256, 160, 320, 32, 128, 128 + ) + self.mixed_5c = InceptionBlock( + self.mixed_5b.output_dim, 384, 192, 384, 48, 128, 128 + ) + + ''' + if init == 'kaiming_normal': + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, + mode='fan_in', + nonlinearity='relu') + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + ''' + + def _space_to_depth(self, input): + """3D space to depth trick for TPU optimization. + """ + B, C, T, H, W = input.shape + input = input.view(B, C, T // 2, 2, H // 2, 2, W // 2, 2) + input = input.permute(0, 3, 5, 7, 1, 2, 4, 6) + input = input.contiguous().view(B, 8 * C, T // 2, H // 2, W // 2) + return input + + def forward(self, inputs): + """Defines the S3DG base architecture. + """ + if self.space_to_depth: + inputs = self._space_to_depth(inputs) + net = self.conv1(inputs) + if self.space_to_depth: + # we need to replicate 'SAME' tensorflow padding + net = net[:, :, 1:, 1:, 1:] + net = self.maxpool_2a(net) + net = self.conv_2b(net) + net = self.conv_2c(net) + if self.gating: + net = self.gating(net) + net = self.maxpool_3a(net) + net = self.mixed_3b(net) + net = self.mixed_3c(net) + net = self.maxpool_4a(net) + net = self.mixed_4b(net) + net = self.mixed_4c(net) + net = self.mixed_4d(net) + net = self.mixed_4e(net) + net = self.mixed_4f(net) + net = self.maxpool_5a(net) + net = self.mixed_5b(net) + net = self.mixed_5c(net) + net = th.mean(net, dim=[2, 3, 4]) + return {'mixed_5c': net} diff --git a/env.yml b/env.yml index f862d9d..a1c6348 100644 --- a/env.yml +++ b/env.yml @@ -1,49 +1,40 @@ -name: visdial-chal -channels: - - pytorch - - defaults -dependencies: - - blas=1.0=mkl - - ca-certificates=2018.03.07=0 - - certifi=2018.4.16=py36_0 - - cffi=1.11.5=py36h9745a5d_0 - - cudatoolkit=8.0=3 - - freetype=2.8=hab7d2ae_1 - - intel-openmp=2018.0.3=0 - - jpeg=9b=h024ee3a_2 - - libedit=3.1.20170329=h6b74fdf_2 - - libffi=3.2.1=hd88cf55_4 - - libgcc-ng=7.2.0=hdf63c60_3 - - libgfortran-ng=7.2.0=hdf63c60_3 - - libpng=1.6.34=hb9fc6fc_0 - - libstdcxx-ng=7.2.0=hdf63c60_3 - - libtiff=4.0.9=he85c1e1_1 - - mkl=2018.0.3=1 - - mkl_fft=1.0.1=py36h3010b51_0 - - mkl_random=1.0.1=py36h629b387_0 - - ncurses=6.1=hf484d3e_0 - - ninja=1.8.2=py36h6bb024c_1 - - numpy=1.14.5=py36hcd700cb_3 - - numpy-base=1.14.5=py36hdbf6ddf_3 - - olefile=0.45.1=py36_0 - - openssl=1.0.2o=h20670df_0 - - pillow=5.1.0=py36h3deb7b8_0 - - pip=10.0.1=py36_0 - - pycparser=2.18=py36hf9f622e_1 - - python=3.6.6=hc3d631a_0 - - readline=7.0=ha6073c6_4 - - setuptools=39.2.0=py36_0 - - six=1.11.0=py36h372c433_1 - - sqlite=3.24.0=h84994c4_0 - - tk=8.6.7=hc745277_3 - - wheel=0.31.1=py36_0 - - xz=5.2.4=h14c3975_4 - - zlib=1.2.11=ha838bed_2 - - pytorch=0.3.0=py36_cuda8.0.61_cudnn7.0.3h37a80b5_4 - - torchvision=0.2.0=py36h17b6947_1 - - pip: - - cython==0.28.3 - - h5py==2.8.0 - - nltk==3.3 - - torch==0.3.0.post4 - - tqdm==4.23.4 +orflow 2.0.0 +tensorflow-estimator 2.0.1 +termcolor 1.1.0 +terminado 0.8.2 +testpath 0.4.2 +torch 1.3.1 +torchvision 0.2.0 +tornado 6.0.3 +tqdm 4.36.1 +traitlets 4.3.2 +transformers 2.0.0 +ubuntu-drivers-common 0.0.0 +ufw 0.35 +unattended-upgrades 0.1 +unity-scope-calculator 0.1 +unity-scope-chromiumbookmarks 0.1 +unity-scope-colourlovers 0.1 +unity-scope-devhelp 0.1 +unity-scope-firefoxbookmarks 0.1 +unity-scope-gdrive 0.7 +unity-scope-manpages 0.1 +unity-scope-openclipart 0.1 +unity-scope-texdoc 0.1 +unity-scope-tomboy 0.1 +unity-scope-virtualbox 0.1 +unity-scope-yelp 0.1 +unity-scope-zotero 0.1 +urllib3 1.25.6 +usb-creator 0.3.0 +wcwidth 0.1.7 +webencodings 0.5.1 +Werkzeug 0.16.0 +wheel 0.29.0 +widgetsnbextension 3.5.1 +wrapt 1.11.2 +xdiagnose 3.8.4.1 +xkit 0.0.0 +XlsxWriter 0.7.3 +zipp 0.6.0 +~ diff --git a/envn.yml b/envn.yml new file mode 100644 index 0000000..015f01a --- /dev/null +++ b/envn.yml @@ -0,0 +1,94 @@ +name: visdial-bert +channels: + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _pytorch_select=0.2=gpu_0 + - blas=1.0=mkl + - ca-certificates=2020.10.14=0 + - certifi=2020.11.8=py38h06a4308_0 + - cffi=1.14.3=py38he30daa8_0 + - cudatoolkit=10.2.89=hfd86e86_1 + - freetype=2.10.4=h5ab3b9f_0 + - intel-openmp=2020.2=254 + - jpeg=9b=h024ee3a_2 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libedit=3.1.20191231=h14c3975_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.1.0=h2733197_1 + - libuv=1.40.0=h7b6447c_0 + - lz4-c=1.9.2=heb0550a_3 + - mkl=2020.2=256 + - mkl-service=2.3.0=py38he904b0f_0 + - mkl_fft=1.2.0=py38h23d657b_0 + - mkl_random=1.1.1=py38h0573a6f_0 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.1=py38hfd86e86_0 + - numpy=1.19.2=py38h54aff64_0 + - numpy-base=1.19.2=py38hfa32c7d_0 + - olefile=0.46=py_0 + - openssl=1.1.1h=h7b6447c_0 + - pillow=8.0.1=py38he98fc37_0 + - pip=20.2.4=py38h06a4308_0 + - pycparser=2.20=py_2 + - python=3.8.5=h7579374_1 + - pytorch=1.7.0=py3.8_cuda10.2.89_cudnn7.6.5_0 + - readline=8.0=h7b6447c_0 + - setuptools=50.3.1=py38h06a4308_1 + - six=1.15.0=py38h06a4308_0 + - sqlite=3.33.0=h62c20be_0 + - tk=8.6.10=hbc83047_0 + - torchaudio=0.7.0=py38 + - torchvision=0.8.1=py38_cu102 + - typing_extensions=3.7.4.3=py_0 + - wheel=0.35.1=py_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - chardet==3.0.4 + - click==7.1.2 + - configparser==5.0.1 + - docker-pycreds==0.4.0 + - ffmpeg-python==0.2.0 + - future==0.18.2 + - gitdb==4.0.5 + - gitpython==3.1.11 + - h5py==3.1.0 + - h5py-wrapper==1.1.0 + - hdfdict==0.3.1 + - idna==2.10 + - install==1.3.4 + - jsonpatch==1.26 + - jsonpointer==2.0 + - pandas==1.1.4 + - pathtools==0.1.2 + - promise==2.3 + - protobuf==3.13.0 + - psutil==5.7.3 + - pytest-runner==5.2 + - python-dateutil==2.8.1 + - pytz==2020.4 + - pyyaml==5.3.1 + - pyzmq==20.0.0 + - requests==2.25.0 + - scipy==1.5.4 + - sentry-sdk==0.19.3 + - shortuuid==1.0.1 + - smmap==3.0.4 + - subprocess32==3.5.4 + - torchfile==0.1.0 + - tornado==6.1 + - tqdm==4.51.0 + - urllib3==1.26.2 + - visdom==0.1.8.9 + - wandb==0.10.10 + - watchdog==0.10.3 + - websocket-client==0.57.0 +prefix: /nethome/halamri3/anaconda3/envs/visdial-bert + diff --git a/evaluate.py b/evaluate.py index 0cebf24..e5e361a 100644 --- a/evaluate.py +++ b/evaluate.py @@ -2,79 +2,72 @@ import datetime import gc import json +import logging import math import os -from tqdm import tqdm import torch from torch.autograd import Variable from torch.utils.data import DataLoader +from tqdm import tqdm +from args import get_args from dataloader import VisDialDataset -from encoders import Encoder, LateFusionEncoder from decoders import Decoder -from utils import process_ranks, scores_to_ranks, get_gt_ranks +from encoders import Encoder, LateFusionEncoder +from models import AVSD +from utils import get_gt_ranks, process_ranks, scores_to_ranks, visualize +import pprint parser = argparse.ArgumentParser() VisDialDataset.add_cmdline_args(parser) LateFusionEncoder.add_cmdline_args(parser) -parser.add_argument('-input_type', default='question_dialog_video_audio', choices=['question_only', - 'question_dialog', - 'question_audio', - 'question_image', - 'question_video', - 'question_caption_image', - 'question_dialog_video', - 'question_dialog_image', - 'question_video_audio', - 'question_dialog_video_audio'], help='Specify the inputs') - -parser.add_argument_group('Evaluation related arguments') -parser.add_argument('-load_path', default='checkpoints/13-Jun-2019-16:22:48/model_epoch_14.pth', help='Checkpoint to load path from') -parser.add_argument('-split', default='test', choices=['val', 'test', 'train'], help='Split to evaluate on') -parser.add_argument('-use_gt', action='store_true', help='Whether to use ground truth for retrieving ranks') -parser.add_argument('-batch_size', default=12, type=int, help='Batch size') -parser.add_argument('-gpuid', default=0, type=int, help='GPU id to use') -parser.add_argument('-overfit', action='store_true', help='Use a batch of only 5 examples, useful for debugging') - -parser.add_argument_group('Submission related arguments') -parser.add_argument('-save_ranks', action='store_true', help='Whether to save retrieved ranks') -parser.add_argument('-save_path', default='logs/ranks.json', help='Path of json file to save ranks') +args = get_args(parser) + +viz = visualize.VisdomVisualize( + env_name=args.visEnvName, + server=args.server, + port=args.serverPort) + +# seed for reproducibility +torch.manual_seed(1234) +torch.backends.cudnn.deterministic = True +torch.autograd.set_detect_anomaly(True) + # ---------------------------------------------------------------------------- # input arguments and options # ---------------------------------------------------------------------------- args = parser.parse_args() +model_args = args +''' +log_path = os.path.join(args.load_path, 'eval_results.log') +logging.basicConfig(filename='eval_results.log') +''' -# seed for reproducibility -torch.manual_seed(1234) +cur = os.getcwd() +os.chdir(args.load_path) +checkpoints = sorted( + filter(os.path.isfile, os.listdir('.')), key=os.path.getmtime) +checkpoints = [file for file in checkpoints if file.endswith(".pth")] +logging.info("Evaluate the following checkpoints: %s", args.load_path) +os.chdir(cur) # set device and default tensor type +device = "cpu" if args.gpuid >= 0: torch.cuda.manual_seed_all(1234) - torch.cuda.set_device(args.gpuid) - -# ---------------------------------------------------------------------------- -# read saved model and args -# ---------------------------------------------------------------------------- - -components = torch.load(args.load_path) -model_args = components['model_args'] -model_args.gpuid = args.gpuid -model_args.batch_size = args.batch_size + args.num_gpu = torch.cuda.device_count() + device = "cuda" # set this because only late fusion encoder is supported yet args.concat_history = True -for arg in vars(args): - print('{:<20}: {}'.format(arg, getattr(args, arg))) - # ---------------------------------------------------------------------------- # loading dataset wrapping with a dataloader # ---------------------------------------------------------------------------- - dataset = VisDialDataset(args, [args.split]) dataloader = DataLoader(dataset, batch_size=args.batch_size, @@ -82,49 +75,67 @@ collate_fn=dataset.collate_fn) # iterations per epoch -setattr(args, 'iter_per_epoch', math.ceil(dataset.num_data_points[args.split] / args.batch_size)) +setattr(args, 'iter_per_epoch', math.ceil( + dataset.num_data_points[args.split] / args.batch_size)) print("{} iter per epoch.".format(args.iter_per_epoch)) # ---------------------------------------------------------------------------- # setup the model # ---------------------------------------------------------------------------- - -encoder = Encoder(model_args) -encoder.load_state_dict(components['encoder']) - -decoder = Decoder(model_args, encoder) -decoder.load_state_dict(components['decoder']) +''' +model = AVSD(model_args) +model._load_state_dict_(components) print("Loaded model from {}".format(args.load_path)) if args.gpuid >= 0: - encoder = encoder.cuda() - decoder = decoder.cuda() - + model = torch.nn.DataParallel(model, output_device=0, dim=0) + model = model.to(device) +''' # ---------------------------------------------------------------------------- # evaluation # ---------------------------------------------------------------------------- -print("Evaluation start time: {}".format( - datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S'))) -encoder.eval() -decoder.eval() +def convert_list_to_tensor(batch): + new_batch = {} + for k, v in batch.items(): + # tensor of list of strings isn't possible, hence removing the image fnames from the batch sent into the training module. + if isinstance(v, list) and not (k == "img_fnames"): + new_batch[k] = torch.Tensor(v) + elif isinstance(v, torch.Tensor): + new_batch[k] = v + return new_batch +def repeat_tensors(batch, num_repeat): + """In the last iterations, when the number of samples are not multiple of the num_gpu, this function will repeat the last few samples""" + new_batch = batch.copy() + for i in range(num_repeat): + for k, v in batch.items(): + if isinstance(v, list): + new_batch[k].append(v[-1]) + elif isinstance(v, torch.Tensor): + new_batch[k] = torch.cat((new_batch[k], v[-1].unsqueeze(0)), 0) + return new_batch +''' if args.use_gt: # ------------------------------------------------------------------------ # calculate automatic metrics and finish # ------------------------------------------------------------------------ all_ranks = [] - for i, batch in enumerate(tqdm(dataloader)): + for i, batch in tqdm(enumerate(tqdm(dataloader))): for key in batch: if not isinstance(batch[key], list): batch[key] = Variable(batch[key], volatile=True) if args.gpuid >= 0: batch[key] = batch[key].cuda() - enc_out = encoder(batch) - dec_out = decoder(enc_out, batch) + if not batch["vid_feat"].shape[0] % args.num_gpu == 0: + num_repeat = args.num_gpu - \ + batch["vid_feat"].shape[0] % args.num_gpu + batch = repeat_tensors(batch, num_repeat) + new_batch = convert_list_to_tensor(batch) + dec_out, _ = model(new_batch) ranks = scores_to_ranks(dec_out.data) gt_ranks = get_gt_ranks(ranks, batch['ans_ind'].data) all_ranks.append(gt_ranks) @@ -136,15 +147,19 @@ # prepare json for submission # ------------------------------------------------------------------------ ranks_json = [] - for i, batch in enumerate(tqdm(dataloader)): + for i, batch in tqdm(enumerate(tqdm(dataloader))): for key in batch: if not isinstance(batch[key], list): batch[key] = Variable(batch[key], volatile=True) if args.gpuid >= 0: batch[key] = batch[key].cuda() - enc_out = encoder(batch) - dec_out = decoder(enc_out, batch) + if not batch["vid_feat"].shape[0] % args.num_gpu == 0: + num_repeat = args.num_gpu - \ + batch["vid_feat"].shape[0] % args.num_gpu + batch = repeat_tensors(batch, num_repeat) + new_batch = convert_list_to_tensor(batch) + dec_out, _ = model(new_batch) ranks = scores_to_ranks(dec_out.data) ranks = ranks.view(-1, 10, 100) @@ -158,14 +173,121 @@ }) else: for j in range(batch['num_rounds'][i]): + + # read saved model and args + # ---------------------------------------------------------------------------- +''' + +print("Evaluation start time: {}".format( + datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S'))) + +i=0 + +for checkpoint in checkpoints: + + print('checkpoint:', checkpoint) + model_path = os.path.join(args.load_path, checkpoint) + components = torch.load(model_path) + model_args = components['model_args'] + if i == 0: + viz.showText(pprint.pformat(args, indent=4)) + i +=1 + model = AVSD(model_args) + model._load_state_dict_(components) + model_args.gpuid = args.gpuid + model_args.batch_size = args.batch_size + + + for arg in vars(args): + print('{:<20}: {}'.format(arg, getattr(args, arg))) + + # ---------------------------------------------------------------------------- + # setup the model + # ---------------------------------------------------------------------------- + + + print("Loaded model from {}".format(args.load_path)) + + if args.gpuid >= 0: + model = torch.nn.DataParallel(model, output_device=0, dim=0) + model = model.to(device) + + # ---------------------------------------------------------------------------- + # evaluation + # ---------------------------------------------------------------------------- + + print("Evaluation start time: {}".format( + datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S'))) + model.eval() + + if args.use_gt: + viz.save() + # ------------------------------------------------------------------------ + # calculate automatic metrics and finish + # ------------------------------------------------------------------------ + all_ranks = [] + for i, batch in tqdm(enumerate(tqdm(dataloader))): + for key in batch: + if not isinstance(batch[key], list): + batch[key] = Variable(batch[key], volatile=True) + if args.gpuid >= 0: + batch[key] = batch[key].cuda() + + new_batch = convert_list_to_tensor(batch) + dec_out, _ = model(new_batch) + ranks = scores_to_ranks(dec_out.data) + gt_ranks = get_gt_ranks(ranks, batch['ans_ind'].data) + all_ranks.append(gt_ranks) + all_ranks = torch.cat(all_ranks, 0) + + all_metrics = process_ranks(all_ranks, args.load_path, checkpoint[6:-4]) + iter_id = checkpoint[6:-4] + for metric_name, metric_value in all_metrics.items(): + print(f"{metric_name}: {metric_value}") + if 'round' in metric_name: + viz.plotLine(iter_id, metric_value, 'Retrieval Round Val Metrics Round -' + metric_name.split('_')[-1], + metric_name) + else: + viz.plotLine(iter_id.split('_')[1], metric_value, 'Retrieval Val Metrics', metric_name) + gc.collect() + else: + # ------------------------------------------------------------------------ + # prepare json for submission + # ------------------------------------------------------------------------ + ranks_json = [] + for i, batch in tqdm(enumerate(tqdm(dataloader))): + for key in batch: + if not isinstance(batch[key], list): + batch[key] = Variable(batch[key], volatile=True) + if args.gpuid >= 0: + batch[key] = batch[key].cuda() + + + new_batch = convert_list_to_tensor(batch) + dec_out, _ = model(new_batch) + ranks = scores_to_ranks(dec_out.data) + ranks = ranks.view(-1, 10, 100) + + for i in range(len(batch['img_fnames'])): + # cast into types explicitly to ensure no errors in schema + if args.split == 'test': ranks_json.append({ 'image_id': int(batch['img_fnames'][i][-16:-4]), - 'round_id': int(j + 1), - 'ranks': list(ranks[i][j]) + 'round_id': int(batch['num_rounds'][i]), + 'ranks': list(ranks[i][batch['num_rounds'][i] - 1]) }) - gc.collect() + else: + for j in range(batch['num_rounds'][i]): + ranks_json.append({ + 'image_id': int(batch['img_fnames'][i][-16:-4]), + 'round_id': int(j + 1), + 'ranks': list(ranks[i][j]) + }) + gc.collect() + + if args.save_ranks: + print("Writing ranks to {}".format(args.save_path)) + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + json.dump(ranks_json, open(args.save_path, 'w')) + -if args.save_ranks: - print("Writing ranks to {}".format(args.save_path)) - os.makedirs(os.path.dirname(args.save_path), exist_ok=True) - json.dump(ranks_json, open(args.save_path, 'w')) diff --git a/models.py b/models.py new file mode 100644 index 0000000..1f0aaa7 --- /dev/null +++ b/models.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from decoders import Decoder +from encoders import Encoder, LateFusionEncoder + + +class AVSD(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.encoder = Encoder(args) + self.decoder = Decoder(args, self.encoder) + self.criterion = nn.CrossEntropyLoss() + + def _load_state_dict_(self, components): + self.encoder.load_state_dict(components['encoder']) + self.decoder.load_state_dict(components['decoder']) + + def forward(self, batch): + enc_out = self.encoder(batch) + dec_out = self.decoder(enc_out, batch) + cur_loss = self.criterion(dec_out, batch['ans_ind'].view(-1)) + + return dec_out, cur_loss diff --git a/train.py b/train.py index 61dd67f..419c124 100644 --- a/train.py +++ b/train.py @@ -3,71 +3,60 @@ import gc import math import os -import numpy as np +import random +import numpy as np import torch from torch import nn, optim from torch.autograd import Variable from torch.optim import lr_scheduler from torch.utils.data import DataLoader +from tqdm import tqdm +from args import get_args from dataloader import VisDialDataset -from encoders import Encoder, LateFusionEncoder from decoders import Decoder +from encoders import Encoder, LateFusionEncoder +from models import AVSD +from utils import get_gt_ranks, process_ranks, scores_to_ranks, visualize parser = argparse.ArgumentParser() VisDialDataset.add_cmdline_args(parser) LateFusionEncoder.add_cmdline_args(parser) -parser.add_argument_group('Input modalites arguments') -parser.add_argument('-input_type', default='question_dialog_video_audio', choices=['question_only', - 'question_dialog', - 'question_audio', - 'question_image', - 'question_video', - 'question_caption_image', - 'question_dialog_video', - 'question_dialog_image', - 'question_video_audio', - 'question_dialog_video_audio'], help='Specify the inputs') - -parser.add_argument_group('Encoder Decoder choice arguments') -parser.add_argument('-encoder', default='lf-ques-im-hist', choices=['lf-ques-im-hist'], help='Encoder to use for training') -parser.add_argument('-concat_history', default=True, help='True for lf encoding') -parser.add_argument('-decoder', default='disc', choices=['disc'], help='Decoder to use for training') - -parser.add_argument_group('Optimization related arguments') -parser.add_argument('-num_epochs', default=20, type=int, help='Epochs') -parser.add_argument('-batch_size', default=12, type=int, help='Batch size') -parser.add_argument('-lr', default=1e-3, type=float, help='Learning rate') -parser.add_argument('-lr_decay_rate', default=0.9997592083, type=float, help='Decay for lr') -parser.add_argument('-min_lr', default=5e-5, type=float, help='Minimum learning rate') -parser.add_argument('-weight_init', default='xavier', choices=['xavier', 'kaiming'], help='Weight initialization strategy') -parser.add_argument('-weight_decay', default=0.00075, help='Weight decay for l2 regularization') -parser.add_argument('-overfit', action='store_true', help='Overfit on 5 examples, meant for debugging') -parser.add_argument('-gpuid', default=0, type=int, help='GPU id to use') - -parser.add_argument_group('Checkpointing related arguments') -parser.add_argument('-load_path', default='', help='Checkpoint to load path from') -parser.add_argument('-save_path', default='checkpoints/', help='Path to save checkpoints') -parser.add_argument('-save_step', default=2, type=int, help='Save checkpoint after every save_step epochs') +args = get_args(parser) +args.numpy_path += "/num_frames_{}".format(args.num_frames) +start_time = datetime.datetime.strftime( + datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') -# ---------------------------------------------------------------------------- -# input arguments and options -# ---------------------------------------------------------------------------- - -args = parser.parse_args() -start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') if args.save_path == 'checkpoints/': - args.save_path += start_time + # args.save_path += start_time + args.save_path += 'input_type_{0}_s3d_mixed_5c_fps_{1}_num_frames_{2}_text_encoder_{3}_lr_{4}_unfreeze_layer_{5}_finetune_{6}_use_npy_{7}_batch_size_{8}_finetuneBert_{9}'.format( + args.input_type, args.fps, args.num_frames, args.text_encoder, args.lr, args.unfreeze_layers, args.finetune, args.use_npy, args.batch_size, args.finetune_textEncoder) +# ------------------------------------------------------------------------------------- +# setting visdom args +# ------------------------------------------------------------------------------------- +viz = visualize.VisdomLinePlot( + env_name=args.visEnvName, + server=args.server, + port=args.serverPort) # seed for reproducibility +random.seed(args.seed) +np.random.seed(args.seed) torch.manual_seed(1234) +torch.backends.cudnn.deterministic = True +torch.autograd.set_detect_anomaly(True) + +if args.set_cuda_device is not '': + os.environ["CUDA_VISIBLE_DEVICES"] = args.set_cuda_device # set device and default tensor type +device = "cpu" if args.gpuid >= 0: torch.cuda.manual_seed_all(1234) - torch.cuda.set_device(args.gpuid) + args.num_gpu = torch.cuda.device_count() + device = "cuda" # transfer all options to model model_args = args @@ -88,6 +77,7 @@ for arg in vars(args): print('{:<20}: {}'.format(arg, getattr(args, arg))) +viz.writeText(args) # ---------------------------------------------------------------------------- # loading dataset wrapping with a dataloader # ---------------------------------------------------------------------------- @@ -96,14 +86,23 @@ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.num_workers, + drop_last=True, collate_fn=dataset.collate_fn) dataset_val = VisDialDataset(args, ['val']) dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, + num_workers=args.num_workers, shuffle=False, + drop_last=True, collate_fn=dataset.collate_fn) +dataset_test = VisDialDataset(args, ['test']) +dataloader_test = DataLoader(dataset_test, + batch_size=args.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn) # ---------------------------------------------------------------------------- # setting model args # ---------------------------------------------------------------------------- @@ -114,47 +113,82 @@ setattr(model_args, key, getattr(dataset, key)) # iterations per epoch -setattr(args, 'iter_per_epoch', math.ceil(dataset.num_data_points['train'] / args.batch_size)) +setattr(args, 'iter_per_epoch', math.ceil( + dataset.num_data_points['train'] / args.batch_size)) print("{} iter per epoch.".format(args.iter_per_epoch)) # ---------------------------------------------------------------------------- # setup the model # ---------------------------------------------------------------------------- -encoder = Encoder(model_args) -decoder = Decoder(model_args, encoder) -optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr, weight_decay=args.weight_decay) -criterion = nn.CrossEntropyLoss() -scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay_rate) +model = AVSD(model_args) +total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) +print("Total number of model params {0}".format(total_params)) +if args.finetune: + total_params = sum(p.numel() + for p in model.encoder.video_embed.parameters() if p.requires_grad) + print("Total number of s3dg params {0}".format(total_params)) +optimizer = optim.Adam(list(model.parameters()), + lr=args.lr, weight_decay=args.weight_decay) + +scheduler = lr_scheduler.StepLR( + optimizer, step_size=1, gamma=args.lr_decay_rate) if args.load_path != '': - encoder.load_state_dict(components['encoder']) - decoder.load_state_dict(components['decoder']) + model._load_state_dict_(components) print("Loaded model from {}".format(args.load_path)) print("Encoder: {}".format(args.encoder)) print("Decoder: {}".format(args.decoder)) if args.gpuid >= 0: - encoder = encoder.cuda() - decoder = decoder.cuda() - criterion = criterion.cuda() + model = torch.nn.DataParallel(model, output_device=0, dim=0) + model = model.to(device) # ---------------------------------------------------------------------------- # training # ---------------------------------------------------------------------------- -encoder.train() -decoder.train() +model.train() os.makedirs(args.save_path, exist_ok=True) +with open(os.path.join(args.save_path, "args_{0}.txt".format(start_time)), "w") as f: + f.write(str(args)) +f.close() running_loss = 0.0 train_begin = datetime.datetime.utcnow() -print("Training start time: {}".format(datetime.datetime.strftime(train_begin, '%d-%b-%Y-%H:%M:%S'))) +print("Training start time: {}".format( + datetime.datetime.strftime(train_begin, '%d-%b-%Y-%H:%M:%S'))) + + +def convert_list_to_tensor(batch): + new_batch = {} + for k, v in batch.items(): + # tensor of list of strings isn't possible, hence removing the image fnames from the batch sent into the training module. + if isinstance(v, list) and not (k == "img_fnames"): + new_batch[k] = torch.Tensor(v) + elif isinstance(v, torch.Tensor): + new_batch[k] = v + return new_batch + + +def repeat_tensors(batch, num_repeat): + """In the last iterations, when the number of samples are not multiple of the num_gpu, this function will repeat the last few samples""" + new_batch = batch.copy() + for i in range(num_repeat): + for k, v in batch.items(): + if isinstance(v, list): + new_batch[k].append(v[-1]) + elif isinstance(v, torch.Tensor): + new_batch[k] = torch.cat((new_batch[k], v[-1].unsqueeze(0)), 0) + return new_batch + log_loss = [] for epoch in range(1, model_args.num_epochs + 1): - for i, batch in enumerate(dataloader): + for i, batch in tqdm(enumerate(dataloader)): optimizer.zero_grad() + model.train() + model.zero_grad() for key in batch: if not isinstance(batch[key], list): batch[key] = Variable(batch[key]) @@ -164,10 +198,12 @@ # -------------------------------------------------------------------- # forward-backward pass and optimizer step # -------------------------------------------------------------------- - enc_out = encoder(batch) - dec_out = decoder(enc_out, batch) - - cur_loss = criterion(dec_out, batch['ans_ind'].view(-1)) + # if not batch["vid_feat"].shape[0] % args.num_gpu == 0: + # num_repeat = args.num_gpu - batch["vid_feat"].shape[0] % args.num_gpu + # batch = repeat_tensors(batch, num_repeat) + new_batch = convert_list_to_tensor(batch) + _, cur_loss = model(new_batch) + cur_loss = cur_loss.mean() cur_loss.backward() optimizer.step() @@ -176,35 +212,43 @@ # -------------------------------------------------------------------- # update running loss and decay learning rates # -------------------------------------------------------------------- - train_loss = cur_loss.data[0] + train_loss = cur_loss.item() + #import pdb + # pdb.set_trace() + if running_loss > 0.0: - running_loss = 0.95 * running_loss + 0.05 * cur_loss.data[0] + running_loss = 0.95 * running_loss + 0.05 * cur_loss.item() else: - running_loss = cur_loss.data[0] + running_loss = cur_loss.item() if optimizer.param_groups[0]['lr'] > args.min_lr: scheduler.step() - # -------------------------------------------------------------------- # print after every few iterations - # -------------------------------------------------------------------- - if i % 100 == 0: + + if (i + 1) % args.eval_step == 0: + print("Running validation") validation_losses = [] + model.eval() + model.zero_grad() for _, val_batch in enumerate(dataloader_val): for key in val_batch: if not isinstance(val_batch[key], list): val_batch[key] = Variable(val_batch[key]) if args.gpuid >= 0: val_batch[key] = val_batch[key].cuda() - enc_out = encoder(val_batch) - dec_out = decoder(enc_out, val_batch) - cur_loss = criterion(dec_out, val_batch['ans_ind'].view(-1)) - validation_losses.append(cur_loss.data[0]) + # if not val_batch["vid_feat"].shape[0] % args.num_gpu == 0: + # num_repeat = args.num_gpu - val_batch["vid_feat"].shape[0] % args.num_gpu + # val_batch = repeat_tensors(val_batch, num_repeat) + # print(val_batch["img_fnames"]) + new_batch_v = convert_list_to_tensor(val_batch) + _, cur_loss = model(new_batch_v) + cur_loss = cur_loss.mean() + validation_losses.append(cur_loss.item()) validation_loss = np.mean(validation_losses) - iteration = (epoch - 1) * args.iter_per_epoch + i log_loss.append((epoch, @@ -217,26 +261,49 @@ # print current time, running average, learning rate, iteration, epoch print("[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][val loss: {:6f}][lr: {:7f}]".format( datetime.datetime.utcnow() - train_begin, epoch, - iteration, running_loss, validation_loss, - optimizer.param_groups[0]['lr'])) + iteration, running_loss, validation_loss, + optimizer.param_groups[0]['lr'])) + viz.plotLine('Loss', 'Train', 'LOSS', iteration, train_loss) + viz.plotLine('Loss', 'Val', 'LOSS', iteration, validation_loss) # ------------------------------------------------------------------------ # save checkpoints and final model # ------------------------------------------------------------------------ if epoch % args.save_step == 0: torch.save({ - 'encoder': encoder.state_dict(), - 'decoder': decoder.state_dict(), + 'encoder': model.module.encoder.state_dict(), + 'decoder': model.module.decoder.state_dict(), 'optimizer': optimizer.state_dict(), - 'model_args': encoder.args + 'model_args': model.module.args }, os.path.join(args.save_path, 'model_epoch_{}.pth'.format(epoch))) + print('Running evaluation for checkpoint:', epoch) + model.eval() + all_ranks = [] + for i, batch in tqdm(enumerate(tqdm(dataloader))): + for key in batch: + if not isinstance(batch[key], list): + batch[key] = Variable(batch[key], volatile=True) + if args.gpuid >= 0: + batch[key] = batch[key].cuda() + + new_batch = convert_list_to_tensor(batch) + dec_out, _ = model(new_batch) + ranks = scores_to_ranks(dec_out.data) + gt_ranks = get_gt_ranks(ranks, batch['ans_ind'].data) + all_ranks.append(gt_ranks) + + all_ranks = torch.cat(all_ranks, 0) + process_ranks(all_ranks, args.save_path, epoch) + + f.close() + gc.collect() + model.train() torch.save({ - 'encoder': encoder.state_dict(), - 'decoder': decoder.state_dict(), + 'encoder': model.module.encoder.state_dict(), + 'decoder': model.module.decoder.state_dict(), 'optimizer': optimizer.state_dict(), - 'model_args': encoder.args + 'model_args': model.module.args }, os.path.join(args.save_path, 'model_final.pth')) np.save(os.path.join(args.save_path, 'log_loss'), log_loss) - diff --git a/utils/__init__.pyc b/utils/__init__.pyc deleted file mode 100644 index 0218991..0000000 Binary files a/utils/__init__.pyc and /dev/null differ diff --git a/utils/__pycache__/__init__.cpython-36.pyc b/utils/__pycache__/__init__.cpython-36.pyc deleted file mode 100644 index 6e55746..0000000 Binary files a/utils/__pycache__/__init__.cpython-36.pyc and /dev/null differ diff --git a/utils/__pycache__/dynamic_rnn.cpython-36.pyc b/utils/__pycache__/dynamic_rnn.cpython-36.pyc deleted file mode 100644 index daab028..0000000 Binary files a/utils/__pycache__/dynamic_rnn.cpython-36.pyc and /dev/null differ diff --git a/utils/__pycache__/eval_utils.cpython-36.pyc b/utils/__pycache__/eval_utils.cpython-36.pyc deleted file mode 100644 index a0be25b..0000000 Binary files a/utils/__pycache__/eval_utils.cpython-36.pyc and /dev/null differ diff --git a/utils/dynamic_rnn.py b/utils/dynamic_rnn.py index 696c83a..d925e10 100644 --- a/utils/dynamic_rnn.py +++ b/utils/dynamic_rnn.py @@ -39,6 +39,7 @@ def forward(self, seq_input, seq_lens, initial_state=None): assert hx[0].size(0) == self.rnn_model.num_layers else: hx = None + self.rnn_model.flatten_parameters() _, (h_n, c_n) = self.rnn_model(packed_seq_input, hx) rnn_output = h_n[-1].index_select(dim=0, index=bwd_order) @@ -46,7 +47,8 @@ def forward(self, seq_input, seq_lens, initial_state=None): @staticmethod def _get_sorted_order(lens): - sorted_len, fwd_order = torch.sort(lens.contiguous().view(-1), 0, descending=True) + sorted_len, fwd_order = torch.sort( + lens.contiguous().view(-1), 0, descending=True) _, bwd_order = torch.sort(fwd_order) if isinstance(sorted_len, Variable): sorted_len = sorted_len.data diff --git a/utils/dynamic_rnn.pyc b/utils/dynamic_rnn.pyc deleted file mode 100644 index 633e332..0000000 Binary files a/utils/dynamic_rnn.pyc and /dev/null differ diff --git a/utils/eval_utils.py b/utils/eval_utils.py index 568773d..0d69338 100644 --- a/utils/eval_utils.py +++ b/utils/eval_utils.py @@ -1,5 +1,5 @@ import torch - +import os def get_gt_ranks(ranks, ans_ind): ans_ind = ans_ind.view(-1) @@ -8,8 +8,7 @@ def get_gt_ranks(ranks, ans_ind): gt_ranks[i] = int(ranks[i, ans_ind[i]]) return gt_ranks - -def process_ranks(ranks): +def process_ranks(ranks, save_path, epoch): num_ques = ranks.size(0) num_opts = 100 @@ -24,11 +23,29 @@ def process_ranks(ranks): num_ge = torch.sum(ranks.ge(num_opts + 1)) print("Warning: some of ranks > 100: {}".format(num_ge)) ranks = ranks[ranks.le(num_opts + 1)] - + ranks = ranks.float() num_r1 = float(torch.sum(torch.le(ranks, 1))) num_r5 = float(torch.sum(torch.le(ranks, 5))) num_r10 = float(torch.sum(torch.le(ranks, 10))) + + with open(os.path.join(save_path, "ranks_resutls.txt"), "a+") as f: + f.write("Epoch: {}".format(epoch)) + f.write("\tNo. questions: {}\n".format(num_ques)) + f.write("\tr@1: {}\n".format(num_r1 / num_ques)) + f.write("\tr@5: {}\n".format(num_r5 / num_ques)) + f.write("\tr@10: {}\n".format(num_r10 / num_ques)) + f.write("\tmeanR: {}\n".format(torch.mean(ranks))) + f.write("\tmeanRR: {}\n".format(torch.mean(ranks.reciprocal()))) + f.write('\n') + f.close() + + metrics = { "r@1": num_r1 / num_ques, + "r@5": num_r5 / num_ques, + "r@10": num_r10 / num_ques, + "mean": torch.mean(ranks), + "mrr": torch.mean(ranks.reciprocal()) } + print("\tNo. questions: {}".format(num_ques)) print("\tr@1: {}".format(num_r1 / num_ques)) print("\tr@5: {}".format(num_r5 / num_ques)) @@ -36,6 +53,7 @@ def process_ranks(ranks): print("\tmeanR: {}".format(torch.mean(ranks))) print("\tmeanRR: {}".format(torch.mean(ranks.reciprocal()))) + return metrics def scores_to_ranks(scores): # sort in descending order - largest score gets highest rank diff --git a/utils/eval_utils.pyc b/utils/eval_utils.pyc deleted file mode 100644 index 8f65dae..0000000 Binary files a/utils/eval_utils.pyc and /dev/null differ diff --git a/utils/visualize.py b/utils/visualize.py new file mode 100644 index 0000000..97ca0fb --- /dev/null +++ b/utils/visualize.py @@ -0,0 +1,134 @@ +import os.path as pth +import json +from visdom import Visdom +import numpy as np + +class VisdomLinePlot(): + + def __init__(self, env_name='main', server="0.0.0.0", port=8899): + self.viz = Visdom( + port=port, + env=env_name, + server=server + ) + self.plot_list = {} + self.env = env_name + self.is_enabled = True + + def plotLine(self, scalar_name, split, title_name, x ,y): + + if scalar_name not in self.plot_list: + + self.plot_list[scalar_name] = self.viz.line( X=np.array([x,x]), Y=np.array([y,y]), env=self.env, + opts=dict(legend=[split], + title=title_name, + xlabel='Epochs', + ylabel= scalar_name)) + else: + + self.viz.line(X=np.array([x]), Y=np.array([y]), + env=self.env, + win=self.plot_list[scalar_name], + name=split, update='append') + + def writeText(self, dict): + output = '' + for arg in vars(dict): + output=output+('{:<20}: {}{}'.format(arg, getattr(dict, arg),"\n")) + self.viz.text(output) + + +class VisdomVisualize(): + def __init__(self, + env_name='main', + server="http://127.0.0.1", + port=8855, + enable=True): + ''' + Initialize a visdom server on server:port + ''' + print("Initializing visdom env [%s]" % env_name) + self.is_enabled = enable + self.env_name = env_name + if self.is_enabled: + self.viz = Visdom( + port=port, + env=env_name, + server=server, + ) + else: + self.viz = None + self.wins = {} + + def linePlot(self, x, y, key, line_name, xlabel="Epochs"): + ''' + Add or update a line plot on the visdom server self.viz + Argumens: + x : Scalar -> X-coordinate on plot + y : Scalar -> Value at x + key : Name of plot/graph + line_name : Name of line within plot/graph + xlabel : Label for x-axis (default: # Iterations) + Plots and lines are created if they don't exist, otherwise + they are updated. + ''' + key = str(key) + if self.is_enabled: + if key in self.wins.keys(): + self.viz.line( + X = np.array([x]), + Y = np.array([y]), + win = self.wins[key], + update = 'append', + name = line_name, + opts = dict(showlegend=True), + ) + else: + self.wins[key] = self.viz.line( + X = np.array([x]), + Y = np.array([y]), + win = key, + name = line_name, + opts = { + 'xlabel': xlabel, + 'ylabel': key, + 'title': key, + 'showlegend': True, + # 'legend': [line_name], + } + ) + + def showText(self, text, key): + ''' + Created a named text window or updates an existing one with + the name == key + ''' + key = str(key) + if self.is_enabled: + win = self.wins[key] if key in self.wins else None + self.wins[key] = self.viz.text(text, win=win) + + def addText(self, text): + ''' + Adds an unnamed text window without keeping track of win id + ''' + if self.is_enabled: + self.viz.text(text) + + def save(self): + if self.is_enabled: + self.viz.save([self.env_name]) + + def histPlot(self, x, key): + key = str(key) + if self.is_enabled: + if key in self.wins.keys(): + self.viz.histogram( + X = x.cpu().numpy(), + win = self.wins[key], + ) + else: + self.wins[key] = self.viz.histogram( + X = x.cpu().numpy(), + win = key + )