diff --git a/docs/conf.py b/docs/conf.py index 70baa6ee..3c6a9aab 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,11 +20,11 @@ # -- Project information ----------------------------------------------------- project = "virtex" -copyright = "2020, Karan Desai and Justin Johnson" +copyright = "2021, Karan Desai and Justin Johnson" author = "Karan Desai" # The full version, including alpha/beta/rc tags -release = "0.9" +release = "1.0" # -- General configuration --------------------------------------------------- @@ -62,9 +62,9 @@ # built documents. # # This version is used underneath the title on the index page. -version = "0.9" +version = "1.0" # The following is used if you need to also include a more detailed version. -release = "0.9" +release = "1.0" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/virtex/data.datasets.rst b/docs/virtex/data.datasets.rst index ebb31fd6..686a974d 100644 --- a/docs/virtex/data.datasets.rst +++ b/docs/virtex/data.datasets.rst @@ -10,7 +10,7 @@ Pretraining Datasets .. automodule:: virtex.data.datasets.captioning -.. automodule:: virtex.data.datasets.multilabel +.. automodule:: virtex.data.datasets.classification ------------------------------------------------------------------------------ diff --git a/docs/virtex/models.rst b/docs/virtex/models.rst index 93a8f68b..fbf75c79 100644 --- a/docs/virtex/models.rst +++ b/docs/virtex/models.rst @@ -16,6 +16,10 @@ Pretraining Models ------------------------------------------------------------------------------- +.. automodule:: virtex.models.masked_lm + +------------------------------------------------------------------------------- + Downstream Models ----------------- diff --git a/virtex/data/__init__.py b/virtex/data/__init__.py index 8de9bac5..a941f623 100644 --- a/virtex/data/__init__.py +++ b/virtex/data/__init__.py @@ -1,6 +1,9 @@ from .datasets.captioning import CaptioningDataset +from .datasets.classification import ( + TokenClassificationDataset, + MultiLabelClassificationDataset, +) from .datasets.masked_lm import MaskedLmDataset -from .datasets.multilabel import MultiLabelClassificationDataset from .datasets.downstream import ( ImageNetDataset, INaturalist2018Dataset, @@ -10,8 +13,9 @@ __all__ = [ "CaptioningDataset", - "MaskedLmDataset", + "TokenClassificationDataset", "MultiLabelClassificationDataset", + "MaskedLmDataset", "ImageDirectoryDataset", "ImageNetDataset", "INaturalist2018Dataset", diff --git a/virtex/data/datasets/multilabel.py b/virtex/data/datasets/classification.py similarity index 53% rename from virtex/data/datasets/multilabel.py rename to virtex/data/datasets/classification.py index e0010233..886033e9 100644 --- a/virtex/data/datasets/multilabel.py +++ b/virtex/data/datasets/classification.py @@ -2,16 +2,106 @@ import glob import json import os -from typing import Callable, Dict, List, Tuple +import random +from typing import Any, Callable, Dict, List, Tuple +import albumentations as alb import cv2 import numpy as np import torch from torch.utils.data import Dataset +from virtex.data.readers import LmdbReader +from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.data import transforms as T +class TokenClassificationDataset(Dataset): + r""" + A dataset which provides image-labelset pairs from a serialized LMDB file + (COCO Captions in this codebase). the set of caption tokens (unordered) + is treated as a labelset. Used for token classification pretraining task. + + Parameters + ---------- + data_root: str, optional (default = "datasets/coco") + Path to the dataset root directory. This must contain the serialized + LMDB files (for COCO ``train2017`` and ``val2017`` splits). + split: str, optional (default = "train") + Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``. + tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer + A tokenizer which has the mapping between word tokens and their + integer IDs. + image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM) + A list of transformations, from either `albumentations + `_ or :mod:`virtex.data.transforms` + to be applied on the image. + max_caption_length: int, optional (default = 30) + Maximum number of tokens to keep in output caption tokens. Extra tokens + will be trimmed from the right end of the token list. + """ + + def __init__( + self, + data_root: str, + split: str, + tokenizer: SentencePieceBPETokenizer, + image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, + max_caption_length: int = 30, + ): + lmdb_path = os.path.join(data_root, f"serialized_{split}.lmdb") + self.reader = LmdbReader(lmdb_path) + + self.image_transform = image_transform + self.caption_transform = alb.Compose( + [ + T.NormalizeCaption(), + T.TokenizeCaption(tokenizer), + T.TruncateCaptionTokens(max_caption_length), + ] + ) + self.padding_idx = tokenizer.token_to_id("") + + def __len__(self): + return len(self.reader) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + + image_id, image, captions = self.reader[idx] + + # Pick a random caption and then transform it. + caption = random.choice(captions) + + # Transform image-caption pair and convert image from HWC to CHW format. + # Pass in caption to image_transform due to paired horizontal flip. + # Caption won't be tokenized/processed here. + image_caption = self.image_transform(image=image, caption=caption) + image, caption = image_caption["image"], image_caption["caption"] + image = np.transpose(image, (2, 0, 1)) + + caption_tokens = self.caption_transform(caption=caption)["caption"] + return { + "image_id": torch.tensor(image_id, dtype=torch.long), + "image": torch.tensor(image, dtype=torch.float), + "labels": torch.tensor(caption_tokens, dtype=torch.long), + } + + def collate_fn( + self, data: List[Dict[str, torch.Tensor]] + ) -> Dict[str, torch.Tensor]: + + labels = torch.nn.utils.rnn.pad_sequence( + [d["labels"] for d in data], + batch_first=True, + padding_value=self.padding_idx, + ) + return { + "image_id": torch.stack([d["image_id"] for d in data], dim=0), + "image": torch.stack([d["image"] for d in data], dim=0), + "labels": labels, + } + + class MultiLabelClassificationDataset(Dataset): r""" A dataset which provides image-labelset pairs from COCO instance annotation @@ -56,7 +146,7 @@ def __init__( } # Mapping from image ID to list of unique category IDs (indices as above) # in corresponding image. - self._labels = defaultdict(list) + self._labels: Dict[str, Any] = defaultdict(list) for ann in _annotations["annotations"]: self._labels[ann["image_id"]].append(_category_ids[ann["category_id"]]) diff --git a/virtex/data/transforms.py b/virtex/data/transforms.py index 7e51c798..d4141c62 100644 --- a/virtex/data/transforms.py +++ b/virtex/data/transforms.py @@ -74,6 +74,8 @@ class TokenizeCaption(CaptionOnlyTransform): tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer A :class:`~virtex.data.tokenizers.SentencePieceBPETokenizer` which encodes a caption into tokens. + add_boundaries: bool, optional (defalult = True) + Whether to add ``[SOS]`` and ``[EOS]`` boundary tokens from tokenizer. Examples -------- diff --git a/virtex/factories.py b/virtex/factories.py index b1bc3a98..62bb89f3 100644 --- a/virtex/factories.py +++ b/virtex/factories.py @@ -191,7 +191,7 @@ class PretrainingDatasetFactory(Factory): "bicaptioning": vdata.CaptioningDataset, "captioning": vdata.CaptioningDataset, "masked_lm": vdata.MaskedLmDataset, - "token_classification": vdata.CaptioningDataset, + "token_classification": vdata.TokenClassificationDataset, "multilabel_classification": vdata.MultiLabelClassificationDataset, } @@ -235,9 +235,12 @@ def from_config(cls, config: Config, split: str = "train"): kwargs.update( tokenizer=tokenizer, max_caption_length=_C.DATA.MAX_CAPTION_LENGTH, - use_single_caption=_C.DATA.USE_SINGLE_CAPTION, - percentage=_C.DATA.USE_PERCENTAGE if split == "train" else 100.0, ) + if _C.MODEL.NAME != "token_classification": + kwargs.update( + use_single_caption=_C.DATA.USE_SINGLE_CAPTION, + percentage=_C.DATA.USE_PERCENTAGE if split == "train" else 100.0, + ) if _C.MODEL.NAME == "masked_lm": kwargs.update(