Skip to content

Commit 1cb3954

Browse files
author
Karan Desai
committed
Merge branch 'decoding' to 'master'
2 parents 341aef7 + f33ef5a commit 1cb3954

File tree

10 files changed

+345
-181
lines changed

10 files changed

+345
-181
lines changed

CHANGELOG.md

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
1-
ArXiv v1 -> v2 CHANGELOG
2-
=========================
1+
CHANGELOG
2+
=========
33

4-
[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is out CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0).
4+
This CHANGELOG file records changes between different arXiv versions of our paper, and the version of this codebase which should be used to reproduce the results in the corresponding arXiv version. View changes between code versions on the [Releases page](https://github.com/kdexd/virtex/releases).
5+
6+
ArXiv v1 -> v2
7+
==============
8+
9+
**Code version:** `v1.2`.
10+
11+
Fix image captioning results with a modified beam search implementation. _Rest of the downstream task results and pre-trained models are unchanged._
12+
13+
14+
ArXiv v1 -> v2
15+
==============
16+
17+
**Code version:** `v1.0` or `v1.1`.
18+
19+
[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is our CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0).
520

621
While the core motivation and approach is the same, we have made some minor changes in our experiments and evaluation setup. These slightly improve model performances across the board (within decimals). New models are available in [`v1.0` model zoo](http://kdexd.github.io/virtex/virtex/usage/model_zoo.html), however links to old models in `v0.9` will be active till June 30, 2021. We encourage you to use the new models!
722

configs/_base_bicaptioning_R_50_L1_H1024.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,20 @@ DATA:
3535

3636
MODEL:
3737
NAME: "virtex"
38+
3839
VISUAL:
3940
NAME: "torchvision::resnet50"
4041
PRETRAINED: false
4142
FROZEN: false
43+
4244
TEXTUAL:
4345
NAME: "transdec_postnorm::L1_H1024_A16_F4096"
4446
DROPOUT: 0.1
4547

48+
DECODER:
49+
NAME: "beam_search"
50+
BEAM_SIZE: 5
51+
4652
OPTIM:
4753
OPTIMIZER_NAME: "sgd"
4854
SGD_MOMENTUM: 0.9

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
author = "Karan Desai"
2525

2626
# The full version, including alpha/beta/rc tags
27-
release = "1.1"
27+
release = "1.2"
2828

2929

3030
# -- General configuration ---------------------------------------------------

scripts/eval_captioning.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
evaluate pretrained model on COCO Captions val2017 split."""
2222
)
2323
parser.add_argument(
24-
"--data-root", default=None,
24+
"--images", "--data-root", default=None,
2525
help="""Path to a directory containing image files to generate captions for.
2626
Default: COCO val2017 image directory as expected relative to project root."""
2727
)
@@ -89,6 +89,10 @@ def main(_A: argparse.Namespace):
8989
}
9090
)
9191

92+
logger.info("Displaying first 25 caption predictions:")
93+
for pred in predictions[:25]:
94+
logger.info(f"{pred['image_id']} :: {pred['caption']}")
95+
9296
# Save predictions as a JSON file if specified.
9397
if _A.output is not None:
9498
os.makedirs(os.path.dirname(_A.output), exist_ok=True)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_model_zoo_configs() -> List[str]:
4141

4242
setup(
4343
name="virtex",
44-
version="1.1.0",
44+
version="1.2.0",
4545
author="Karan Desai and Justin Johnson",
4646
description="VirTex: Learning Visual Representations with Textual Annotations",
4747
package_data={"virtex.model_zoo": get_model_zoo_configs()},

virtex/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,20 @@ def __init__(
158158
# Dropout probability for embedding, hidden features in textual head.
159159
_C.MODEL.TEXTUAL.DROPOUT = 0.1
160160

161+
_C.MODEL.DECODER = CN()
162+
# What algorithm to use for decoding. Supported values: {"beam_search",
163+
# "nucleus_sampling"}.
164+
_C.MODEL.DECODER.NAME = "beam_search"
165+
# Number of beams to decode (1 = greedy decoding). Ignored when decoding
166+
# through nucleus sampling.
167+
_C.MODEL.DECODER.BEAM_SIZE = 5
168+
# Size of nucleus for sampling predictions. Ignored when decoding through
169+
# beam search.
170+
_C.MODEL.DECODER.NUCLEUS_SIZE = 0.9
171+
# Maximum length of decoded caption. Decoding may end earlier when [EOS]
172+
# token is sampled.
173+
_C.MODEL.DECODER.MAX_DECODING_STEPS = _C.DATA.MAX_CAPTION_LENGTH
174+
161175
# ---------------------------------------------------------------------
162176
# Optimization hyper-parameters, default values are for pretraining
163177
# our best model on bicaptioning task (COCO Captions).

virtex/factories.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,24 @@
1818
signature of underlying class; or config hierarchy. Refer description of
1919
specific factories for more details.
2020
"""
21-
from functools import partial
2221
import re
22+
from functools import partial
2323
from typing import Any, Callable, Dict, Iterable, List
2424

2525
import albumentations as alb
2626
from torch import nn, optim
2727

28-
from virtex.config import Config
2928
import virtex.data as vdata
29+
import virtex.models as vmodels
30+
from virtex.config import Config
3031
from virtex.data import transforms as T
3132
from virtex.data.tokenizers import SentencePieceBPETokenizer
32-
import virtex.models as vmodels
3333
from virtex.modules import visual_backbones, textual_heads
3434
from virtex.optim import Lookahead, lr_scheduler
3535

36+
from virtex.utils.beam_search import AutoRegressiveBeamSearch
37+
from virtex.utils.nucleus_sampling import AutoRegressiveNucleusSampling
38+
3639

3740
class Factory(object):
3841
r"""
@@ -460,9 +463,9 @@ def from_config(cls, config: Config) -> nn.Module:
460463
# for matching kwargs here.
461464
if _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"}:
462465
kwargs = {
463-
"max_decoding_steps": _C.DATA.MAX_CAPTION_LENGTH,
464466
"sos_index": _C.DATA.SOS_INDEX,
465467
"eos_index": _C.DATA.EOS_INDEX,
468+
"decoder": CaptionDecoderFactory.from_config(_C),
466469
}
467470

468471
elif _C.MODEL.NAME == "token_classification":
@@ -482,6 +485,42 @@ def from_config(cls, config: Config) -> nn.Module:
482485
return cls.create(_C.MODEL.NAME, visual, textual, **kwargs)
483486

484487

488+
class CaptionDecoderFactory(Factory):
489+
r"""
490+
Factory to create decoders from predicting captions from VirTex model.
491+
492+
Possible choices: ``{"beam_search", "nucleus_sampling"}``.
493+
"""
494+
495+
PRODUCTS: Dict[str, Callable] = {
496+
"beam_search": AutoRegressiveBeamSearch,
497+
"nucleus_sampling": AutoRegressiveNucleusSampling,
498+
}
499+
500+
@classmethod
501+
def from_config(cls, config: Config) -> nn.Module:
502+
r"""
503+
Create a model directly from config.
504+
505+
Parameters
506+
----------
507+
config: virtex.config.Config
508+
Config object with all the parameters.
509+
"""
510+
511+
_C = config
512+
kwargs = {
513+
"eos_index": _C.DATA.EOS_INDEX,
514+
"max_steps": _C.MODEL.DECODER.MAX_DECODING_STEPS,
515+
}
516+
if _C.MODEL.DECODER.NAME == "beam_search":
517+
kwargs["beam_size"] = _C.MODEL.DECODER.BEAM_SIZE
518+
elif _C.MODEL.DECODER.NAME == "nucleus_sampling":
519+
kwargs["nucleus_size"] = _C.MODEL.DECODER.NUCLEUS_SIZE
520+
521+
return cls.create(_C.MODEL.DECODER.NAME, **kwargs)
522+
523+
485524
class OptimizerFactory(Factory):
486525
r"""Factory to create optimizers. Possible choices: ``{"sgd", "adamw"}``."""
487526

0 commit comments

Comments
 (0)