|
| 1 | +from dataclasses import dataclass |
| 2 | +from itertools import chain |
| 3 | +import math |
| 4 | + |
| 5 | +from datargs import parse |
| 6 | +from datasets import load_dataset |
| 7 | +from torch.utils.data import DataLoader |
| 8 | +from sentence_transformers import SentenceTransformer, InputExample, models, losses |
| 9 | +from sentence_transformers.datasets import ParallelSentencesDataset |
| 10 | +from sentence_transformers.evaluation import ( |
| 11 | + MSEEvaluator, |
| 12 | + TranslationEvaluator, |
| 13 | + EmbeddingSimilarityEvaluator, |
| 14 | + SequentialEvaluator, |
| 15 | +) |
| 16 | + |
| 17 | +import numpy as np |
| 18 | +import torch.nn as nn |
| 19 | + |
| 20 | +from all_datasets import NusaX, NusaTranslation |
| 21 | + |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class Args: |
| 25 | + # data args |
| 26 | + student_model_name: str = "LazarusNLP/NusaBERT-base" |
| 27 | + teacher_model_name: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
| 28 | + # train |
| 29 | + max_seq_length: int = 128 |
| 30 | + # test |
| 31 | + test_dataset_name: str = "LazarusNLP/stsb_mt_id" |
| 32 | + test_dataset_split: str = "validation" |
| 33 | + test_text_column_1: str = "text_1" |
| 34 | + test_text_column_2: str = "text_2" |
| 35 | + test_label_column: str = "correlation" |
| 36 | + # training args |
| 37 | + num_epochs: int = 20 |
| 38 | + train_batch_size: int = 128 |
| 39 | + test_batch_size: int = 128 |
| 40 | + learning_rate: float = 2e-5 |
| 41 | + warmup_ratio: float = 0.1 |
| 42 | + output_path: str = "exp/all-indobert-base" |
| 43 | + use_amp: bool = True |
| 44 | + # huggingface hub args |
| 45 | + hub_model_id: str = "LazarusNLP/all-indobert-base" |
| 46 | + hub_private_repo: bool = True |
| 47 | + |
| 48 | + |
| 49 | +def main(args: Args): |
| 50 | + # Load datasets |
| 51 | + raw_datasets = { |
| 52 | + "indonlp/NusaX-MT": NusaX, |
| 53 | + "indonlp/nusatranslation_mt": NusaTranslation, |
| 54 | + } |
| 55 | + |
| 56 | + train_ds = [ds.train_samples() for ds in raw_datasets.values()] |
| 57 | + train_ds = list(chain.from_iterable(train_ds)) # flatten multiple datasets |
| 58 | + |
| 59 | + dev_ds = [ds.validation_samples() for ds in raw_datasets.values()] |
| 60 | + dev_ds = list(chain.from_iterable(dev_ds)) |
| 61 | + |
| 62 | + test_ds = load_dataset(args.test_dataset_name, split=args.test_dataset_split) |
| 63 | + |
| 64 | + # Load teacher model |
| 65 | + teacher_model = SentenceTransformer(args.teacher_model_name) |
| 66 | + teacher_dimension = teacher_model.get_sentence_embedding_dimension() |
| 67 | + |
| 68 | + # Intialize model with mean pool |
| 69 | + word_embedding_model = models.Transformer(args.student_model_name, max_seq_length=args.max_seq_length) |
| 70 | + dimension = word_embedding_model.get_word_embedding_dimension() |
| 71 | + pooling_model = models.Pooling(dimension, pooling_mode="mean") |
| 72 | + # project student's output pooling to teacher's output dimension |
| 73 | + dense_model = models.Dense( |
| 74 | + in_features=dimension, |
| 75 | + out_features=teacher_dimension, |
| 76 | + activation_function=nn.Tanh(), |
| 77 | + ) |
| 78 | + student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model]) |
| 79 | + |
| 80 | + # Prepare Parallel Sentences Dataset |
| 81 | + parallel_ds = ParallelSentencesDataset( |
| 82 | + student_model=student_model, |
| 83 | + teacher_model=teacher_model, |
| 84 | + batch_size=args.test_batch_size, |
| 85 | + use_embedding_cache=True, |
| 86 | + ) |
| 87 | + parallel_ds.add_dataset(train_ds, max_sentence_length=args.max_seq_length) |
| 88 | + |
| 89 | + # DataLoader to batch your data |
| 90 | + train_dataloader = DataLoader(parallel_ds, batch_size=args.train_batch_size) |
| 91 | + |
| 92 | + warmup_steps = math.ceil( |
| 93 | + len(train_dataloader) * args.num_epochs * args.warmup_ratio |
| 94 | + ) # 10% of train data for warm-up |
| 95 | + |
| 96 | + # Flatten validation translation pairs into two separate lists |
| 97 | + source_sentences, target_sentences = map(list, zip(*dev_ds)) |
| 98 | + |
| 99 | + # MSE evaluation |
| 100 | + mse_evaluator = MSEEvaluator( |
| 101 | + source_sentences, target_sentences, teacher_model=teacher_model, batch_size=args.test_batch_size |
| 102 | + ) |
| 103 | + |
| 104 | + # Translation evaluation |
| 105 | + trans_evaluator = TranslationEvaluator(source_sentences, target_sentences, batch_size=args.test_batch_size) |
| 106 | + |
| 107 | + # STS evaluation |
| 108 | + test_data = [ |
| 109 | + InputExample( |
| 110 | + texts=[data[args.test_text_column_1], data[args.test_text_column_2]], |
| 111 | + label=float(data[args.test_label_column]) / 5.0, |
| 112 | + ) |
| 113 | + for data in test_ds |
| 114 | + ] |
| 115 | + |
| 116 | + sts_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_data, batch_size=args.test_batch_size) |
| 117 | + |
| 118 | + # Use MSE crosslingual distillation loss |
| 119 | + train_loss = losses.MSELoss(model=student_model) |
| 120 | + |
| 121 | + # Call the fit method |
| 122 | + student_model.fit( |
| 123 | + train_objectives=[(train_dataloader, train_loss)], |
| 124 | + evaluator=SequentialEvaluator( |
| 125 | + [mse_evaluator, trans_evaluator, sts_evaluator], main_score_function=lambda scores: np.mean(scores) |
| 126 | + ), |
| 127 | + epochs=args.num_epochs, |
| 128 | + warmup_steps=warmup_steps, |
| 129 | + show_progress_bar=True, |
| 130 | + optimizer_params={"lr": args.learning_rate, "eps": 1e-6}, |
| 131 | + output_path=args.output_path, |
| 132 | + save_best_model=True, |
| 133 | + use_amp=args.use_amp, |
| 134 | + ) |
| 135 | + |
| 136 | + # Save model to HuggingFace Hub |
| 137 | + student_model.save_to_hub( |
| 138 | + args.hub_model_id, |
| 139 | + private=args.hub_private_repo, |
| 140 | + train_datasets=list(raw_datasets.keys()), |
| 141 | + ) |
| 142 | + |
| 143 | + |
| 144 | +if __name__ == "__main__": |
| 145 | + args = parse(Args) |
| 146 | + main(args) |
0 commit comments