diff --git a/onmt/models/model.py b/onmt/models/model.py index 920adcc981..a3ce348413 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -1,5 +1,6 @@ """ Onmt NMT Model base class definition """ import torch.nn as nn +import torch class NMTModel(nn.Module): @@ -17,7 +18,8 @@ def __init__(self, encoder, decoder): self.encoder = encoder self.decoder = decoder - def forward(self, src, tgt, lengths, bptt=False, with_align=False): + def forward(self, src, tgt, lengths, bptt=False, + with_align=False, encode_tgt=False): """Forward propagate a `src` and `tgt` pair for training. Possible initialized with a beginning decoder state. @@ -46,9 +48,20 @@ def forward(self, src, tgt, lengths, bptt=False, with_align=False): if bptt is False: self.decoder.init_state(src, memory_bank, enc_state) + dec_out, attns = self.decoder(dec_in, memory_bank, memory_lengths=lengths, with_align=with_align) + + if encode_tgt: + # tgt for zero shot alignment loss + tgt_lengths = torch.Tensor(tgt.size(1))\ + .type_as(memory_bank) \ + .long() \ + .fill_(tgt.size(0)) + embs_tgt, memory_bank_tgt, ltgt = self.encoder(tgt, tgt_lengths) + return dec_out, attns, memory_bank, memory_bank_tgt + return dec_out, attns def update_dropout(self, dropout): diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py index 900096cf4d..b5959ce92e 100644 --- a/onmt/modules/copy_generator.py +++ b/onmt/modules/copy_generator.py @@ -186,14 +186,15 @@ def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, self.tgt_vocab = tgt_vocab self.normalize_by_length = normalize_by_length - def _make_shard_state(self, batch, output, range_, attns): + def _make_shard_state(self, batch, output, enc_src, enc_tgt, + range_, attns): """See base class for args description.""" if getattr(batch, "alignment", None) is None: raise AssertionError("using -copy_attn you need to pass in " "-dynamic_dict during preprocess stage.") shard_state = super(CopyGeneratorLossCompute, self)._make_shard_state( - batch, output, range_, attns) + batch, output, enc_src, enc_tgt, range_, attns) shard_state.update({ "copy_attn": attns.get("copy"), @@ -201,7 +202,8 @@ def _make_shard_state(self, batch, output, range_, attns): }) return shard_state - def _compute_loss(self, batch, output, target, copy_attn, align, + def _compute_loss(self, batch, normalization, output, target, + copy_attn, align, enc_src=None, enc_tgt=None, std_attn=None, coverage_attn=None): """Compute the loss. @@ -244,8 +246,18 @@ def _compute_loss(self, batch, output, target, copy_attn, align, offset_align = align[correct_mask] + len(self.tgt_vocab) target_data[correct_mask] += offset_align + if self.lambda_cosine != 0.0: + cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt) + loss += self.lambda_cosine * (cosine_loss / num_ex) + else: + cosine_loss = None + num_ex = 0 + # Compute sum of perplexities for stats - stats = self._stats(loss.sum().clone(), scores_data, target_data) + stats = self._stats(loss.sum().clone(), + cosine_loss.clone() if cosine_loss is not None + else cosine_loss, + scores_data, target_data, num_ex) # this part looks like it belongs in CopyGeneratorLoss if self.normalize_by_length: diff --git a/onmt/opts.py b/onmt/opts.py index e61724b26c..191fc5c151 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -193,6 +193,9 @@ def model_opts(parser): help='Train a coverage attention layer.') group.add('--lambda_coverage', '-lambda_coverage', type=float, default=0.0, help='Lambda value for coverage loss of See et al (2017)') + group.add('--lambda_cosine', '-lambda_cosine', type=float, default=0.0, + help='Lambda value for cosine alignment loss ' + 'of https://arxiv.org/abs/1903.07091 ') group.add('--loss_scale', '-loss_scale', type=float, default=0, help="For FP16 training, the static loss scale to use. If not " "set, the loss scale is dynamically computed.") diff --git a/onmt/trainer.py b/onmt/trainer.py index 4328ca52ea..d60c1a0ddb 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -70,7 +70,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): model_dtype=opt.model_dtype, earlystopper=earlystopper, dropout=dropout, - dropout_steps=dropout_steps) + dropout_steps=dropout_steps, + encode_tgt=True if opt.lambda_cosine > 0 else False) return trainer @@ -107,7 +108,8 @@ def __init__(self, model, train_loss, valid_loss, optim, n_gpu=1, gpu_rank=1, gpu_verbose_level=0, report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', - earlystopper=None, dropout=[0.3], dropout_steps=[0]): + earlystopper=None, dropout=[0.3], dropout_steps=[0], + encode_tgt=False): # Basic attributes. self.model = model self.train_loss = train_loss @@ -132,6 +134,7 @@ def __init__(self, model, train_loss, valid_loss, optim, self.earlystopper = earlystopper self.dropout = dropout self.dropout_steps = dropout_steps + self.encode_tgt = encode_tgt for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 @@ -314,11 +317,21 @@ def validate(self, valid_iter, moving_average=None): tgt = batch.tgt # F-prop through the model. - outputs, attns = valid_model(src, tgt, src_lengths, - with_align=self.with_align) + if self.encode_tgt: + outputs, attns, enc_src, enc_tgt = valid_model( + src, tgt, src_lengths, + with_align=self.with_align, + encode_tgt=self.encode_tgt) + else: + outputs, attns = valid_model( + src, tgt, src_lengths, + with_align=self.with_align) + enc_src, enc_tgt = None, None # Compute loss. - _, batch_stats = self.valid_loss(batch, outputs, attns) + _, batch_stats = self.valid_loss( + batch, outputs, attns, + enc_src=enc_src, enc_tgt=enc_tgt) # Update statistics. stats.update(batch_stats) @@ -361,8 +374,16 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, if self.accum_count == 1: self.optim.zero_grad() - outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt, - with_align=self.with_align) + if self.encode_tgt: + outputs, attns, enc_src, enc_tgt = self.model( + src, tgt, src_lengths, bptt=bptt, + with_align=self.with_align, encode_tgt=self.encode_tgt) + else: + outputs, attns = self.model( + src, tgt, src_lengths, bptt=bptt, + with_align=self.with_align) + enc_src, enc_tgt = None, None + bptt = True # 3. Compute loss. @@ -371,6 +392,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, batch, outputs, attns, + enc_src=enc_src, + enc_tgt=enc_tgt, normalization=normalization, shard_size=self.shard_size, trunc_start=j, diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index c48f0d3d21..f185e1a567 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -58,7 +58,7 @@ def build_loss_compute(model, tgt_field, opt, train=True): else: compute = NMTLossCompute( criterion, loss_gen, lambda_coverage=opt.lambda_coverage, - lambda_align=opt.lambda_align) + lambda_align=opt.lambda_align, lambda_cosine=opt.lambda_cosine) compute.to(device) return compute @@ -92,7 +92,8 @@ def __init__(self, criterion, generator): def padding_idx(self): return self.criterion.ignore_index - def _make_shard_state(self, batch, output, range_, attns=None): + def _make_shard_state(self, batch, enc_src, enc_tgt, + output, range_, attns=None): """ Make shard state dictionary for shards() to return iterable shards for efficient loss computation. Subclass must define @@ -123,6 +124,8 @@ def __call__(self, batch, output, attns, + enc_src=None, + enc_tgt=None, normalization=1.0, shard_size=0, trunc_start=0, @@ -157,18 +160,20 @@ def __call__(self, if trunc_size is None: trunc_size = batch.tgt.size(0) - trunc_start trunc_range = (trunc_start, trunc_start + trunc_size) - shard_state = self._make_shard_state(batch, output, trunc_range, attns) + shard_state = self._make_shard_state( + batch, output, enc_src, enc_tgt, trunc_range, attns) if shard_size == 0: - loss, stats = self._compute_loss(batch, **shard_state) - return loss / float(normalization), stats + loss, stats = self._compute_loss(batch, normalization, + **shard_state) + return loss, stats batch_stats = onmt.utils.Statistics() for shard in shards(shard_state, shard_size): - loss, stats = self._compute_loss(batch, **shard) - loss.div(float(normalization)).backward() + loss, stats = self._compute_loss(batch, normalization, **shard) + loss.backward() batch_stats.update(stats) return None, batch_stats - def _stats(self, loss, scores, target): + def _stats(self, loss, cosine_loss, scores, target, num_ex): """ Args: loss (:obj:`FloatTensor`): the loss computed by the loss criterion. @@ -182,7 +187,9 @@ def _stats(self, loss, scores, target): non_padding = target.ne(self.padding_idx) num_correct = pred.eq(target).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() - return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) + return onmt.utils.Statistics( + loss.item(), cosine_loss.item() if cosine_loss is not None else 0, + num_non_padding, num_correct, num_ex) def _bottle(self, _v): return _v.view(-1, _v.size(2)) @@ -227,15 +234,17 @@ class NMTLossCompute(LossComputeBase): """ def __init__(self, criterion, generator, normalization="sents", - lambda_coverage=0.0, lambda_align=0.0): + lambda_coverage=0.0, lambda_align=0.0, lambda_cosine=0.0): super(NMTLossCompute, self).__init__(criterion, generator) self.lambda_coverage = lambda_coverage self.lambda_align = lambda_align + self.lambda_cosine = lambda_cosine - def _make_shard_state(self, batch, output, range_, attns=None): + def _make_shard_state(self, batch, output, enc_src, enc_tgt, + range_, attns=None): shard_state = { "output": output, - "target": batch.tgt[range_[0] + 1: range_[1], :, 0], + "target": batch.tgt[range_[0] + 1: range_[1], :, 0] } if self.lambda_coverage != 0.0: coverage = attns.get("coverage", None) @@ -273,9 +282,15 @@ def _make_shard_state(self, batch, output, range_, attns=None): "align_head": attn_align, "ref_align": ref_align[:, range_[0] + 1: range_[1], :] }) + if self.lambda_cosine != 0.0: + shard_state.update({ + "enc_src": enc_src, + "enc_tgt": enc_tgt + }) return shard_state - def _compute_loss(self, batch, output, target, std_attn=None, + def _compute_loss(self, batch, normalization, output, target, + enc_src=None, enc_tgt=None, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): bottled_output = self._bottle(output) @@ -284,6 +299,7 @@ def _compute_loss(self, batch, output, target, std_attn=None, gtruth = target.view(-1) loss = self.criterion(scores, gtruth) + if self.lambda_coverage != 0.0: coverage_loss = self._compute_coverage_loss( std_attn=std_attn, coverage_attn=coverage_attn) @@ -296,7 +312,20 @@ def _compute_loss(self, batch, output, target, std_attn=None, align_loss = self._compute_alignement_loss( align_head=align_head, ref_align=ref_align) loss += align_loss - stats = self._stats(loss.clone(), scores, gtruth) + + loss = loss/float(normalization) + + if self.lambda_cosine != 0.0: + cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt) + loss += self.lambda_cosine * (cosine_loss / num_ex) + else: + cosine_loss = None + num_ex = 0 + + stats = self._stats(loss.clone() * normalization, + cosine_loss.clone() if cosine_loss is not None + else cosine_loss, + scores, gtruth, num_ex) return loss, stats @@ -305,6 +334,15 @@ def _compute_coverage_loss(self, std_attn, coverage_attn): covloss *= self.lambda_coverage return covloss + def _compute_cosine_loss(self, enc_src, enc_tgt): + max_src = enc_src.max(axis=0)[0] + max_tgt = enc_tgt.max(axis=0)[0] + cosine_loss = torch.nn.functional.cosine_similarity( + max_src.float(), max_tgt.float(), dim=1) + cosine_loss = 1 - cosine_loss + num_ex = cosine_loss.size(0) + return cosine_loss.sum(), num_ex + def _compute_alignement_loss(self, align_head, ref_align): """Compute loss between 2 partial alignment matrix.""" # align_head contains value in [0, 1) presenting attn prob, @@ -368,7 +406,7 @@ def shards(state, shard_size, eval_only=False): # over the shards, not over the keys: therefore, the values need # to be re-zipped by shard and then each shard can be paired # with the keys. - for shard_tensors in zip(*values): + for i, shard_tensors in enumerate(zip(*values)): yield dict(zip(keys, shard_tensors)) # Assumed backprop'd diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 273dae3dba..ac6ddf6820 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -120,6 +120,10 @@ def validate_train_opts(cls, opt): assert len(opt.attention_dropout) == len(opt.dropout_steps), \ "Number of attention_dropout values must match accum_steps values" + assert not(opt.max_generator_batches > 0 and opt.lambda_cosine != 0), \ + "-lambda_cosine loss is not implemented " \ + "for max_generator_batches > 0." + @classmethod def validate_translate_opts(cls, opt): if opt.beam_size != 1 and opt.random_sampling_topk != 1: diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 896d98c74d..87a1e7f8f1 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -17,12 +17,15 @@ class Statistics(object): * elapsed time """ - def __init__(self, loss=0, n_words=0, n_correct=0): + def __init__(self, loss=0, cosine_loss=0, n_words=0, + n_correct=0, num_ex=0): self.loss = loss self.n_words = n_words self.n_correct = n_correct self.n_src_words = 0 self.start_time = time.time() + self.cosine_loss = cosine_loss + self.num_ex = num_ex @staticmethod def all_gather_stats(stat, max_size=4096): @@ -81,6 +84,8 @@ def update(self, stat, update_n_src_words=False): self.loss += stat.loss self.n_words += stat.n_words self.n_correct += stat.n_correct + self.cosine_loss += stat.cosine_loss + self.num_ex += stat.num_ex if update_n_src_words: self.n_src_words += stat.n_src_words @@ -97,6 +102,10 @@ def ppl(self): """ compute perplexity """ return math.exp(min(self.loss / self.n_words, 100)) + def cos(self): + """ normalize cosine distance per example""" + return self.cosine_loss / self.num_ex + def elapsed_time(self): """ compute elapsed time """ return time.time() - self.start_time @@ -113,8 +122,12 @@ def output(self, step, num_steps, learning_rate, start): step_fmt = "%2d" % step if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) + if self.cosine_loss != 0: + cos_log = "cos: %4.2f; " % (self.cos()) + else: + cos_log = "" logger.info( - ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + + ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + cos_log + "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") % (step_fmt, self.accuracy(),