Skip to content

Commit 66b3e22

Browse files
committed
Added All v4 Dataset Results and CachedMNRL Loss Training
1 parent dab15bb commit 66b3e22

File tree

8 files changed

+272
-21
lines changed

8 files changed

+272
-21
lines changed

README.md

Lines changed: 19 additions & 1 deletion
Large diffs are not rendered by default.

docs/index.md

Lines changed: 33 additions & 2 deletions
Large diffs are not rendered by default.

docs/training/all.md

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,25 @@ Inspired by [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-
44

55
## Training Data
66

7-
| Dataset | Task | Data Instance | Number of Training Tuples |
8-
| ------------------------------------------------------------------------------------------------------------------ | :----------------------------: | :-------------------------------------------: | :-----------------------: |
9-
| [indonli](https://huggingface.co/datasets/indonli) | Natural Language Inference | `(premise, entailment, contradiction)` | 3,914 |
10-
| [indolem/indo_story_cloze](https://huggingface.co/datasets/indolem/indo_story_cloze) | Commonsense Reasoning | `(context, correct ending, incorrect ending)` | 1,000 |
11-
| [unicamp-dl/mmarco](https://huggingface.co/datasets/unicamp-dl/mmarco) | Passage Retrieval | `(query, positive passage, negative passage)` | 100,000 |
12-
| [miracl/miracl](https://huggingface.co/datasets/miracl/miracl) | Passage Retrieval | `(query, positive passage, negative passage)` | 8,086 |
13-
| [SEACrowd/wrete](https://huggingface.co/datasets/SEACrowd/wrete) | Textual Entailment | `(sentenceA, sentenceB)` | 183 |
14-
| [SEACrowd/indolem_ntp](https://huggingface.co/datasets/SEACrowd/indolem_ntp) | Textual Entailment | `(tweet, next tweet)` | 5,681 |
15-
| [khalidalt/tydiqa-goldp](https://huggingface.co/datasets/khalidalt/tydiqa-goldp) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 11,404 |
16-
| [SEACrowd/facqa](https://huggingface.co/datasets/SEACrowd/facqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 4,990 |
17-
| *included in v2* |
18-
| [indonesian-nlp/lfqa_id](https://huggingface.co/datasets/indonesian-nlp/lfqa_id) | Open-domain Question-Answering | `(question, answer)` | 226,147 |
19-
| [jakartaresearch/indoqa](https://huggingface.co/datasets/jakartaresearch/indoqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 6,498 |
20-
| [jakartaresearch/id-paraphrase-detection](https://huggingface.co/datasets/jakartaresearch/id-paraphrase-detection) | Paraphrase | `(sentence, rephrased sentence)` | 4,076 |
21-
| **Total** | | | **371,979** |
7+
| Dataset | Task | Data Instance | Number of Training Tuples |
8+
| -------------------------------------------------------------------------------------------------------------------------- | :----------------------------: | :-------------------------------------------: | :-----------------------: |
9+
| [indonli](https://huggingface.co/datasets/indonli) | Natural Language Inference | `(premise, entailment, contradiction)` | 3,914 |
10+
| [indolem/indo_story_cloze](https://huggingface.co/datasets/indolem/indo_story_cloze) | Commonsense Reasoning | `(context, correct ending, incorrect ending)` | 1,000 |
11+
| [unicamp-dl/mmarco](https://huggingface.co/datasets/unicamp-dl/mmarco) | Passage Retrieval | `(query, positive passage, negative passage)` | 100,000 |
12+
| [miracl/miracl](https://huggingface.co/datasets/miracl/miracl) | Passage Retrieval | `(query, positive passage, negative passage)` | 8,086 |
13+
| [SEACrowd/wrete](https://huggingface.co/datasets/SEACrowd/wrete) | Textual Entailment | `(sentenceA, sentenceB)` | 183 |
14+
| [SEACrowd/indolem_ntp](https://huggingface.co/datasets/SEACrowd/indolem_ntp) | Textual Entailment | `(tweet, next tweet)` | 5,681 |
15+
| [khalidalt/tydiqa-goldp](https://huggingface.co/datasets/khalidalt/tydiqa-goldp) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 11,404 |
16+
| [SEACrowd/facqa](https://huggingface.co/datasets/SEACrowd/facqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 4,990 |
17+
| *included in v2* |
18+
| [indonesian-nlp/lfqa_id](https://huggingface.co/datasets/indonesian-nlp/lfqa_id) | Open-domain Question-Answering | `(question, answer)` | 226,147 |
19+
| [jakartaresearch/indoqa](https://huggingface.co/datasets/jakartaresearch/indoqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 6,498 |
20+
| [jakartaresearch/id-paraphrase-detection](https://huggingface.co/datasets/jakartaresearch/id-paraphrase-detection) | Paraphrase | `(sentence, rephrased sentence)` | 4,076 |
21+
| *included in v3* |
22+
| [LazarusNLP/multilingual-NLI-26lang-2mil7-id](https://huggingface.co/datasets/LazarusNLP/multilingual-NLI-26lang-2mil7-id) | Natural Language Inference | `(premise, entailment, hypothesis)` | 41,924 |
23+
| *included in v4* |
24+
| [nthakur/swim-ir-monolingual](https://huggingface.co/datasets/nthakur/swim-ir-monolingual) | Passage Retrieval | `(query, positive passage, negative passage)` | 227,145 |
25+
| **Total** | | | **641,048** |
2226

2327
## All Supervised Datasets with MultipleNegativesRankingLoss
2428

@@ -46,6 +50,21 @@ python train_all_mnrl.py \
4650
--learning-rate 2e-5
4751
```
4852

53+
## All Supervised Datasets with CachedMultipleNegativesRankingLoss
54+
55+
### IndoBERT Base
56+
57+
```sh
58+
python train_all_mnrl.py \
59+
--model-name indobenchmark/indobert-base-p1 \
60+
--max-seq-length 128 \
61+
--num-epochs 5 \
62+
--train-batch-size-pairs 384 \
63+
--train-batch-size-triplets 256 \
64+
--mini-batch-size 320 \
65+
--learning-rate 2e-5
66+
```
67+
4968
## References
5069

5170
```bibtex

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ git+https://github.com/w11wo/SCT.git
44
git+https://github.com/embeddings-benchmark/mteb.git
55
datasets
66
scikit-learn
7-
datargs
7+
datargs
8+
nusacrowd

training/all/README.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ Inspired by [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-
1919
| [jakartaresearch/indoqa](https://huggingface.co/datasets/jakartaresearch/indoqa) | Extractive Question-Answering | `(question, passage)`, `(question, answer)` | 6,498 |
2020
| [jakartaresearch/id-paraphrase-detection](https://huggingface.co/datasets/jakartaresearch/id-paraphrase-detection) | Paraphrase | `(sentence, rephrased sentence)` | 4,076 |
2121
| *included in v3* |
22-
| [LazarusNLP/multilingual-NLI-26lang-2mil7-id](https://huggingface.co/datasets/LazarusNLP/multilingual-NLI-26lang-2mil7-id) | Natural Language Inference | `(premise, entailement hypothesis)` | 41,924 |
23-
| **Total** | | | **413,903** |
22+
| [LazarusNLP/multilingual-NLI-26lang-2mil7-id](https://huggingface.co/datasets/LazarusNLP/multilingual-NLI-26lang-2mil7-id) | Natural Language Inference | `(premise, entailment, hypothesis)` | 41,924 |
23+
| *included in v4* |
24+
| [nthakur/swim-ir-monolingual](https://huggingface.co/datasets/nthakur/swim-ir-monolingual) | Passage Retrieval | `(query, positive passage, negative passage)` | 227,145 |
25+
| **Total** | | | **641,048** |
2426

2527
## All Supervised Datasets with MultipleNegativesRankingLoss
2628

@@ -48,6 +50,21 @@ python train_all_mnrl.py \
4850
--learning-rate 2e-5
4951
```
5052

53+
## All Supervised Datasets with CachedMultipleNegativesRankingLoss
54+
55+
### IndoBERT Base
56+
57+
```sh
58+
python train_all_mnrl.py \
59+
--model-name indobenchmark/indobert-base-p1 \
60+
--max-seq-length 128 \
61+
--num-epochs 5 \
62+
--train-batch-size-pairs 384 \
63+
--train-batch-size-triplets 256 \
64+
--mini-batch-size 320 \
65+
--learning-rate 2e-5
66+
```
67+
5168
## References
5269

5370
```bibtex

training/all/all_datasets.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,42 @@ def train_samples() -> List[InputExample]:
196196
return train_samples
197197

198198

199+
@dataclass
200+
class SwimIR:
201+
dataset = load_dataset("nthakur/swim-ir-monolingual", "id", split="train")
202+
203+
@staticmethod
204+
def train_samples() -> List[InputExample]:
205+
train_data = {}
206+
train_samples = []
207+
208+
for datum in SwimIR.dataset:
209+
query = datum["query"].strip()
210+
answer = datum["text"].strip()
211+
title = datum["title"].strip()
212+
213+
if title not in train_data:
214+
train_data[title] = {query: [answer]}
215+
elif title in train_data and query not in train_data[title]:
216+
train_data[title][query] = [answer]
217+
else:
218+
train_data[title][query].append(answer)
219+
220+
for title, queries in train_data.items():
221+
passage_queries = list(queries.keys())
222+
# cannot get a negative sample if only 1 query in that passage
223+
if len(passage_queries) > 1:
224+
for query, answers in queries.items():
225+
positive = random.choice(answers)
226+
# get random negative sample, from different query
227+
random_query = random.choice([q for q in passage_queries if q != query])
228+
negative = random.choice(queries[random_query])
229+
230+
train_samples.append(InputExample(texts=[query, positive, negative]))
231+
232+
return train_samples
233+
234+
199235
@dataclass
200236
class IndoStoryCloze:
201237
dataset = load_dataset("indolem/indo_story_cloze", split="train", trust_remote_code=True)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from dataclasses import dataclass
2+
import math
3+
4+
from datargs import parse
5+
from datasets import load_dataset
6+
from sentence_transformers import SentenceTransformer, InputExample, models, losses
7+
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
8+
9+
from all_datasets import (
10+
IndoNLI,
11+
IndoStoryCloze,
12+
mMARCO,
13+
MIRACL,
14+
SwimIR,
15+
MultilingualNLI,
16+
WReTE,
17+
IndoLEMNTP,
18+
TyDiQA,
19+
FacQA,
20+
LFQAID,
21+
IndoQA,
22+
ParaphraseDetection,
23+
)
24+
from MultiDatasetDataLoader import MultiDatasetDataLoader
25+
26+
27+
@dataclass
28+
class Args:
29+
# data args
30+
model_name: str = "indobenchmark/indobert-base-p1"
31+
# train
32+
max_seq_length: int = 128
33+
# test
34+
test_dataset_name: str = "LazarusNLP/stsb_mt_id"
35+
test_dataset_split: str = "validation"
36+
test_text_column_1: str = "text_1"
37+
test_text_column_2: str = "text_2"
38+
test_label_column: str = "correlation"
39+
# training args
40+
num_epochs: int = 5
41+
train_batch_size_pairs: int = 384
42+
train_batch_size_triplets: int = 256
43+
test_batch_size: int = 32
44+
mini_batch_size: int = 128
45+
learning_rate: float = 2e-5
46+
warmup_ratio: float = 0.1
47+
output_path: str = "exp/all-indobert-base"
48+
use_amp: bool = True
49+
# huggingface hub args
50+
hub_model_id: str = "LazarusNLP/all-indobert-base"
51+
hub_private_repo: bool = True
52+
53+
54+
def main(args: Args):
55+
# Load datasets
56+
raw_datasets = {
57+
"indonli": IndoNLI,
58+
"indolem/indo_story_cloze": IndoStoryCloze,
59+
"unicamp-dl/mmarco": mMARCO,
60+
"miracl/miracl": MIRACL,
61+
"nthakur/swim-ir-monolingual": SwimIR,
62+
"LazarusNLP/multilingual-NLI-26lang-2mil7-id": MultilingualNLI,
63+
"SEACrowd/wrete": WReTE,
64+
"SEACrowd/indolem_ntp": IndoLEMNTP,
65+
"khalidalt/tydiqa-goldp": TyDiQA,
66+
"SEACrowd/facqa": FacQA,
67+
"indonesian-nlp/lfqa_id": LFQAID,
68+
"jakartaresearch/indoqa": IndoQA,
69+
"jakartaresearch/id-paraphrase-detection": ParaphraseDetection,
70+
}
71+
72+
train_ds = [ds.train_samples() for ds in raw_datasets.values()]
73+
test_ds = load_dataset(args.test_dataset_name, split=args.test_dataset_split)
74+
75+
# Intialize model with mean pool
76+
word_embedding_model = models.Transformer(args.model_name, max_seq_length=args.max_seq_length)
77+
dimension = word_embedding_model.get_word_embedding_dimension()
78+
pooling_model = models.Pooling(dimension, pooling_mode="mean")
79+
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
80+
81+
# DataLoader to batch your data
82+
train_dataloader = MultiDatasetDataLoader(
83+
train_ds, batch_size_pairs=args.train_batch_size_pairs, batch_size_triplets=args.train_batch_size_triplets
84+
)
85+
86+
warmup_steps = math.ceil(
87+
len(train_dataloader) * args.num_epochs * args.warmup_ratio
88+
) # 10% of train data for warm-up
89+
90+
# Setup test data for evaluation
91+
test_data = [
92+
InputExample(
93+
texts=[data[args.test_text_column_1], data[args.test_text_column_2]],
94+
label=float(data[args.test_label_column]) / 5.0,
95+
)
96+
for data in test_ds
97+
]
98+
99+
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_data, batch_size=args.test_batch_size)
100+
101+
# Use the denoising auto-encoder loss
102+
train_loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=args.mini_batch_size)
103+
104+
# Call the fit method
105+
model.fit(
106+
train_objectives=[(train_dataloader, train_loss)],
107+
evaluator=evaluator,
108+
epochs=args.num_epochs,
109+
warmup_steps=warmup_steps,
110+
show_progress_bar=True,
111+
optimizer_params={"lr": args.learning_rate, "eps": 1e-6},
112+
output_path=args.output_path,
113+
save_best_model=True,
114+
use_amp=args.use_amp,
115+
)
116+
117+
# Save model to HuggingFace Hub
118+
model.save_to_hub(
119+
args.hub_model_id,
120+
private=args.hub_private_repo,
121+
train_datasets=list(raw_datasets.keys()),
122+
)
123+
124+
125+
if __name__ == "__main__":
126+
args = parse(Args)
127+
main(args)

training/all/train_all_mnrl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
IndoStoryCloze,
1212
mMARCO,
1313
MIRACL,
14+
SwimIR,
1415
MultilingualNLI,
1516
WReTE,
1617
IndoLEMNTP,
@@ -56,6 +57,7 @@ def main(args: Args):
5657
"indolem/indo_story_cloze": IndoStoryCloze,
5758
"unicamp-dl/mmarco": mMARCO,
5859
"miracl/miracl": MIRACL,
60+
"nthakur/swim-ir-monolingual": SwimIR,
5961
"LazarusNLP/multilingual-NLI-26lang-2mil7-id": MultilingualNLI,
6062
"SEACrowd/wrete": WReTE,
6163
"SEACrowd/indolem_ntp": IndoLEMNTP,

0 commit comments

Comments
 (0)