-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
173 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,63 @@ | ||
# On-grammar-improvements-of-GPT-2-generation | ||
# On-grammar-improvements-of-GPT-2-generation | ||
|
||
This repository includes the complete code for the paper XXXX. The purpose of this repository is to make experiments reproducable and give advanced insights into the experiments that were conducted. | ||
|
||
<img src="architecture.png" alt="drawing" width="200"/> | ||
|
||
## Getting Started | ||
|
||
This repository includes the following things: | ||
|
||
- Documentation of the Dataset building process | ||
- Finetuning, Grammar Correction and Generation scripts that were used during this research project | ||
- Documentation of the complete evaluation process | ||
- A mountain of generated samples that was used during evaluation | ||
- Documentation of the model combination evaluation | ||
- Documentation of generating samples referenced in our paper | ||
|
||
|
||
## Build on | ||
|
||
* [HuggingFace -Transformers](https://github.com/huggingface/transformers) | ||
|
||
### Installation | ||
|
||
To install dependencies simply run | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
You should be good to go. | ||
|
||
|
||
|
||
|
||
## The following external resources should be added in order to retrace all steps: | ||
|
||
- LAMBADA data files should be extracted to the LAMBADA folder (downloadable from https://wiki.cimec.unitn.it/tiki-index.php?page=CLIC) | ||
- The GPT-2 generation datasets should be extracted to the original_data folder (downloadable from https://github.com/openai/gpt-2-output-dataset). For the purpose of this paper, only the small-117M datasets are needed. | ||
- The model checkpoints are needed to retrace everything. If needed they are available by contacting Gideon-Stein. | ||
|
||
|
||
## Generation script usage: | ||
``` | ||
python transgenerator_translation.py --model_path=../trained_models/the_dream_final_3/checkpoint-257616/pytorch_model.bin --text_path ../build_data/EOS_new_filter_700.txt --n_data 1000 --save_path the_dream_filter_700_3_1.p | ||
python run_generation_edited.py --model_name_or_path=model_save/only_correctedAll/pytorch_model.bin --save_name oootf | ||
``` | ||
Parameters can be added and changed accordingly to the script. | ||
|
||
## Finetuning script usage: | ||
``` | ||
python run_lm_finetuning_frozen_3.py --output_dir=model_save/the_dream_finetune_3 --model_type=gpt2 --model_name_or_path=gpt2 --do_train --train_data_file=classic_finetune_train.txt --per_gpu_train_batch_size 1 --gradient_accumulation_steps 4 --save_steps 41599 --save_total_limit 20 --num_train_epochs 20 | ||
python run_generation_edited.py --model_name_or_path=model_save/the_dream_classic_finetune_2/first/checkpoint-41599/pytorch_model.bin --save_name generate_cf_1 --max_length 1024 --n_sentences 100 | ||
``` | ||
|
||
|
||
|
||
|
||
## Authors | ||
|
||
* **Gideon Stein** - *Initial work* - [Github](https://github.com/Gideon-Stein) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
absl-py==0.10.0 | ||
argon2-cffi==20.1.0 | ||
astunparse==1.6.3 | ||
async-generator==1.10 | ||
attrs==20.2.0 | ||
backcall==0.2.0 | ||
bleach==3.2.1 | ||
boto3==1.16.51 | ||
botocore==1.19.51 | ||
cachetools==4.1.1 | ||
certifi==2020.6.20 | ||
cffi==1.14.3 | ||
chardet==3.0.4 | ||
click==7.1.2 | ||
colorama==0.4.3 | ||
cycler==0.10.0 | ||
decorator==4.4.2 | ||
defusedxml==0.6.0 | ||
entrypoints==0.3 | ||
filelock==3.0.12 | ||
future==0.18.2 | ||
gast==0.3.3 | ||
google-auth==1.22.1 | ||
google-auth-oauthlib==0.4.1 | ||
google-pasta==0.2.0 | ||
grpcio==1.32.0 | ||
h5py==2.10.0 | ||
idna==2.10 | ||
importlib-metadata==2.0.0 | ||
ipykernel==5.3.4 | ||
ipython==7.18.1 | ||
ipython-genutils==0.2.0 | ||
ipywidgets==7.5.1 | ||
jedi==0.17.2 | ||
Jinja2==2.11.2 | ||
jmespath==0.10.0 | ||
joblib==0.17.0 | ||
json5==0.9.5 | ||
jsonschema==3.2.0 | ||
jupyter-client==6.1.7 | ||
jupyter-core==4.6.3 | ||
jupyterlab==2.2.8 | ||
jupyterlab-pygments==0.1.2 | ||
jupyterlab-server==1.2.0 | ||
Keras-Preprocessing==1.1.2 | ||
kiwisolver==1.3.1 | ||
language-check==1.1 | ||
Markdown==3.2.2 | ||
MarkupSafe==1.1.1 | ||
matplotlib==3.3.3 | ||
mistune==0.8.4 | ||
nbclient==0.5.0 | ||
nbconvert==6.0.7 | ||
nbformat==5.0.7 | ||
nest-asyncio==1.4.1 | ||
notebook==6.1.4 | ||
numpy==1.18.5 | ||
oauthlib==3.1.0 | ||
opt-einsum==3.3.0 | ||
packaging==20.4 | ||
pandas==1.1.5 | ||
pandocfilters==1.4.2 | ||
parso==0.7.1 | ||
pickleshare==0.7.5 | ||
Pillow==7.2.0 | ||
prometheus-client==0.8.0 | ||
prompt-toolkit==3.0.7 | ||
protobuf==3.13.0 | ||
pyasn1==0.4.8 | ||
pyasn1-modules==0.2.8 | ||
pycparser==2.20 | ||
Pygments==2.7.1 | ||
pyparsing==2.4.7 | ||
pyrsistent==0.17.3 | ||
python-dateutil==2.8.1 | ||
pytz==2020.4 | ||
pywin32==228 | ||
pywinpty==0.5.7 | ||
pyzmq==19.0.2 | ||
regex==2020.9.27 | ||
requests==2.24.0 | ||
requests-oauthlib==1.3.0 | ||
rsa==4.6 | ||
s3transfer==0.3.3 | ||
sacremoses==0.0.43 | ||
Send2Trash==1.5.0 | ||
sentencepiece==0.1.91 | ||
six==1.15.0 | ||
tensorboard==2.3.0 | ||
tensorboard-plugin-wit==1.7.0 | ||
tensorflow==2.3.1 | ||
tensorflow-estimator==2.3.0 | ||
termcolor==1.1.0 | ||
terminado==0.9.1 | ||
testpath==0.4.4 | ||
tokenizers==0.5.2 | ||
torch==1.6.0+cu101 | ||
torchtext==0.6.0 | ||
torchvision==0.7.0+cu101 | ||
tornado==6.0.4 | ||
tqdm==4.50.0 | ||
traitlets==5.0.4 | ||
transformers==2.8.0 | ||
urllib3==1.25.10 | ||
wcwidth==0.2.5 | ||
webencodings==0.5.1 | ||
Werkzeug==1.0.1 | ||
widgetsnbextension==3.5.1 | ||
wrapt==1.12.1 | ||
zipp==3.3.0 |