|
1 |
| -# ODE-Transformer |
2 |
| -This is a code repository for the ACL 2022 paper "ODE Transformer: An Ordinary Differential Equation-Inspired Model for Sequence Generation", which redesigns the Transformer architecture from the ODE perspective via using high-order ODE solvers to enhance the residual connections. |
| 1 | +# ODE Transformer: An Ordinary Differential Equation-Inspired Model for Sequence Generation |
| 2 | +This code is based on Fairseq v0.6.2 |
| 3 | +## Requirements and Installation |
| 4 | +- PyTorch version >= 1.2.0 |
| 5 | +- python version >= 3.6 |
| 6 | + |
| 7 | +## Prepare Data |
| 8 | +### For Machine Translation |
| 9 | + |
| 10 | +#### 1、Download [WMT14' En-De](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) and [WMT14' En-Fr](https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-wmt14en2fr.sh) |
| 11 | + |
| 12 | +#### 2、Preprocessed dataset |
| 13 | + |
| 14 | +### For Abstractive Summarization Task |
| 15 | + |
| 16 | +#### 1、Download [CNN dataset](https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ) and [Daily Mail dataset](https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs) |
| 17 | + |
| 18 | + |
| 19 | +#### 2、Generate binary dataset ```data-bin/cnndm``` |
| 20 | + |
| 21 | +```bash preprocess_cnndaily_bin.sh path/to/cnndm_raw_data``` |
| 22 | + |
| 23 | +### For Grammatical Error Correction Task |
| 24 | + |
| 25 | + #### 1、Download [FCE v2.1 dataset](https://www.cl.cam.ac.uk/research/nl/bea2019st/data/fce_v2.1.bea19.tar.gz)、[Lang-8 Corpus of Learner English dataset](https://docs.google.com/forms/d/e/1FAIpQLSflRX3h5QYxegivjHN7SJ194OxZ4XN_7Rt0cNpR2YbmNV-7Ag/viewform)、[NUCLE dataset](https://sterling8.d2.comp.nus.edu.sg/nucle_download/nucle.php)、[W&I+LOCNESS v2.1 dataset](https://www.cl.cam.ac.uk/research/nl/bea2019st/data/wi+locness_v2.1.bea19.tar.gz) |
| 26 | + |
| 27 | + #### 2、Get CONLL14 test set |
| 28 | + |
| 29 | + ```bash prepare_conll14_test_data.sh``` |
| 30 | + |
| 31 | + #### 3、Preprocessed dataset |
| 32 | + |
| 33 | + ```bash preprocess_gec.sh``` |
| 34 | + |
| 35 | + #### 4、Generate binary dataset ```data-bin/BEA``` |
| 36 | + |
| 37 | + ```bash preprocess_gec_bin.sh``` |
| 38 | + |
| 39 | +## Train |
| 40 | +### For WMT'14 En-De Task |
| 41 | + |
| 42 | +#### Train a RK2-block $\textrm{learnable}\, \gamma_i$ model (6-layer Big model) |
| 43 | + |
| 44 | +```bash train_wmt_en_de.sh``` |
| 45 | + |
| 46 | +``` |
| 47 | +python3 -u train.py data-bin/$data_dir |
| 48 | + --distributed-world-size 8 -s src -t tgt |
| 49 | + --arch transformer_ode_t2t_wmt_en_de_big |
| 50 | + --share-all-embeddings |
| 51 | + --optimizer adam --clip-norm 0.0 |
| 52 | + --adam-betas '(0.9, 0.997)' |
| 53 | + --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000 |
| 54 | + --lr 0.002 --min-lr 1e-09 |
| 55 | + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 |
| 56 | + --max-tokens 4096 |
| 57 | + --update-freq 4 |
| 58 | + --max-epoch 20 |
| 59 | + --dropout 0.3 --attention-dropout 0.1 -- relu-dropout 0.1 |
| 60 | + --no-progress-bar |
| 61 | + --log-interval 100 |
| 62 | + --ddp-backend no_c10d |
| 63 | + --seed 1 |
| 64 | + --save-dir $save_dir |
| 65 | + --keep-last-epochs 10 |
| 66 | +``` |
| 67 | + |
| 68 | + |
| 69 | + |
| 70 | +### For WMT'14 En-Fr Task |
| 71 | + |
| 72 | +#### Train a RK2-block $\textrm{learnable}\, \gamma_i$ model |
| 73 | + |
| 74 | +```bash train_wmt_en_fr.sh``` |
| 75 | + |
| 76 | +``` |
| 77 | +python3 -u train.py data-bin/$data_dir |
| 78 | + --distributed-world-size 8 -s src -t tgt |
| 79 | + --arch transformer_ode_t2t_wmt_en_de_big |
| 80 | + --share-all-embeddings |
| 81 | + --optimizer adam --clip-norm 0.0 |
| 82 | + --adam-betas '(0.9, 0.997)' |
| 83 | + --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000 |
| 84 | + --lr 0.002 --min-lr 1e-09 |
| 85 | + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 |
| 86 | + --max-tokens 4096 |
| 87 | + --update-freq 8 |
| 88 | + --max-epoch 20 |
| 89 | + --dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1 |
| 90 | + --no-progress-bar |
| 91 | + --log-interval 100 |
| 92 | + --ddp-backend no_c10d |
| 93 | + --seed 1 |
| 94 | + --save-dir $save_dir |
| 95 | + --keep-last-epochs 10 |
| 96 | +``` |
| 97 | + |
| 98 | + |
| 99 | + |
| 100 | +### For Abstractive Summarization Task |
| 101 | + |
| 102 | +#### Train a RK2-block $\textrm{learnable}\, \gamma_i$ model |
| 103 | + |
| 104 | +```bash train_cnn_daily.sh``` |
| 105 | + |
| 106 | +``` |
| 107 | +python3 -u train.py data-bin/$data_dir |
| 108 | + --distributed-world-size 8 -s src -t tgt |
| 109 | + --arch transformer_ode_t2t_wmt_en_de |
| 110 | + --share-all-embeddings |
| 111 | + --optimizer adam --clip-norm 0.0 |
| 112 | + --adam-betas '(0.9, 0.997)' |
| 113 | + --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 8000 |
| 114 | + --lr 0.002 --min-lr 1e-09 |
| 115 | + --weight-decay 0.0001 |
| 116 | + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 |
| 117 | + --max-tokens 4096 |
| 118 | + --update-freq 4 |
| 119 | + --max-epoch 20 |
| 120 | + --dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1 |
| 121 | + --truncate-source --skip-invalid-size-inputs-valid-test --max-source-positions 500 |
| 122 | + --no-progress-bar |
| 123 | + --log-interval 100 |
| 124 | + --ddp-backend no_c10d |
| 125 | + --seed 1 |
| 126 | + --save-dir $save_dir |
| 127 | + --keep-last-epochs 10 |
| 128 | +``` |
| 129 | + |
| 130 | +### For Grammatical Error Correction Task |
| 131 | + |
| 132 | +#### Train a RK2-block $\textrm{learnable}\, \gamma_i$ model |
| 133 | +```bash train_gec.sh``` |
| 134 | + |
| 135 | +``` |
| 136 | +python3 -u train.py data-bin/$data_dir |
| 137 | + --distributed-world-size 8 -s src -t tgt |
| 138 | + --arch transformer_ode_t2t_wmt_en_de |
| 139 | + --share-all-embeddings |
| 140 | + --optimizer adam --clip-norm 0.0 |
| 141 | + --adam-betas '(0.9, 0.98)' |
| 142 | + --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 |
| 143 | + --lr 0.0015 --min-lr 1e-09 |
| 144 | + --weight-decay 0.0001 |
| 145 | + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 |
| 146 | + --max-tokens 4096 |
| 147 | + --update-freq 2 |
| 148 | + --max-epoch 55 |
| 149 | + --dropout 0.2 --attention-dropout 0.1 -- relu-dropout 0.1 |
| 150 | + --no-progress-bar |
| 151 | + --log-interval 100 |
| 152 | + --ddp-backend no_c10d |
| 153 | + --seed 1 |
| 154 | + --save-dir $save_dir |
| 155 | + --keep-last-epochs 10 |
| 156 | + --tensorboard-logdir $save_dir" |
| 157 | +``` |
| 158 | + |
| 159 | +## Evaluation |
| 160 | +### For WMT'14 En-De Task |
| 161 | + |
| 162 | +We measure the performance through multi-bleu and sacrebleu |
| 163 | + |
| 164 | +``` |
| 165 | +python3 generate.py \ |
| 166 | +data-bin/wmt-en2de \ |
| 167 | +--path $model_dir/$checkpoint \ |
| 168 | +--gen-subset test \ |
| 169 | +--batch-size 64 \ |
| 170 | +--beam 4 \ |
| 171 | +--lenpen 0.6 \ |
| 172 | +--output hypo.txt \ |
| 173 | +--quiet \ |
| 174 | +--remove-bpe |
| 175 | +``` |
| 176 | + |
| 177 | + |
| 178 | + |
| 179 | +### For WMT'14 En-Fr Task |
| 180 | + |
| 181 | +We measure the performance through multi-bleu and sacrebleu |
| 182 | + |
| 183 | +``` |
| 184 | +python3 generate.py \ |
| 185 | +data-bin/wmt-en2fr \ |
| 186 | +--path $model_dir/$checkpoint \ |
| 187 | +--gen-subset test \ |
| 188 | +--batch-size 64 \ |
| 189 | +--beam 4 \ |
| 190 | +--lenpen 0.6 \ |
| 191 | +--output hypo.txt \ |
| 192 | +--quiet \ |
| 193 | +--remove-bpe |
| 194 | +``` |
| 195 | + |
| 196 | + |
| 197 | + |
| 198 | +### For Abstractive Summarization Task |
| 199 | + |
| 200 | +We use pyrouge as the scoring script. |
| 201 | + |
| 202 | +``` |
| 203 | +python3 generate.py \ |
| 204 | +data-bin/$data_dir \ |
| 205 | +--path $model_dir/$checkpoint \ |
| 206 | +--gen-subset test \ |
| 207 | +--truncate-source \ |
| 208 | +--batch-size 32 \ |
| 209 | +--lenpen 2.0 \ |
| 210 | +--min-len 55 \ |
| 211 | +--max-len-b 140 \ |
| 212 | +--max-source-positions 500 \ |
| 213 | +--beam 4 \ |
| 214 | +--no-repeat-ngram-size 3 \ |
| 215 | +--remove-bpe |
| 216 | +
|
| 217 | +python3 get_rouge.py --decodes_filename cnndm.test.target.tok --targets_filename $model_dir/hypo.sorted.tok |
| 218 | +``` |
| 219 | + |
| 220 | +### For Grammatical Error Correction Task |
| 221 | +We use m2scorer as the scoring script. |
| 222 | + |
| 223 | +``` |
| 224 | +python3 generate.py \ |
| 225 | +data-bin/$data_dir \ |
| 226 | +--path $model_dir/$checkpoint \ |
| 227 | +--gen-subset test \ |
| 228 | +--batch-size 64 \ |
| 229 | +--beam 4 \ |
| 230 | +--lenpen 2.0 \ |
| 231 | +--output hypo.txt \ |
| 232 | +--quiet \ |
| 233 | +--remove-bpe |
| 234 | +
|
| 235 | +path/to/m2scorer path/to/model_output path/to/conll14st-test.m2 |
| 236 | +``` |
| 237 | + |
| 238 | + |
| 239 | +## Results |
| 240 | +### Machine Translation |
| 241 | + |
| 242 | +| Model | Layer | En-De | En-Fr | |
| 243 | +| -------------------------------- | ----- | ----- | ----- | |
| 244 | +| Residual-block (baseline) | 6-6 | 29.21 | 42.89 | |
| 245 | +| RK2-block (learnable $\gamma_i$) | 6-6 | 30.53 | 43.59 | |
| 246 | +| Residual-block (baseline) | 12-6 | 29.91 | 43.22 | |
| 247 | +| RK2-block (learnable $\gamma_i$) | 12-6 | 30.76 | 44.11 | |
| 248 | + |
| 249 | +### Abstractive Summarization Task |
| 250 | + |
| 251 | +| Model | RG-1 | RG-2 | RG-L | |
| 252 | +| --------------------------------- | ---- | ---- | ---- | |
| 253 | +| Residual-block | 40.47 | 17.73 | 37.29 | |
| 254 | +| RK2-block ((learnable $\gamma_i$) | 41.58 | 18.57 | 38.41 | |
| 255 | +| RK4-block | 41.83 | 18.84 | 38.68 | |
| 256 | + |
| 257 | +### Grammatical Error Correction Task |
| 258 | + |
| 259 | +| Model | Prec. | Recall | F_0.5 | |
| 260 | +| ---- | ---- | ---- | ---- | |
| 261 | +| Residual-block | 67.97 | 32.17 |55.61 | |
| 262 | +| RK2-block ((learnable $\gamma_i$) | 68.21 | 35.30 |57.49 | |
| 263 | +| RK4-block | 66.20 | 38.13 |57.71 | |
| 264 | + |
0 commit comments