diff --git a/docs/conf.py b/docs/conf.py
index 70baa6ee..3c6a9aab 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -20,11 +20,11 @@
# -- Project information -----------------------------------------------------
project = "virtex"
-copyright = "2020, Karan Desai and Justin Johnson"
+copyright = "2021, Karan Desai and Justin Johnson"
author = "Karan Desai"
# The full version, including alpha/beta/rc tags
-release = "0.9"
+release = "1.0"
# -- General configuration ---------------------------------------------------
@@ -62,9 +62,9 @@
# built documents.
#
# This version is used underneath the title on the index page.
-version = "0.9"
+version = "1.0"
# The following is used if you need to also include a more detailed version.
-release = "0.9"
+release = "1.0"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
diff --git a/docs/virtex/data.datasets.rst b/docs/virtex/data.datasets.rst
index ebb31fd6..686a974d 100644
--- a/docs/virtex/data.datasets.rst
+++ b/docs/virtex/data.datasets.rst
@@ -10,7 +10,7 @@ Pretraining Datasets
.. automodule:: virtex.data.datasets.captioning
-.. automodule:: virtex.data.datasets.multilabel
+.. automodule:: virtex.data.datasets.classification
------------------------------------------------------------------------------
diff --git a/docs/virtex/models.rst b/docs/virtex/models.rst
index 93a8f68b..fbf75c79 100644
--- a/docs/virtex/models.rst
+++ b/docs/virtex/models.rst
@@ -16,6 +16,10 @@ Pretraining Models
-------------------------------------------------------------------------------
+.. automodule:: virtex.models.masked_lm
+
+-------------------------------------------------------------------------------
+
Downstream Models
-----------------
diff --git a/virtex/data/__init__.py b/virtex/data/__init__.py
index 8de9bac5..a941f623 100644
--- a/virtex/data/__init__.py
+++ b/virtex/data/__init__.py
@@ -1,6 +1,9 @@
from .datasets.captioning import CaptioningDataset
+from .datasets.classification import (
+ TokenClassificationDataset,
+ MultiLabelClassificationDataset,
+)
from .datasets.masked_lm import MaskedLmDataset
-from .datasets.multilabel import MultiLabelClassificationDataset
from .datasets.downstream import (
ImageNetDataset,
INaturalist2018Dataset,
@@ -10,8 +13,9 @@
__all__ = [
"CaptioningDataset",
- "MaskedLmDataset",
+ "TokenClassificationDataset",
"MultiLabelClassificationDataset",
+ "MaskedLmDataset",
"ImageDirectoryDataset",
"ImageNetDataset",
"INaturalist2018Dataset",
diff --git a/virtex/data/datasets/multilabel.py b/virtex/data/datasets/classification.py
similarity index 53%
rename from virtex/data/datasets/multilabel.py
rename to virtex/data/datasets/classification.py
index e0010233..886033e9 100644
--- a/virtex/data/datasets/multilabel.py
+++ b/virtex/data/datasets/classification.py
@@ -2,16 +2,106 @@
import glob
import json
import os
-from typing import Callable, Dict, List, Tuple
+import random
+from typing import Any, Callable, Dict, List, Tuple
+import albumentations as alb
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
+from virtex.data.readers import LmdbReader
+from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.data import transforms as T
+class TokenClassificationDataset(Dataset):
+ r"""
+ A dataset which provides image-labelset pairs from a serialized LMDB file
+ (COCO Captions in this codebase). the set of caption tokens (unordered)
+ is treated as a labelset. Used for token classification pretraining task.
+
+ Parameters
+ ----------
+ data_root: str, optional (default = "datasets/coco")
+ Path to the dataset root directory. This must contain the serialized
+ LMDB files (for COCO ``train2017`` and ``val2017`` splits).
+ split: str, optional (default = "train")
+ Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``.
+ tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer
+ A tokenizer which has the mapping between word tokens and their
+ integer IDs.
+ image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM)
+ A list of transformations, from either `albumentations
+ `_ or :mod:`virtex.data.transforms`
+ to be applied on the image.
+ max_caption_length: int, optional (default = 30)
+ Maximum number of tokens to keep in output caption tokens. Extra tokens
+ will be trimmed from the right end of the token list.
+ """
+
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ tokenizer: SentencePieceBPETokenizer,
+ image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
+ max_caption_length: int = 30,
+ ):
+ lmdb_path = os.path.join(data_root, f"serialized_{split}.lmdb")
+ self.reader = LmdbReader(lmdb_path)
+
+ self.image_transform = image_transform
+ self.caption_transform = alb.Compose(
+ [
+ T.NormalizeCaption(),
+ T.TokenizeCaption(tokenizer),
+ T.TruncateCaptionTokens(max_caption_length),
+ ]
+ )
+ self.padding_idx = tokenizer.token_to_id("")
+
+ def __len__(self):
+ return len(self.reader)
+
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+
+ image_id, image, captions = self.reader[idx]
+
+ # Pick a random caption and then transform it.
+ caption = random.choice(captions)
+
+ # Transform image-caption pair and convert image from HWC to CHW format.
+ # Pass in caption to image_transform due to paired horizontal flip.
+ # Caption won't be tokenized/processed here.
+ image_caption = self.image_transform(image=image, caption=caption)
+ image, caption = image_caption["image"], image_caption["caption"]
+ image = np.transpose(image, (2, 0, 1))
+
+ caption_tokens = self.caption_transform(caption=caption)["caption"]
+ return {
+ "image_id": torch.tensor(image_id, dtype=torch.long),
+ "image": torch.tensor(image, dtype=torch.float),
+ "labels": torch.tensor(caption_tokens, dtype=torch.long),
+ }
+
+ def collate_fn(
+ self, data: List[Dict[str, torch.Tensor]]
+ ) -> Dict[str, torch.Tensor]:
+
+ labels = torch.nn.utils.rnn.pad_sequence(
+ [d["labels"] for d in data],
+ batch_first=True,
+ padding_value=self.padding_idx,
+ )
+ return {
+ "image_id": torch.stack([d["image_id"] for d in data], dim=0),
+ "image": torch.stack([d["image"] for d in data], dim=0),
+ "labels": labels,
+ }
+
+
class MultiLabelClassificationDataset(Dataset):
r"""
A dataset which provides image-labelset pairs from COCO instance annotation
@@ -56,7 +146,7 @@ def __init__(
}
# Mapping from image ID to list of unique category IDs (indices as above)
# in corresponding image.
- self._labels = defaultdict(list)
+ self._labels: Dict[str, Any] = defaultdict(list)
for ann in _annotations["annotations"]:
self._labels[ann["image_id"]].append(_category_ids[ann["category_id"]])
diff --git a/virtex/data/transforms.py b/virtex/data/transforms.py
index 7e51c798..d4141c62 100644
--- a/virtex/data/transforms.py
+++ b/virtex/data/transforms.py
@@ -74,6 +74,8 @@ class TokenizeCaption(CaptionOnlyTransform):
tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer
A :class:`~virtex.data.tokenizers.SentencePieceBPETokenizer` which encodes
a caption into tokens.
+ add_boundaries: bool, optional (defalult = True)
+ Whether to add ``[SOS]`` and ``[EOS]`` boundary tokens from tokenizer.
Examples
--------
diff --git a/virtex/factories.py b/virtex/factories.py
index b1bc3a98..62bb89f3 100644
--- a/virtex/factories.py
+++ b/virtex/factories.py
@@ -191,7 +191,7 @@ class PretrainingDatasetFactory(Factory):
"bicaptioning": vdata.CaptioningDataset,
"captioning": vdata.CaptioningDataset,
"masked_lm": vdata.MaskedLmDataset,
- "token_classification": vdata.CaptioningDataset,
+ "token_classification": vdata.TokenClassificationDataset,
"multilabel_classification": vdata.MultiLabelClassificationDataset,
}
@@ -235,9 +235,12 @@ def from_config(cls, config: Config, split: str = "train"):
kwargs.update(
tokenizer=tokenizer,
max_caption_length=_C.DATA.MAX_CAPTION_LENGTH,
- use_single_caption=_C.DATA.USE_SINGLE_CAPTION,
- percentage=_C.DATA.USE_PERCENTAGE if split == "train" else 100.0,
)
+ if _C.MODEL.NAME != "token_classification":
+ kwargs.update(
+ use_single_caption=_C.DATA.USE_SINGLE_CAPTION,
+ percentage=_C.DATA.USE_PERCENTAGE if split == "train" else 100.0,
+ )
if _C.MODEL.NAME == "masked_lm":
kwargs.update(