Skip to content

Commit

Permalink
hello Saturday
Browse files Browse the repository at this point in the history
  • Loading branch information
beyondguo committed Nov 19, 2022
1 parent 3f8e0cf commit e29dc32
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 50 deletions.
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
# 💡GENIUS: Generating text using sketches as input
# 💡GENIUS – generating text using sketches!

**基于草稿的文本生成模型**

- Paper: [GENIUS: Sketch-based Language Model Pre-training via Extreme and Selective Masking for Text Generation and Augmentation](https://github.com/beyondguo/genius/blob/master/GENIUS_gby_arxiv.pdf)
- **Paper: [GENIUS: Sketch-based Language Model Pre-training via Extreme and Selective Masking for Text Generation and Augmentation](https://github.com/beyondguo/genius/blob/master/GENIUS_gby_arxiv.pdf)**

💡**GENIUS** is a powerful conditional text generation model using sketches as input, which can fill in the missing contexts for a given **sketch** (key information consisting of textual spans, phrases, or words, concatenated by mask tokens). GENIUS is pre-trained on a large- scale textual corpus with a novel *reconstruction from sketch* objective using an *extreme and selective masking* strategy, enabling it to generate diverse and high-quality texts given sketches.

**Example 1:**

- sketch: `__machine learning__my research interest__data science__`
- BART: `The machine learning aspect of my research interest in data science.`
- sketch: `__ machine learning __ my research interest __ data science __`
- **GENIUS**: `I am a Ph.D. student in machine learning, and my research interest is in data science. I am interested in understanding how humans and machines interact and how we can improve the quality of life for people around the world.`

**Example 2:**

- sketch: `自然语言处理__谷歌__通用人工智能__`
- BART: `自然语言处理是谷歌的通用人工智能技术`
- **GENIUS**: `自然语言处理是谷歌在通用人工智能领域的一个重要研究方向,其目的是为了促进人类智能的发展。 `


Expand All @@ -33,13 +31,20 @@

| Model | #params | Language | comment|
|------------------------|--------------------------------|-------|---------|
| [`genius-large`](https://huggingface.co/beyond/genius-large) | 406M | English | The version used in paper |
| [`genius-large`](https://huggingface.co/beyond/genius-large) | 406M | English | The version used in **paper** (recommend) |
| [`genius-large-k2t`](https://huggingface.co/beyond/genius-large-k2t) | 406M | English | keywords-to-text |
| [`genius-base`](https://huggingface.co/beyond/genius-base) | 139M | English | smaller version |
| [`genius-base-ps`](https://huggingface.co/beyond/genius-base) | 139M | English | pre-trained both in paragraphs and short sentences |
| [`genius-base-chinese`](https://huggingface.co/beyond/genius-base-chinese) | 116M | 中文 | 在一千万纯净中文段落上预训练|

<img src="https://cdn.jsdelivr.net/gh/beyondguo/mdnice_pictures/typora/sega-hf-api.jpg" width="50%" />
![image-20221119191940969](https://cdn.jsdelivr.net/gh/beyondguo/mdnice_pictures/typora/202211191919005.png)




More Examples:

![image-20221119184950762](https://cdn.jsdelivr.net/gh/beyondguo/mdnice_pictures/typora/202211191849815.png)

## Usage

Expand Down
File renamed without changes.
20 changes: 10 additions & 10 deletions augmentation_clf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,34 @@ python conditional_clm_finetune.py --clm_model_path gpt2 --dataset_name sst2new_
python conditional_clm_clf.py --dataset_name sst2new_50 --clm_model_path ../saved_models/c-clm/sst2new_50_gpt2
```
SEGA (Sketch-based Generative Augmentation, Ours):
GeniusAug (Sketch-based Generative Augmentation, Ours):
```shell
python sega_clf.py \
python genius_clf.py \
--dataset_name ng_50 \
--sega_model_path ../saved_models/bart-base-c4-realnewslike-4templates-passage-and-sent-max15sents_2-sketch4/checkpoint-215625 \
--genius_model_path beyond/genius-large \
--template 4 \
--sega_version sega-base-t4 \
--genius_version genius-base-t4 \
--n_aug 4 \
--add_prompt
```
SEGA fine-tune on downstream datasets:
GeniusAug fine-tune on downstream datasets:
```shell
CUDA_VISIBLE_DEVICES=0 python sega_finetune.py \
CUDA_VISIBLE_DEVICES=0 python genius_finetune.py \
--dataset_name yahooA10k_200 \
--checkpoint ../saved_models/bart-large-c4-l_50_200-d_13799838-yake_mask-t_3900800/checkpoint-152375 \
--checkpoint beyond/genius-large \
--max_num_sent 5 \
--num_train_epochs 10 \
--batch_size 16
```
SEGA-mixup
GeniusAug-mixup
```shell
python sega_mixup_clf.py \
python genius_mixup_clf.py \
--dataset_name imdb_50 \
--max_ngram 3 \
--sketch_n_kws 15 \
--extract_global_kws \
--sega_version sega-mixup \
--genius_version genius-mixup \
--n_aug 4
```
26 changes: 13 additions & 13 deletions augmentation_clf/sega_clf.py → augmentation_clf/genius_clf.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
"""
SEGA (Sketch-based Generative Augmentation) for Classification
GeniusAug for Classification
example script:
python sega_clf.py \
python genius_clf.py \
--dataset_name imdb_50 \
--sega_model_path ../saved_models/bart-base-c4-realnewslike-4templates-passage-and-sent-max15sents_2-sketch4/checkpoint-215625 \
--genius_model_path beyond/genius-large \
--template 4 \
--add_prompt \
--sega_version sega-base-t4 \
--genius_version genius-base-t4 \
--n_aug 2
"""
import sys
sys.path.append('../')
from transformers import pipeline
import pandas as pd
from sega_utils import SketchExtractor, List2Dataset, setup_seed, get_stopwords
from genius_utils import SketchExtractor, List2Dataset, setup_seed, get_stopwords
setup_seed(5)
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument('--dataset_name', type=str, default='bbc_50', help='dataset dir name')
parser.add_argument('--sega_model_path', type=str, default=None, help='sega model path')
parser.add_argument('--genius_model_path', type=str, default=None, help='genius model path')
parser.add_argument('--template', type=int, help='1,2,3,4')
parser.add_argument('--max_ngram', type=int,default=3, help='3 for normal passages. If the text is too short, can be set smaller')
parser.add_argument('--aspect_only', action='store_true',default=False, help='')
parser.add_argument('--no_aspect', action='store_true',default=False, help='')
parser.add_argument('--max_length', type=int, default=200, help='')
parser.add_argument('--add_prompt', action='store_true', default=True, help='if set, will prepend label prefix to sketches')
parser.add_argument('--n_aug', type=int, default=1, help='how many times to augment')
parser.add_argument('--sega_version', type=str, help='to custom output filename')
parser.add_argument('--genius_version', type=str, help='to custom output filename')
parser.add_argument('--device', type=int, default=0, help='cuda device index, if not found, will switch to cpu')
args = parser.parse_args()

Expand All @@ -47,12 +47,12 @@
label2desc = {label:label for label in set(labels)}
print(label2desc)

if args.sega_model_path is not None:
checkpoint = args.sega_model_path
if args.genius_model_path is not None:
checkpoint = args.genius_model_path
else:
checkpoint = ''
print('sega checkpoint:', checkpoint)
sega = pipeline('text2text-generation', model=checkpoint, device=args.device, framework='pt')
print('genius checkpoint:', checkpoint)
genius = pipeline('text2text-generation', model=checkpoint, device=args.device, framework='pt')


sketcher = SketchExtractor(model='bert')
Expand Down Expand Up @@ -92,7 +92,7 @@ def my_topk(text):
print('Generating new samples...')
new_contents = []
for _ in range(args.n_aug):
for out in tqdm(sega(
for out in tqdm(genius(
sketch_dataset, num_beams=3, do_sample=True,
num_return_sequences=1, max_length=args.max_length,
batch_size=50, truncation=True)):
Expand All @@ -111,7 +111,7 @@ def my_topk(text):
corresponding_sketches = ['ORIGINAL-SAMPLE'] * len(labels) + all_sketches
assert len(augmented_contents) == len(augmented_labels), 'wrong num'
assert len(augmented_contents) == len(corresponding_sketches), 'wrong num'
args.output_name = f"sega_prompt{args.add_prompt}_asonly_{args.aspect_only}_{args.sega_version}_aug{args.n_aug}"
args.output_name = f"genius_prompt{args.add_prompt}_asonly_{args.aspect_only}_{args.genius_version}_aug{args.n_aug}"
augmented_dataset = pd.DataFrame({'content':augmented_contents, 'sketch':corresponding_sketches, 'label':augmented_labels})
augmented_dataset.to_csv(f'../data_clf/{args.dataset_name}/{args.output_name}.csv')

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
SEGA Fine-tuning on target dataset
GENIUS Fine-tuning on target dataset
To make SEGA better suited for downstream tasks, we finetune the model with prompts.
To make GENIUS better suited for downstream tasks, we finetune the model with prompts.
- cut the original text into smaller chunks (less than 15 sentences)
- extract the label-aware sketch from the chunk, and prepend the corresponding label prefix to the sketch
- prepend the corresponding label prefix the each chunk
- prompt + sketch --> SEGA --> prompt + content
- prompt + sketch --> GENIUS --> prompt + content
example script:
CUDA_VISIBLE_DEVICES=0
python sega_finetune.py \
python genius_finetune.py \
--dataset_name bbc_50 \
--checkpoint ../saved_models/bart-large-c4-l_50_200-d_13799838-yake_mask-t_3900800/checkpoint-152375 \
--checkpoint beyond/genius-large \
--max_num_sent 15 \
--num_train_epochs 10 \
--batch_size 16
Expand All @@ -28,7 +28,7 @@
nltk.download('stopwords')
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
from sega_utils import SketchExtractor, List2Dataset, setup_seed, get_stopwords
from genius_utils import SketchExtractor, List2Dataset, setup_seed, get_stopwords
from rouge_score import rouge_scorer
from datasets import load_metric
import pandas as pd
Expand All @@ -40,7 +40,7 @@
import argparse
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument('--dataset_name', type=str, default='bbc_50', help='dataset dir name')
parser.add_argument('--checkpoint', type=str, default='', help='sega checkpoint')
parser.add_argument('--checkpoint', type=str, default='', help='genius checkpoint')
parser.add_argument('--aspect_only', action='store_true', default=False, help='')
parser.add_argument('--template', type=int, default=4, help='')
parser.add_argument('--num_train_epochs', type=int, default=10, help='train epochs')
Expand Down Expand Up @@ -164,7 +164,7 @@ def compute_metrics(eval_pred):
##################################################################


output_dir = f"../saved_models/sega_finetuned_for_{args.dataset_name}{args.comment}"
output_dir = f"../saved_models/genius_finetuned_for_{args.dataset_name}{args.comment}"

training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
SEGA-mixup for Classification
GeniusAug-mixup for Classification
example script:
python sega_mixup_clf.py \
python genius_mixup_clf.py \
--dataset_name imdb_50 \
--max_ngram 3 \
--sketch_n_kws 15 \
--extract_global_kws \
--sega_version sega-mixup \
--genius_version genius-mixup \
--n_aug 4
"""
Expand All @@ -17,13 +17,13 @@
from collections import defaultdict
import pandas as pd
import random
from sega_utils import SketchExtractor, List2Dataset, setup_seed, get_stopwords
from genius_utils import SketchExtractor, List2Dataset, setup_seed, get_stopwords
setup_seed(5)
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument('--dataset_name', type=str, default='yahooA10k_50', help='dataset dir name')
parser.add_argument('--sega_model_path', type=str, default=None, help='sega model path')
parser.add_argument('--genius_model_path', type=str, default=None, help='genius model path')
parser.add_argument('--template', type=int, default=4, help='1,2,3,4')
parser.add_argument('--max_ngram', type=int,default=3, help='3 for normal passages. If the text is too short, can be set smaller')
parser.add_argument('--sketch_n_kws', type=int,default=15, help='how many kewords to form the sketch')
Expand All @@ -32,7 +32,7 @@
parser.add_argument('--add_prompt', action='store_true', default=True, help='if set, will prepend label prefix to sketches')
parser.add_argument('--max_length', type=int, default=200, help='')
parser.add_argument('--n_aug', type=int, default=1, help='how many times to augment')
parser.add_argument('--sega_version', type=str, default='sega-mixup', help='to custom output filename')
parser.add_argument('--genius_version', type=str, default='genius-mixup', help='to custom output filename')
parser.add_argument('--device', type=int, default=0, help='cuda device index, if not found, will switch to cpu')
args = parser.parse_args()

Expand All @@ -50,14 +50,14 @@
label2desc = {label:label for label in set(labels)}
print(label2desc)

if args.sega_model_path is not None:
checkpoint = args.sega_model_path
if args.genius_model_path is not None:
checkpoint = args.genius_model_path
else:
# checkpoint = '../saved_models/bart-base-c4-realnewslike-4templates-passage-and-sent-max15sents_2-sketch4/checkpoint-215625'
checkpoint = '../saved_models/bart-large-c4-l_50_200-d_13799838-yake_mask-t_3900800/checkpoint-152375'

print('sega checkpoint:', checkpoint)
sega = pipeline('text2text-generation', model=checkpoint, device=args.device)
print('genius checkpoint:', checkpoint)
genius = pipeline('text2text-generation', model=checkpoint, device=args.device)



Expand Down Expand Up @@ -128,7 +128,7 @@ def contain_alpha(string):


print('Generating new samples...')
for out in tqdm(sega(
for out in tqdm(genius(
sketch_dataset, num_beams=3, do_sample=True,
num_return_sequences=1, max_length=args.max_length,
batch_size=32, truncation=True)):
Expand All @@ -146,7 +146,7 @@ def contain_alpha(string):
corresponding_sketches = ['ORIGINAL-SAMPLE'] * len(labels) + mixup_sketches
assert len(augmented_contents) == len(augmented_labels), 'wrong num'
assert len(augmented_contents) == len(corresponding_sketches), 'wrong num'
args.output_name = f"segaMix_prompt{args.add_prompt}_asonly_{args.aspect_only}_{args.sega_version}_aug{args.n_aug}"
args.output_name = f"geniusMix_prompt{args.add_prompt}_asonly_{args.aspect_only}_{args.genius_version}_aug{args.n_aug}"
augmented_dataset = pd.DataFrame({'content':augmented_contents, 'sketch':corresponding_sketches, 'label':augmented_labels})
augmented_dataset.to_csv(f'../data_clf/{args.dataset_name}/{args.output_name}.csv')

Expand Down

0 comments on commit e29dc32

Please sign in to comment.