Skip to content

Commit e57b2e7

Browse files
committed
a revision of ODE Transformer
1 parent 323d5d1 commit e57b2e7

File tree

349 files changed

+32310
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

349 files changed

+32310
-2
lines changed

.DS_Store

10 KB
Binary file not shown.

LICENSE

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
BSD License
2+
3+
For fairseq software
4+
5+
Copyright (c) 2017-present, Facebook, Inc. All rights reserved.
6+
7+
Redistribution and use in source and binary forms, with or without modification,
8+
are permitted provided that the following conditions are met:
9+
10+
* Redistributions of source code must retain the above copyright notice, this
11+
list of conditions and the following disclaimer.
12+
13+
* Redistributions in binary form must reproduce the above copyright notice,
14+
this list of conditions and the following disclaimer in the documentation
15+
and/or other materials provided with the distribution.
16+
17+
* Neither the name Facebook nor the names of its contributors may be used to
18+
endorse or promote products derived from this software without specific
19+
prior written permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

PATENTS

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
Additional Grant of Patent Rights Version 2
2+
3+
"Software" means the fairseq software distributed by Facebook, Inc.
4+
5+
Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
6+
("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
7+
(subject to the termination provision below) license under any Necessary
8+
Claims, to make, have made, use, sell, offer to sell, import, and otherwise
9+
transfer the Software. For avoidance of doubt, no license is granted under
10+
Facebook’s rights in any patent claims that are infringed by (i) modifications
11+
to the Software made by you or any third party or (ii) the Software in
12+
combination with any software or other technology.
13+
14+
The license granted hereunder will terminate, automatically and without notice,
15+
if you (or any of your subsidiaries, corporate affiliates or agents) initiate
16+
directly or indirectly, or take a direct financial interest in, any Patent
17+
Assertion: (i) against Facebook or any of its subsidiaries or corporate
18+
affiliates, (ii) against any party if such Patent Assertion arises in whole or
19+
in part from any software, technology, product or service of Facebook or any of
20+
its subsidiaries or corporate affiliates, or (iii) against any party relating
21+
to the Software. Notwithstanding the foregoing, if Facebook or any of its
22+
subsidiaries or corporate affiliates files a lawsuit alleging patent
23+
infringement against you in the first instance, and you respond by filing a
24+
patent infringement counterclaim in that lawsuit against that party that is
25+
unrelated to the Software, the license granted hereunder will not terminate
26+
under section (i) of this paragraph due to such counterclaim.
27+
28+
A "Necessary Claim" is a claim of a patent owned by Facebook that is
29+
necessarily infringed by the Software standing alone.
30+
31+
A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
32+
or contributory infringement or inducement to infringe any patent, including a
33+
cross-claim or counterclaim.

README.md

Lines changed: 264 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,264 @@
1-
# ODE-Transformer
2-
This is a code repository for the ACL 2022 paper "ODE Transformer: An Ordinary Differential Equation-Inspired Model for Sequence Generation", which redesigns the Transformer architecture from the ODE perspective via using high-order ODE solvers to enhance the residual connections.
1+
# ODE Transformer: An Ordinary Differential Equation-Inspired Model for Sequence Generation
2+
This code is based on Fairseq v0.6.2
3+
## Requirements and Installation
4+
- PyTorch version >= 1.2.0
5+
- python version >= 3.6
6+
7+
## Prepare Data
8+
### For Machine Translation
9+
10+
#### 1、Download [WMT14' En-De](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) and [WMT14' En-Fr](https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-wmt14en2fr.sh)
11+
12+
#### 2、Preprocessed dataset
13+
14+
### For Abstractive Summarization Task
15+
16+
#### 1、Download [CNN dataset](https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ) and [Daily Mail dataset](https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs)
17+
18+
19+
#### 2、Generate binary dataset ```data-bin/cnndm```
20+
21+
```bash preprocess_cnndaily_bin.sh path/to/cnndm_raw_data```
22+
23+
### For Grammatical Error Correction Task
24+
25+
#### 1、Download [FCE v2.1 dataset](https://www.cl.cam.ac.uk/research/nl/bea2019st/data/fce_v2.1.bea19.tar.gz)[Lang-8 Corpus of Learner English dataset](https://docs.google.com/forms/d/e/1FAIpQLSflRX3h5QYxegivjHN7SJ194OxZ4XN_7Rt0cNpR2YbmNV-7Ag/viewform)[NUCLE dataset](https://sterling8.d2.comp.nus.edu.sg/nucle_download/nucle.php)[W&I+LOCNESS v2.1 dataset](https://www.cl.cam.ac.uk/research/nl/bea2019st/data/wi+locness_v2.1.bea19.tar.gz)
26+
27+
#### 2、Get CONLL14 test set
28+
29+
```bash prepare_conll14_test_data.sh```
30+
31+
#### 3、Preprocessed dataset
32+
33+
```bash preprocess_gec.sh```
34+
35+
#### 4、Generate binary dataset ```data-bin/BEA```
36+
37+
```bash preprocess_gec_bin.sh```
38+
39+
## Train
40+
### For WMT'14 En-De Task
41+
42+
#### Train a RK2-block $\textrm{learnable}\, \gamma_i$ model (6-layer Big model)
43+
44+
```bash train_wmt_en_de.sh```
45+
46+
```
47+
python3 -u train.py data-bin/$data_dir
48+
--distributed-world-size 8 -s src -t tgt
49+
--arch transformer_ode_t2t_wmt_en_de_big
50+
--share-all-embeddings
51+
--optimizer adam --clip-norm 0.0
52+
--adam-betas '(0.9, 0.997)'
53+
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000
54+
--lr 0.002 --min-lr 1e-09
55+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
56+
--max-tokens 4096
57+
--update-freq 4
58+
--max-epoch 20
59+
--dropout 0.3 --attention-dropout 0.1 -- relu-dropout 0.1
60+
--no-progress-bar
61+
--log-interval 100
62+
--ddp-backend no_c10d
63+
--seed 1
64+
--save-dir $save_dir
65+
--keep-last-epochs 10
66+
```
67+
68+
69+
70+
### For WMT'14 En-Fr Task
71+
72+
#### Train a RK2-block $\textrm{learnable}\, \gamma_i$ model
73+
74+
```bash train_wmt_en_fr.sh```
75+
76+
```
77+
python3 -u train.py data-bin/$data_dir
78+
--distributed-world-size 8 -s src -t tgt
79+
--arch transformer_ode_t2t_wmt_en_de_big
80+
--share-all-embeddings
81+
--optimizer adam --clip-norm 0.0
82+
--adam-betas '(0.9, 0.997)'
83+
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000
84+
--lr 0.002 --min-lr 1e-09
85+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
86+
--max-tokens 4096
87+
--update-freq 8
88+
--max-epoch 20
89+
--dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1
90+
--no-progress-bar
91+
--log-interval 100
92+
--ddp-backend no_c10d
93+
--seed 1
94+
--save-dir $save_dir
95+
--keep-last-epochs 10
96+
```
97+
98+
99+
100+
### For Abstractive Summarization Task
101+
102+
#### Train a RK2-block $\textrm{learnable}\, \gamma_i$ model
103+
104+
```bash train_cnn_daily.sh```
105+
106+
```
107+
python3 -u train.py data-bin/$data_dir
108+
--distributed-world-size 8 -s src -t tgt
109+
--arch transformer_ode_t2t_wmt_en_de
110+
--share-all-embeddings
111+
--optimizer adam --clip-norm 0.0
112+
--adam-betas '(0.9, 0.997)'
113+
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 8000
114+
--lr 0.002 --min-lr 1e-09
115+
--weight-decay 0.0001
116+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
117+
--max-tokens 4096
118+
--update-freq 4
119+
--max-epoch 20
120+
--dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1
121+
--truncate-source --skip-invalid-size-inputs-valid-test --max-source-positions 500
122+
--no-progress-bar
123+
--log-interval 100
124+
--ddp-backend no_c10d
125+
--seed 1
126+
--save-dir $save_dir
127+
--keep-last-epochs 10
128+
```
129+
130+
### For Grammatical Error Correction Task
131+
132+
#### Train a RK2-block $\textrm{learnable}\, \gamma_i$ model
133+
```bash train_gec.sh```
134+
135+
```
136+
python3 -u train.py data-bin/$data_dir
137+
--distributed-world-size 8 -s src -t tgt
138+
--arch transformer_ode_t2t_wmt_en_de
139+
--share-all-embeddings
140+
--optimizer adam --clip-norm 0.0
141+
--adam-betas '(0.9, 0.98)'
142+
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000
143+
--lr 0.0015 --min-lr 1e-09
144+
--weight-decay 0.0001
145+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
146+
--max-tokens 4096
147+
--update-freq 2
148+
--max-epoch 55
149+
--dropout 0.2 --attention-dropout 0.1 -- relu-dropout 0.1
150+
--no-progress-bar
151+
--log-interval 100
152+
--ddp-backend no_c10d
153+
--seed 1
154+
--save-dir $save_dir
155+
--keep-last-epochs 10
156+
--tensorboard-logdir $save_dir"
157+
```
158+
159+
## Evaluation
160+
### For WMT'14 En-De Task
161+
162+
We measure the performance through multi-bleu and sacrebleu
163+
164+
```
165+
python3 generate.py \
166+
data-bin/wmt-en2de \
167+
--path $model_dir/$checkpoint \
168+
--gen-subset test \
169+
--batch-size 64 \
170+
--beam 4 \
171+
--lenpen 0.6 \
172+
--output hypo.txt \
173+
--quiet \
174+
--remove-bpe
175+
```
176+
177+
178+
179+
### For WMT'14 En-Fr Task
180+
181+
We measure the performance through multi-bleu and sacrebleu
182+
183+
```
184+
python3 generate.py \
185+
data-bin/wmt-en2fr \
186+
--path $model_dir/$checkpoint \
187+
--gen-subset test \
188+
--batch-size 64 \
189+
--beam 4 \
190+
--lenpen 0.6 \
191+
--output hypo.txt \
192+
--quiet \
193+
--remove-bpe
194+
```
195+
196+
197+
198+
### For Abstractive Summarization Task
199+
200+
We use pyrouge as the scoring script.
201+
202+
```
203+
python3 generate.py \
204+
data-bin/$data_dir \
205+
--path $model_dir/$checkpoint \
206+
--gen-subset test \
207+
--truncate-source \
208+
--batch-size 32 \
209+
--lenpen 2.0 \
210+
--min-len 55 \
211+
--max-len-b 140 \
212+
--max-source-positions 500 \
213+
--beam 4 \
214+
--no-repeat-ngram-size 3 \
215+
--remove-bpe
216+
217+
python3 get_rouge.py --decodes_filename cnndm.test.target.tok --targets_filename $model_dir/hypo.sorted.tok
218+
```
219+
220+
### For Grammatical Error Correction Task
221+
We use m2scorer as the scoring script.
222+
223+
```
224+
python3 generate.py \
225+
data-bin/$data_dir \
226+
--path $model_dir/$checkpoint \
227+
--gen-subset test \
228+
--batch-size 64 \
229+
--beam 4 \
230+
--lenpen 2.0 \
231+
--output hypo.txt \
232+
--quiet \
233+
--remove-bpe
234+
235+
path/to/m2scorer path/to/model_output path/to/conll14st-test.m2
236+
```
237+
238+
239+
## Results
240+
### Machine Translation
241+
242+
| Model | Layer | En-De | En-Fr |
243+
| -------------------------------- | ----- | ----- | ----- |
244+
| Residual-block (baseline) | 6-6 | 29.21 | 42.89 |
245+
| RK2-block (learnable $\gamma_i$) | 6-6 | 30.53 | 43.59 |
246+
| Residual-block (baseline) | 12-6 | 29.91 | 43.22 |
247+
| RK2-block (learnable $\gamma_i$) | 12-6 | 30.76 | 44.11 |
248+
249+
### Abstractive Summarization Task
250+
251+
| Model | RG-1 | RG-2 | RG-L |
252+
| --------------------------------- | ---- | ---- | ---- |
253+
| Residual-block | 40.47 | 17.73 | 37.29 |
254+
| RK2-block ((learnable $\gamma_i$) | 41.58 | 18.57 | 38.41 |
255+
| RK4-block | 41.83 | 18.84 | 38.68 |
256+
257+
### Grammatical Error Correction Task
258+
259+
| Model | Prec. | Recall | F_0.5 |
260+
| ---- | ---- | ---- | ---- |
261+
| Residual-block | 67.97 | 32.17 |55.61 |
262+
| RK2-block ((learnable $\gamma_i$) | 68.21 | 35.30 |57.49 |
263+
| RK4-block | 66.20 | 38.13 |57.71 |
264+

docs/Makefile

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Minimal makefile for Sphinx documentation
2+
#
3+
4+
# You can set these variables from the command line.
5+
SPHINXOPTS =
6+
SPHINXBUILD = python -msphinx
7+
SPHINXPROJ = fairseq
8+
SOURCEDIR = .
9+
BUILDDIR = _build
10+
11+
# Put it first so that "make" without argument is like "make help".
12+
help:
13+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14+
15+
.PHONY: help Makefile
16+
17+
# Catch-all target: route all unknown targets to Sphinx using the new
18+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19+
%: Makefile
20+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

docs/_static/theme_overrides.css

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
.wy-table-responsive table td kbd {
2+
white-space: nowrap;
3+
}
4+
.wy-table-responsive table td {
5+
white-space: normal !important;
6+
}
7+
.wy-table-responsive {
8+
overflow: visible !important;
9+
}

0 commit comments

Comments
 (0)