Skip to content

Commit

Permalink
thesis code
Browse files Browse the repository at this point in the history
  • Loading branch information
rrichajalota committed Aug 2, 2023
1 parent 0bb5b98 commit 2f2705a
Show file tree
Hide file tree
Showing 19 changed files with 661 additions and 66 deletions.
2 changes: 1 addition & 1 deletion evaluation/extract_ref_hyp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='extract reference and hypothesis from model generation')
parser.add_argument("--file", default="/netscratch/anonymous/results/generations/unsup/motra-old/699517/generate-test.txt")
parser.add_argument("--out_dir", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/src_hyp/")
parser.add_argument("--out_dir", default="/netscratch/jalota/datasets/motra-preprocessed/en_de/test/src_hyp/")
parser.add_argument("--name", default="699517.tsv")
args = parser.parse_args()
contains_dup = False
Expand Down
4 changes: 2 additions & 2 deletions evaluation/gen_fsq_ppl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
sed '1,56d' gen.txt > new_gen.txt
"""
parser = argparse.ArgumentParser(description='generate test data for binary classification from fairseq-generate output')
parser.add_argument("--file", default="/home/anonymous/gen_w_threshold_translated_test.txt")
parser.add_argument("--out_dir", default="/netscratch/anonymous/test_perplexity/")
parser.add_argument("--file", default="/home/jalota/gen_w_threshold_translated_test.txt")
parser.add_argument("--out_dir", default="/netscratch/jalota/test_perplexity/")
parser.add_argument("--name", default="test")
parser.add_argument("--exp", default="712684")
args = parser.parse_args()
Expand Down
1 change: 1 addition & 0 deletions evaluation/gen_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
# if tr.strip() == "!" or tr.strip() == "co-rapporteur ." or tr.strip() == "Thank you very much for your attention .":
# print(tr)
# continue
# if len(line.split()) < 510:
of.write(f"{line}\t1")
of.write("\n")
count += 1
Expand Down
4 changes: 2 additions & 2 deletions evaluation/qualitative_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd

allowed_postags=['NOUN', 'ADJ', 'VERB', 'ADV'] # or any other types
# nlp = spacy.load('en_core_web_trf')
nlp = spacy.load('de_dep_news_trf')
nlp = spacy.load('en_core_web_trf')
# nlp = spacy.load('de_dep_news_trf')

def token_filter(token):
return (token.pos_ in allowed_postags) & (not (token.is_punct | token.is_space |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from fairseq.criterions import register_criterion
from torch.nn.functional import gumbel_softmax
from torch.distributions import Categorical
from torch.utils.data import Dataset
from fairseq.data import LMContextWindowDataset, MonolingualDataset
import evaluate
import random
from fairseq.lm_perplexity import LanguageModel
from fairseq.lm_perplexity import LanguageModel, LanguageModelValidation

from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
Expand All @@ -36,6 +38,16 @@
logger = logging.getLogger("fairseq.criterion.UnsupervisedAugmentedCrossEntropyLoss")


class torchDataset(Dataset):
def __init__(self, data_list):
self.data_list = data_list

def __getitem__(self, index):
return self.data_list[index]

def __len__(self):
return len(self.data_list)

def cross_entropy(pred, soft_targets):
# logger.info(f"pred.size(): {pred.size()}")
# logger.info(f"soft_targets.size(): {soft_targets.size()}")
Expand Down Expand Up @@ -83,6 +95,27 @@ class UnsupervisedAugmentedLabelSmoothedCrossEntropyCriterionConfig(
default=0.5,
metadata={"help": "supervised loss weightage"},
)
pretrained_lm: str = field(
default="/netscratch/jalota/checkpoints/transformer_en_hansard/",
metadata={
"help": "pretrained fairseq LM model to evaluate PPL during unsupervised training."
},
)
pretrained_lm_dict_path: str = field(
default="/netscratch/jalota/datasets/data-bin/canadianHansard/lm/",
metadata={
"help": "dict path for pretrained fairseq LM model to evaluate PPL during unsupervised training."
},
)
lm_context_window: int = field(
default=5, metadata={"help": "context window size for evaluating PPL"}
)
bertscore_model: str = field(
default="roberta-base",
metadata={
"help": "which model to use for evaluating semantic similarity. for EN: roberta-base, DE: t5-base"
},
)


@register_criterion(
Expand All @@ -99,40 +132,53 @@ def __init__(
label_smoothing,
ignore_prefix_size,
report_accuracy,
lm_weight=0.5,
lm_weight=1,
cosine_weight=1,
unsupervised_weight=0.5,
unsupervised_weight=1,
supervised_weight=1,
bert_model='t5-base',
tau_gumbel_softmax=1.0, #TODO: change to 0.1 to enforce sparsity
bertscore_model='roberta-base',
lm_context_window=5,
pretrained_lm_dict_path="/netscratch/jalota/datasets/data-bin/canadianHansard/lm/",
pretrained_lm="/netscratch/jalota/checkpoints/transformer_en_hansard/",
tau_gumbel_softmax=0.1,
hard_gumbel_softmax=False,
eps_gumbel_softmax=1e-10,
soft_bert_score=False
):
# 'microsoft/deberta-v3-base'
# 'microsoft/deberta-v3-base' t5-base
# roberta-base for EN
super().__init__(
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy
)
self.lm_weight = torch.tensor(1)
self.cosine_weight = torch.tensor(1)
self.unsupervised_weight = torch.tensor(1.0)
self.supervised_weight = torch.tensor(1.0)
self.unsupervised_weight = torch.tensor(0.3)
self.supervised_weight = torch.tensor(0.7)
self.perplexity = Perplexity()
self.cosine_sim = CosineSimilarity()
self.mse_loss = MSELoss(reduction='mean')
self.bertscore_model = bert_model
self.bertscore_model = bertscore_model

self.tau_gumbel_softmax = tau_gumbel_softmax
self.hard_gumbel_softmax = hard_gumbel_softmax
self.eps_gumbel_softmax = eps_gumbel_softmax

self.pretrained_lm = pretrained_lm
self.pretrained_lm_dict_path = pretrained_lm_dict_path
self.lm_context_window = lm_context_window

# self.bert_scorer = BERTScorer(self.bert_model, soft_bert_score=soft_bert_score) # , device='cpu')
# self.pad_token_id = self.bert_scorer._tokenizer.convert_tokens_to_ids('[PAD]')
# hansard: /netscratch/jalota/checkpoints/transformer_en_hansard/
# hansard_data: /netscratch/jalota/datasets/data-bin/canadianHansard/lm/
# de: /netscratch/jalota/checkpoints/transformer_lm_de_finetuned/
# de_data: /netscratch/jalota/datasets/motra-sst/de/unsup_setup_raw/lm_finetuning/
self.bertscore = evaluate.load("bertscore")
self.lm = LanguageModel(path='/netscratch/jalota/checkpoints/transformer_lm_de_finetuned/',tgt_dict=task.tgt_dict)
self.lm = LanguageModel(path=self.pretrained_lm,tgt_dict=task.tgt_dict,data_name_or_path=self.pretrained_lm_dict_path)
self.val_lm = LanguageModelValidation(path=self.pretrained_lm,tgt_dict=task.tgt_dict, context_window=self.lm_context_window,data_name_or_path=self.pretrained_lm_dict_path)
# /netscratch/jalota/datasets/motra-sst/de/unsup_setup_raw/lm_finetuning/
# DE: /netscratch/jalota/checkpoints/transformer_lm_de_finetuned/
# EN: /netscratch/jalota/checkpoints/transformer_lm_en_finetuned/
# data_name_or_path='/netscratch/jalota/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/lm_finetune/'

#load("perplexity", module_type="measurement")

Expand All @@ -148,6 +194,7 @@ def forward(self, model, sample, seqeunce_generator=None, tgt_dict=None,reduce=T
sample_size = (
sample['sup']["target"].size(0) if self.sentence_avg else sample['sup']["ntokens"]
)
# logger.info(f'sample["sup"]["net_input"]["prev_output_tokens"]: {sample["sup"]["net_input"]["prev_output_tokens"]}')
## take the mean of loss and nll_loss here and convert them from log base e to 2
loss = loss_sum / sample_size / math.log(2)
nll_loss = nll_loss_sum / sample['sup']["ntokens"] / math.log(2)
Expand Down Expand Up @@ -184,42 +231,79 @@ def decode(toks, src=False, escape_unk=False):
).replace("<pad>", "").rstrip()

with torch.no_grad():
if any(sample["net_input"]["src_lengths"]) > 510:
logger.info(f'sample["net_input"]["src_lengths"]: {sample["net_input"]["src_lengths"]}')
gen_out = seqeunce_generator.generate(
[model], sample, prefix_tokens=None, constraints=None)

# logger.info(f"gen_out: {gen_out}")
hyps, hyps_tok = [], []
for i in range(len(gen_out)):
s = decode(gen_out[i][0]["tokens"]).strip()
if len(s) > 0:
hyps_tok.append(s)
# s = decode(gen_out[i][0]["tokens"]).strip()
# if len(s) > 0:
# hyps_tok.append(s)
hyps.append(gen_out[i][0]["tokens"])

hyps = collate_tokens(hyps, src_dict.pad(), src_dict.eos(), left_pad=False, pad_to_length=None,pad_to_bsz=None)
msize = max(v.size(0) for v in hyps)
msize = msize if msize <= 512 else 512
# logger.info(f"msize: {msize}")

hyps = collate_tokens(hyps, src_dict.pad(), src_dict.eos(), move_eos_to_beginning=False, left_pad=False, pad_to_length=512,pad_to_bsz=None)

batch_size = len(hyps)

if not train:
# calculate bertscore and PPL straight-away!
refs_list = []
hyps_tok = []
refs = sample['net_input']['src_tokens']
for i in range(len(refs)):
s = decode(refs[i]).strip()
refs_list.append(s)
hs = decode(gen_out[i][0]["tokens"]).strip()
if len(s.split()) > 2 and len(hs.split()) > 1:
hyps_tok.append(hs)
refs_list.append(s)

# refs_list.append(s)

# logger.info(f"len(refs_list): {len(refs_list)}")
# logger.info(f"len(hyps_tok): {len(hyps_tok)}")

# logger.info(f"refs_list: {refs_list[0:3]}")
# logger.info(f"hyps_tok: {hyps_tok[0:3]}")
# logger.info(f"refs_list: {refs_list}")
# logger.info(f"hyps_tok: {hyps_tok}")

sim_loss, _ = self.compute_bertLoss(hyps_tok, refs_list)

ppl_results = self.perplexity.compute(data=hyps_tok, model_id='/netscratch/jalota/checkpoints/gpt2-finetuned-motra-de-40epochs/', batch_size=len(hyps_tok), add_start_token=True)
# ppl_results = self.perplexity.compute(data=hyps_tok, model_id='/netscratch/jalota/checkpoints/gpt2-finetuned-motra/', batch_size=len(hyps_tok), add_start_token=True)
hyps_cpu, gen_sizes = [], []
for h in hyps:
# if h.size(0) <= 512:
hyps_cpu.append(h.cpu())
gen_sizes.append(msize)

# hyps = [h.cpu() for h in hyps]
# logger.info(f"len(hyps_cpu): {len(hyps_cpu)}")
# logger.info(f"gen_sizes: {gen_sizes}")

genData = torchDataset(data_list=hyps_cpu)
# gen_sizes = [msize for _ in range(len(genData))]
gen_data = MonolingualDataset(genData, gen_sizes, src_vocab=tgt_dict, fixed_pad_length=512)

ppl_results = self.val_lm.get_lm_perplexity(gen_data, batch_size)

# logger.info(f"ppl: {ppl_results['mean_perplexity']}")
# gpt2-finetuned-motra-de-40epochs/ - DE
# gpt2-finetuned-motra/ - EN

mean_per_word_entropy = math.log2(ppl_results['mean_perplexity'])
mean_per_word_entropy = ppl_results['loss']
# math.log2(ppl_results['mean_perplexity'])

unsupervised_loss = 1.0 * sim_loss + 1.0 * mean_per_word_entropy
loss += self.unsupervised_weight * unsupervised_loss
logging_output["loss"] = loss
logging_output["sim_loss"] = sim_loss
logging_output["mean_per_word_entropy"] = mean_per_word_entropy
logging_output["lm_ppl"] = ppl_results['mean_perplexity']
logging_output["lm_ppl"] = ppl_results['perplexity']
logging_output["unsupervised_loss"] = unsupervised_loss

else:
Expand Down Expand Up @@ -277,12 +361,15 @@ def get_similarity_loss(self, model, preds_tensor, sample, pad_token_id):
batch_size, max_seq_len, vocab_size = preds_tensor.size()
emb_size = emb_matrix.size()[-1]

preds_tensor_embs = torch.mm(preds_tensor.contiguous().view(-1, vocab_size), emb_matrix)
preds_tensor_embs = preds_tensor_embs.view(-1, max_seq_len, emb_size)
with torch.autocast("cuda"):
preds_tensor_embs = torch.mm(preds_tensor.contiguous().view(-1, vocab_size), emb_matrix)
preds_tensor_embs = preds_tensor_embs.view(-1, max_seq_len, emb_size)

with torch.no_grad():
source_emb = model.encoder.forward(sample['net_input']['src_tokens'].cuda())
preds_enc_emb = model.encoder.forward(preds_tensor_embs.cuda())
# logger.info(f"preds_tensor_embs: {preds_tensor_embs.dtype}")

with torch.no_grad():
source_emb = model.encoder.forward(sample['net_input']['src_tokens'])
preds_enc_emb = model.encoder.forward(preds_tensor_embs)

source_sent_repr = torch.sum(source_emb['encoder_out'][0], dim=0)
output_sent_repr = torch.sum(preds_enc_emb['encoder_out'][0], dim=0)
Expand Down Expand Up @@ -329,15 +416,18 @@ def prepare_second_pass_input(self, sample, tgt_dict, src_dict, hyps):
tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=True,
pad_to_length=None,
pad_to_length=512,
pad_to_multiple=1
)
# logger.info(f"prev_output_tokens: {prev_output_tokens}")

# logger.info(f"tgt_dict.eos():{tgt_dict.eos()}")

src_lengths = sample["net_input"]["src_lengths"]
src_tokens = sample["net_input"]["src_tokens"]
# logger.info(f"src_lengths: {src_lengths}")

# sort by descending src lengths
# sort by descending src lengths
src_lengths, sort_order = src_lengths.sort(descending=True)

sample['id'] = sample['id'].index_select(0, sort_order)
Expand All @@ -348,6 +438,8 @@ def prepare_second_pass_input(self, sample, tgt_dict, src_dict, hyps):
return sample

def compute_bertLoss(self, preds_list, refs_list, reduce=True):
# logger.info(f"len(refs_list): {len(refs_list)}")
# logger.info(f"len(preds_list): {len(preds_list)}")
results = self.bertscore.compute(predictions=preds_list, references=refs_list, model_type=self.bertscore_model)
avg_f1 = sum(results['f1'])/len(results['f1'])
bert_loss = 1-avg_f1
Expand Down Expand Up @@ -377,8 +469,11 @@ def compute_cosineSimilarityLoss(self, model, sample, hyps, train):
gen_out_emb = model.encoder.forward(hyps)

source_sent_repr = torch.sum(source_emb['encoder_out'][0], dim=0)
# logger.info(f"source_sent_repr: {source_sent_repr}")

output_sent_repr = torch.sum(gen_out_emb['encoder_out'][0], dim=0).cuda()

# logger.info(f"output_sent_repr: {output_sent_repr}")
target_labels = torch.ones(source_sent_repr.shape[0], dtype=source_sent_repr.dtype).cuda()
# cosineLoss = torch.nn.CosineEmbeddingLoss(reduction='mean')
# cos_sim_loss = cosineLoss(source_sent_repr, output_sent_repr, target_labels)
Expand Down Expand Up @@ -412,9 +507,9 @@ def reduce_metrics(cls, logging_outputs) -> None:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
metrics.log_derived(
"lm_ppl", lm_ppl, unsup_nsentences, round=3
)
# metrics.log_derived(
# "lm_ppl", lm_ppl, unsup_nsentences,
# )
metrics.log_scalar(
"sim_loss", sim_loss, unsup_nsentences, round=3
)
Expand Down
8 changes: 8 additions & 0 deletions fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def collate_tokens(
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
if size < 512 and pad_to_length is not None:
pad_to_length = size
size = size if pad_to_length is None else max(size, pad_to_length)
if size >= 512:
logger.info(f"size!: {size}")
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)

Expand Down Expand Up @@ -161,6 +165,7 @@ def collect_filtered(function, iterable, filtered):


def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
# logger.info(f"size_fn: {size_fn}")
def compare_leq(a, b):
return a <= b if not isinstance(a, tuple) else max(a) <= b

Expand Down Expand Up @@ -211,6 +216,8 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False):
"Use `FairseqDataset::filter_indices_by_size` instead.",
stacklevel=2,
)
# logger.info(f"max_positions: {max_positions}")
# logger.info(f"dataset: {dataset}")
if isinstance(max_positions, float) or isinstance(max_positions, int):
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
Expand Down Expand Up @@ -355,6 +362,7 @@ def batch_by_size(
max_sentences,
bsz_mult,
)
#logger.info(f"b: {b}")

if bsz_mult > 1 and len(b[-1]) % bsz_mult != 0:
b = b[:-1]
Expand Down
8 changes: 8 additions & 0 deletions fairseq/data/lm_context_window_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,15 @@ def collater(self, samples) -> Dict:
if extra > 0:
self.prev_tokens = self.prev_tokens[extra:]
pads = np.full(self.context_window - len(self.prev_tokens), pad)
# if toks[i].get_device() != -1:
# toks[i] = toks[i].cpu().data # move the tensor to cpu
# print(f"self.prev_tokens: {self.prev_tokens}")
# print(f"toks[i]: {toks[i]}")
new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads])
# print(f"new_toks[i]: {new_toks[i]}")
# print(f"tgt[i]: {tgt[i]}")
tgt[i] = tgt[i].cpu().data #.numpy()
# print(f"tgt[i]: {tgt[i]}")
new_tgt[
i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i])
] = tgt[i]
Expand Down
Loading

0 comments on commit 2f2705a

Please sign in to comment.