Skip to content

Commit

Permalink
fix support for Chinese
Browse files Browse the repository at this point in the history
  • Loading branch information
beyondguo committed Nov 24, 2022
1 parent e29dc32 commit 46c9204
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 181 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ bert-base-cased/


playground.ipynb
aplayground.ipynb



Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**基于草稿的文本生成模型**

- **Paper: [GENIUS: Sketch-based Language Model Pre-training via Extreme and Selective Masking for Text Generation and Augmentation](https://github.com/beyondguo/genius/blob/master/GENIUS_gby_arxiv.pdf)**
- **Paper: [GENIUS: Sketch-based Language Model Pre-training via Extreme and Selective Masking for Text Generation and Augmentation](https://arxiv.org/abs/2211.10330)**

💡**GENIUS** is a powerful conditional text generation model using sketches as input, which can fill in the missing contexts for a given **sketch** (key information consisting of textual spans, phrases, or words, concatenated by mask tokens). GENIUS is pre-trained on a large- scale textual corpus with a novel *reconstruction from sketch* objective using an *extreme and selective masking* strategy, enabling it to generate diverse and high-quality texts given sketches.

Expand Down Expand Up @@ -121,5 +121,11 @@ Out-of-distribution (OOD) evaluations:
| **GeniusAug-f** | **76.18** | 66.89 | **77.45** | **80.36** | **75.22** |

### BibTeX entry and citation info
TBD
If you find our paper/code/demo useful, please cite our paper:
@article{guo2022genius,
title={GENIUS: Sketch-based Language Model Pre-training via Extreme and Selective Masking for Text Generation and Augmentation},
author={Guo, Biyang and Gong, Yeyun and Shen, Yelong and Han, Songqiao and Huang, Hailiang and Duan, Nan and Chen, Weizhu},
journal={arXiv preprint arXiv:2211.10330},
year={2022}
}

11 changes: 10 additions & 1 deletion genius_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ class SketchExtractor:
def __init__(self, model='yake'):
assert model in ['yake', 'bert','jieba'], '`model` only support `yake`, `bert` or `jieba`'
self.model = model
self.mask = '<mask>'
self.sep = ' '
if model == 'yake': # for English
self.extractor = None
if model == 'bert': # for English
self.extractor = AspectKeyBERT(model='all-MiniLM-L6-v2') # paraphrase-MiniLM-L3-v2 (the fastest LM) all-MiniLM-L6-v2
if model == 'jieba': # for Chinese
print('You are using Chinese version.\n --mask token: "[MASK]"\n--sep:""')
self.extractor = jieba.analyse
self.mask = '[mask]'
self.sep = ''


def get_kws(self, s, max_ngram=3, top=10, aspect_keywords=None, use_aspect_as_doc_embedding=False):
Expand All @@ -48,7 +53,7 @@ def get_kws(self, s, max_ngram=3, top=10, aspect_keywords=None, use_aspect_as_do
return [], kws
return kws_pairs, [p[0] for p in kws_pairs]

def get_sketch_from_kws(self, s, kws, template=4, mask='<mask>', sep=' '):
def get_sketch_from_kws(self, s, kws, template=4):
"""
TODO: keywords extracted by YAKE may not always be the same as original, like "IBM's" will be "IBM".
for template 3/4, a workaround is split keywords into single words, then match
Expand All @@ -58,6 +63,8 @@ def get_sketch_from_kws(self, s, kws, template=4, mask='<mask>', sep=' '):
3 --> keywords ordered by the original order and frequences in `s`, joint by a single space
4 --> same as above, but joint by a single <mask> token (the default GENIUS mode)
"""
mask = self.mask
sep = self.sep
if template == 1:
return ' '.join(kws)
if template == 2:
Expand Down Expand Up @@ -113,6 +120,8 @@ def get_sketch_from_kws(self, s, kws, template=4, mask='<mask>', sep=' '):
masked_text.append(f'{mask}{sep}')
if sep == ' ' and id - all_ids[i-1] == 2 and s[id-1] == ' ': # a space in between
masked_text.append(' ')
if sep == '' and id - all_ids[i-1] == 2:
masked_text.append(f'{sep}{mask}{sep}')
if id - all_ids[i-1] > 2: # something in between
masked_text.append(f'{sep}{mask}{sep}')
masked_text.append(s[id])
Expand Down
171 changes: 0 additions & 171 deletions playground.ipynb

This file was deleted.

10 changes: 4 additions & 6 deletions pre_training/genius_pretrain_chinese.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
nltk.download('stopwords')
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import transformers
from rouge_chinese import Rouge as RougeChinese
from transformers import BertTokenizer, AutoModel, AutoConfig, AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from datasets import load_dataset, load_metric
Expand Down Expand Up @@ -52,7 +52,7 @@ def preprocess_function(examples):


# ROUGE metric:
rouge_score = load_metric("rouge")
rouge_score = RougeChinese
def compute_metrics(eval_pred):
predictions, labels = eval_pred
# Decode generated summaries into text
Expand All @@ -65,11 +65,9 @@ def compute_metrics(eval_pred):
decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
# Compute ROUGE scores
result = rouge_score.compute(
predictions=decoded_preds, references=decoded_labels, use_stemmer=True
)
result = rouge_score.get_scores(decoded_preds, decoded_labels)
# Extract the median scores
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
result = {key: value['f'] * 100 for key, value in result.items()}
return {k: round(v, 4) for k, v in result.items()}


Expand Down
2 changes: 1 addition & 1 deletion pre_training/prepare_genius_pretrain_data_chinese.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def add_sketch_to_dataset(examples):
res['text'].append(p)
_, kws = sketch_extractor.get_kws(p, top=max(len(jieba.lcut(p))//5,1))
# we plan to use `fnlp/bart-large-chinese` for pre-training, the mask token is `[MASK]`
sketch = sketch_extractor.get_sketch_from_kws(p, kws, mask='[MASK]',sep='')
sketch = sketch_extractor.get_sketch_from_kws(p, kws)
res['sketch'].append(sketch)
return res

Expand Down

0 comments on commit 46c9204

Please sign in to comment.