Skip to content

Commit e9b68a4

Browse files
committed
Initial Multilingual Recipe
1 parent bbce306 commit e9b68a4

File tree

2 files changed

+212
-0
lines changed

2 files changed

+212
-0
lines changed

training/multilingual/all_datasets.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import List
2+
from dataclasses import dataclass
3+
4+
from datasets import load_dataset, concatenate_datasets
5+
6+
7+
@dataclass
8+
class NusaX:
9+
target_languages = ["ace", "ban", "bjn", "bbc", "bug", "jav", "mad", "min", "nij", "sun"]
10+
train_datasets, validation_datasets = [], []
11+
for lang in target_languages:
12+
ds = load_dataset("indonlp/NusaX-MT", f"ind-{lang}", trust_remote_code=True)
13+
train_datasets.append(ds["train"])
14+
validation_datasets.append(ds["validation"])
15+
16+
dataset = {"train": concatenate_datasets(train_datasets), "validation": concatenate_datasets(validation_datasets)}
17+
18+
@staticmethod
19+
def train_samples() -> List[List[str]]:
20+
train_samples = []
21+
22+
for datum in NusaX.dataset["train"]:
23+
train_samples.append([datum["text_1"], datum["text_2"]])
24+
25+
return train_samples
26+
27+
@staticmethod
28+
def validation_samples() -> List[List[str]]:
29+
validation_samples = []
30+
31+
for datum in NusaX.dataset["validation"]:
32+
validation_samples.append([datum["text_1"], datum["text_2"]])
33+
34+
return validation_samples
35+
36+
37+
@dataclass
38+
class NusaTranslation:
39+
target_languages = ["abs", "btk", "bew", "bhp", "jav", "mad", "mak", "min", "mui", "rej", "sun"]
40+
train_datasets, validation_datasets = [], []
41+
for lang in target_languages:
42+
ds = load_dataset(
43+
"indonlp/nusatranslation_mt", f"nusatranslation_mt_ind_{lang}_nusantara_t2t", trust_remote_code=True
44+
)
45+
train_datasets.append(ds["train"])
46+
validation_datasets.append(ds["validation"])
47+
48+
dataset = {"train": concatenate_datasets(train_datasets), "validation": concatenate_datasets(validation_datasets)}
49+
50+
@staticmethod
51+
def train_samples() -> List[List[str]]:
52+
train_samples = []
53+
54+
for datum in NusaTranslation.dataset["train"]:
55+
train_samples.append([datum["text_1"], datum["text_2"]])
56+
57+
return train_samples
58+
59+
@staticmethod
60+
def validation_samples() -> List[List[str]]:
61+
validation_samples = []
62+
63+
for datum in NusaTranslation.dataset["validation"]:
64+
validation_samples.append([datum["text_1"], datum["text_2"]])
65+
66+
return validation_samples
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from dataclasses import dataclass
2+
from itertools import chain
3+
import math
4+
5+
from datargs import parse
6+
from datasets import load_dataset
7+
from torch.utils.data import DataLoader
8+
from sentence_transformers import SentenceTransformer, InputExample, models, losses
9+
from sentence_transformers.datasets import ParallelSentencesDataset
10+
from sentence_transformers.evaluation import (
11+
MSEEvaluator,
12+
TranslationEvaluator,
13+
EmbeddingSimilarityEvaluator,
14+
SequentialEvaluator,
15+
)
16+
17+
import numpy as np
18+
import torch.nn as nn
19+
20+
from all_datasets import NusaX, NusaTranslation
21+
22+
23+
@dataclass
24+
class Args:
25+
# data args
26+
student_model_name: str = "LazarusNLP/NusaBERT-base"
27+
teacher_model_name: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
28+
# train
29+
max_seq_length: int = 128
30+
# test
31+
test_dataset_name: str = "LazarusNLP/stsb_mt_id"
32+
test_dataset_split: str = "validation"
33+
test_text_column_1: str = "text_1"
34+
test_text_column_2: str = "text_2"
35+
test_label_column: str = "correlation"
36+
# training args
37+
num_epochs: int = 20
38+
train_batch_size: int = 128
39+
test_batch_size: int = 128
40+
learning_rate: float = 2e-5
41+
warmup_ratio: float = 0.1
42+
output_path: str = "exp/all-indobert-base"
43+
use_amp: bool = True
44+
# huggingface hub args
45+
hub_model_id: str = "LazarusNLP/all-indobert-base"
46+
hub_private_repo: bool = True
47+
48+
49+
def main(args: Args):
50+
# Load datasets
51+
raw_datasets = {
52+
"indonlp/NusaX-MT": NusaX,
53+
"indonlp/nusatranslation_mt": NusaTranslation,
54+
}
55+
56+
train_ds = [ds.train_samples() for ds in raw_datasets.values()]
57+
train_ds = list(chain.from_iterable(train_ds)) # flatten multiple datasets
58+
59+
dev_ds = [ds.validation_samples() for ds in raw_datasets.values()]
60+
dev_ds = list(chain.from_iterable(dev_ds))
61+
62+
test_ds = load_dataset(args.test_dataset_name, split=args.test_dataset_split)
63+
64+
# Load teacher model
65+
teacher_model = SentenceTransformer(args.teacher_model_name)
66+
teacher_dimension = teacher_model.get_sentence_embedding_dimension()
67+
68+
# Intialize model with mean pool
69+
word_embedding_model = models.Transformer(args.student_model_name, max_seq_length=args.max_seq_length)
70+
dimension = word_embedding_model.get_word_embedding_dimension()
71+
pooling_model = models.Pooling(dimension, pooling_mode="mean")
72+
# project student's output pooling to teacher's output dimension
73+
dense_model = models.Dense(
74+
in_features=dimension,
75+
out_features=teacher_dimension,
76+
activation_function=nn.Tanh(),
77+
)
78+
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])
79+
80+
# Prepare Parallel Sentences Dataset
81+
parallel_ds = ParallelSentencesDataset(
82+
student_model=student_model,
83+
teacher_model=teacher_model,
84+
batch_size=args.test_batch_size,
85+
use_embedding_cache=True,
86+
)
87+
parallel_ds.add_dataset(train_ds, max_sentence_length=args.max_seq_length)
88+
89+
# DataLoader to batch your data
90+
train_dataloader = DataLoader(parallel_ds, batch_size=args.train_batch_size)
91+
92+
warmup_steps = math.ceil(
93+
len(train_dataloader) * args.num_epochs * args.warmup_ratio
94+
) # 10% of train data for warm-up
95+
96+
# Flatten validation translation pairs into two separate lists
97+
source_sentences, target_sentences = map(list, zip(*dev_ds))
98+
99+
# MSE evaluation
100+
mse_evaluator = MSEEvaluator(
101+
source_sentences, target_sentences, teacher_model=teacher_model, batch_size=args.test_batch_size
102+
)
103+
104+
# Translation evaluation
105+
trans_evaluator = TranslationEvaluator(source_sentences, target_sentences, batch_size=args.test_batch_size)
106+
107+
# STS evaluation
108+
test_data = [
109+
InputExample(
110+
texts=[data[args.test_text_column_1], data[args.test_text_column_2]],
111+
label=float(data[args.test_label_column]) / 5.0,
112+
)
113+
for data in test_ds
114+
]
115+
116+
sts_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_data, batch_size=args.test_batch_size)
117+
118+
# Use MSE crosslingual distillation loss
119+
train_loss = losses.MSELoss(model=student_model)
120+
121+
# Call the fit method
122+
student_model.fit(
123+
train_objectives=[(train_dataloader, train_loss)],
124+
evaluator=SequentialEvaluator(
125+
[mse_evaluator, trans_evaluator, sts_evaluator], main_score_function=lambda scores: np.mean(scores)
126+
),
127+
epochs=args.num_epochs,
128+
warmup_steps=warmup_steps,
129+
show_progress_bar=True,
130+
optimizer_params={"lr": args.learning_rate, "eps": 1e-6},
131+
output_path=args.output_path,
132+
save_best_model=True,
133+
use_amp=args.use_amp,
134+
)
135+
136+
# Save model to HuggingFace Hub
137+
student_model.save_to_hub(
138+
args.hub_model_id,
139+
private=args.hub_private_repo,
140+
train_datasets=list(raw_datasets.keys()),
141+
)
142+
143+
144+
if __name__ == "__main__":
145+
args = parse(Args)
146+
main(args)

0 commit comments

Comments
 (0)