-
Notifications
You must be signed in to change notification settings - Fork 4
/
transformer_base.py
executable file
·179 lines (155 loc) · 6.68 KB
/
transformer_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.distributed import fsdp_wrap
from fairseq.models import FairseqEncoderDecoderModel
# from fairseq.models.transformer import (
# TransformerConfig,
# TransformerDecoderBase,
# TransformerEncoderBase,
# )
from .transformer_config import TransformerConfig
from .transformer_encoder import TransformerEncoderBase
from .transformer_decoder import TransformerDecoderBase
class TransformerModelBase(FairseqEncoderDecoderModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.transformer_parser
:prog:
"""
def __init__(self, cfg, encoder, decoder):
super().__init__(encoder, decoder)
self.cfg = cfg
self.supports_align_args = True
@classmethod
def add_args(cls, parser):
"""Add model-specific arguments to the parser."""
# we want to build the args recursively in this case.
gen_parser_from_dataclass(
parser, TransformerConfig(), delete_default=False, with_prefix=""
)
@classmethod
def build_model(cls, cfg, task):
"""Build a new model instance."""
# -- TODO T96535332
# bug caused by interaction between OmegaConf II and argparsing
cfg.decoder.input_dim = int(cfg.decoder.input_dim)
cfg.decoder.output_dim = int(cfg.decoder.output_dim)
# --
if cfg.encoder.layers_to_keep:
cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(","))
if cfg.decoder.layers_to_keep:
cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(","))
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if cfg.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if cfg.encoder.embed_dim != cfg.decoder.embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if cfg.decoder.embed_path and (
cfg.decoder.embed_path != cfg.encoder.embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = cls.build_embedding(
cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path
)
decoder_embed_tokens = encoder_embed_tokens
cfg.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = cls.build_embedding(
cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path
)
decoder_embed_tokens = cls.build_embedding(
cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path
)
if cfg.offload_activations:
cfg.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
return cls(cfg, encoder, decoder)
@classmethod
def build_embedding(cls, cfg, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
@classmethod
def build_encoder(cls, cfg, src_dict, embed_tokens):
return TransformerEncoderBase(cfg, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, cfg, tgt_dict, embed_tokens):
return TransformerDecoderBase(
cfg,
tgt_dict,
embed_tokens,
no_encoder_attn=cfg.no_cross_attention,
)
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
return_all_hiddens: bool = True,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Run the forward pass for an encoder-decoder model.
Copied from the base class, but without ``**kwargs``,
which are not supported by TorchScript.
"""
encoder_out = self.encoder(
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
return decoder_out
# Since get_normalized_probs is in the Fairseq Model which is not scriptable,
# I rewrite the get_normalized_probs from Base Class to call the
# helper function in the Base Class.
@torch.jit.export
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m