Skip to content

Commit

Permalink
Add TokenClassificationDataset and update docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
kdexd committed Mar 7, 2021
1 parent 5b2f12c commit ad6e91b
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 12 deletions.
8 changes: 4 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
# -- Project information -----------------------------------------------------

project = "virtex"
copyright = "2020, Karan Desai and Justin Johnson"
copyright = "2021, Karan Desai and Justin Johnson"
author = "Karan Desai"

# The full version, including alpha/beta/rc tags
release = "0.9"
release = "1.0"


# -- General configuration ---------------------------------------------------
Expand Down Expand Up @@ -62,9 +62,9 @@
# built documents.
#
# This version is used underneath the title on the index page.
version = "0.9"
version = "1.0"
# The following is used if you need to also include a more detailed version.
release = "0.9"
release = "1.0"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
2 changes: 1 addition & 1 deletion docs/virtex/data.datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Pretraining Datasets

.. automodule:: virtex.data.datasets.captioning

.. automodule:: virtex.data.datasets.multilabel
.. automodule:: virtex.data.datasets.classification

------------------------------------------------------------------------------

Expand Down
4 changes: 4 additions & 0 deletions docs/virtex/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Pretraining Models

-------------------------------------------------------------------------------

.. automodule:: virtex.models.masked_lm

-------------------------------------------------------------------------------

Downstream Models
-----------------

Expand Down
8 changes: 6 additions & 2 deletions virtex/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .datasets.captioning import CaptioningDataset
from .datasets.classification import (
TokenClassificationDataset,
MultiLabelClassificationDataset,
)
from .datasets.masked_lm import MaskedLmDataset
from .datasets.multilabel import MultiLabelClassificationDataset
from .datasets.downstream import (
ImageNetDataset,
INaturalist2018Dataset,
Expand All @@ -10,8 +13,9 @@

__all__ = [
"CaptioningDataset",
"MaskedLmDataset",
"TokenClassificationDataset",
"MultiLabelClassificationDataset",
"MaskedLmDataset",
"ImageDirectoryDataset",
"ImageNetDataset",
"INaturalist2018Dataset",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,106 @@
import glob
import json
import os
from typing import Callable, Dict, List, Tuple
import random
from typing import Any, Callable, Dict, List, Tuple

import albumentations as alb
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset

from virtex.data.readers import LmdbReader
from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.data import transforms as T


class TokenClassificationDataset(Dataset):
r"""
A dataset which provides image-labelset pairs from a serialized LMDB file
(COCO Captions in this codebase). the set of caption tokens (unordered)
is treated as a labelset. Used for token classification pretraining task.
Parameters
----------
data_root: str, optional (default = "datasets/coco")
Path to the dataset root directory. This must contain the serialized
LMDB files (for COCO ``train2017`` and ``val2017`` splits).
split: str, optional (default = "train")
Which split (from COCO 2017 version) to read. One of ``{"train", "val"}``.
tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer
A tokenizer which has the mapping between word tokens and their
integer IDs.
image_tranform: Callable, optional (default = virtex.data.transforms.DEFAULT_IMAGE_TRANSFORM)
A list of transformations, from either `albumentations
<https://albumentations.readthedocs.io/en/latest/>`_ or :mod:`virtex.data.transforms`
to be applied on the image.
max_caption_length: int, optional (default = 30)
Maximum number of tokens to keep in output caption tokens. Extra tokens
will be trimmed from the right end of the token list.
"""

def __init__(
self,
data_root: str,
split: str,
tokenizer: SentencePieceBPETokenizer,
image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
max_caption_length: int = 30,
):
lmdb_path = os.path.join(data_root, f"serialized_{split}.lmdb")
self.reader = LmdbReader(lmdb_path)

self.image_transform = image_transform
self.caption_transform = alb.Compose(
[
T.NormalizeCaption(),
T.TokenizeCaption(tokenizer),
T.TruncateCaptionTokens(max_caption_length),
]
)
self.padding_idx = tokenizer.token_to_id("<unk>")

def __len__(self):
return len(self.reader)

def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

image_id, image, captions = self.reader[idx]

# Pick a random caption and then transform it.
caption = random.choice(captions)

# Transform image-caption pair and convert image from HWC to CHW format.
# Pass in caption to image_transform due to paired horizontal flip.
# Caption won't be tokenized/processed here.
image_caption = self.image_transform(image=image, caption=caption)
image, caption = image_caption["image"], image_caption["caption"]
image = np.transpose(image, (2, 0, 1))

caption_tokens = self.caption_transform(caption=caption)["caption"]
return {
"image_id": torch.tensor(image_id, dtype=torch.long),
"image": torch.tensor(image, dtype=torch.float),
"labels": torch.tensor(caption_tokens, dtype=torch.long),
}

def collate_fn(
self, data: List[Dict[str, torch.Tensor]]
) -> Dict[str, torch.Tensor]:

labels = torch.nn.utils.rnn.pad_sequence(
[d["labels"] for d in data],
batch_first=True,
padding_value=self.padding_idx,
)
return {
"image_id": torch.stack([d["image_id"] for d in data], dim=0),
"image": torch.stack([d["image"] for d in data], dim=0),
"labels": labels,
}


class MultiLabelClassificationDataset(Dataset):
r"""
A dataset which provides image-labelset pairs from COCO instance annotation
Expand Down Expand Up @@ -56,7 +146,7 @@ def __init__(
}
# Mapping from image ID to list of unique category IDs (indices as above)
# in corresponding image.
self._labels = defaultdict(list)
self._labels: Dict[str, Any] = defaultdict(list)

for ann in _annotations["annotations"]:
self._labels[ann["image_id"]].append(_category_ids[ann["category_id"]])
Expand Down
2 changes: 2 additions & 0 deletions virtex/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class TokenizeCaption(CaptionOnlyTransform):
tokenizer: virtex.data.tokenizers.SentencePieceBPETokenizer
A :class:`~virtex.data.tokenizers.SentencePieceBPETokenizer` which encodes
a caption into tokens.
add_boundaries: bool, optional (defalult = True)
Whether to add ``[SOS]`` and ``[EOS]`` boundary tokens from tokenizer.
Examples
--------
Expand Down
9 changes: 6 additions & 3 deletions virtex/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class PretrainingDatasetFactory(Factory):
"bicaptioning": vdata.CaptioningDataset,
"captioning": vdata.CaptioningDataset,
"masked_lm": vdata.MaskedLmDataset,
"token_classification": vdata.CaptioningDataset,
"token_classification": vdata.TokenClassificationDataset,
"multilabel_classification": vdata.MultiLabelClassificationDataset,
}

Expand Down Expand Up @@ -235,9 +235,12 @@ def from_config(cls, config: Config, split: str = "train"):
kwargs.update(
tokenizer=tokenizer,
max_caption_length=_C.DATA.MAX_CAPTION_LENGTH,
use_single_caption=_C.DATA.USE_SINGLE_CAPTION,
percentage=_C.DATA.USE_PERCENTAGE if split == "train" else 100.0,
)
if _C.MODEL.NAME != "token_classification":
kwargs.update(
use_single_caption=_C.DATA.USE_SINGLE_CAPTION,
percentage=_C.DATA.USE_PERCENTAGE if split == "train" else 100.0,
)

if _C.MODEL.NAME == "masked_lm":
kwargs.update(
Expand Down

0 comments on commit ad6e91b

Please sign in to comment.