Skip to content

Commit 86857a5

Browse files
kahnefacebook-github-bot
authored andcommitted
Levenshtein Transformer paper code
Summary: Code for our NeurIPS paper [Levenshtein Transformer](https://arxiv.org/abs/1905.11006) * Added Levenshtein Transformer model, task and criterion class * Added iterative NAT Transformer, insertion Transformer and CMLM Transformer model class for baselines * Add an option for prepending BOS to dictionary class and translation task class Reviewed By: myleott Differential Revision: D17297372 fbshipit-source-id: 54eca60831ae95dc721c2c34e882e1810ee575c7
1 parent 6c1da0f commit 86857a5

25 files changed

+2968
-15
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ modeling and other text generation tasks.
66

77
### What's New:
88

9+
- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
910
- August 2019: [WMT'19 models released](examples/wmt19/README.md)
1011
- July 2019: fairseq relicensed under MIT license
1112
- July 2019: [RoBERTa models and code released](examples/roberta/README.md)
@@ -32,6 +33,13 @@ Fairseq provides reference implementations of various sequence-to-sequence model
3233
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
3334
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
3435
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
36+
- **Non-autoregressive Transformers**
37+
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
38+
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
39+
- Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
40+
- Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
41+
- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
42+
3543

3644
**Additionally:**
3745
- multi-GPU (distributed) training on one machine or across multiple machines
@@ -50,7 +58,7 @@ translation and language modeling datasets.
5058

5159
# Requirements and Installation
5260

53-
* [PyTorch](http://pytorch.org/) version >= 1.1.0
61+
* [PyTorch](http://pytorch.org/) version >= 1.2.0
5462
* Python version >= 3.5
5563
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
5664
* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` option
@@ -92,6 +100,7 @@ as well as example training and evaluation commands.
92100
- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
93101

94102
We also have more detailed READMEs to reproduce results from specific papers:
103+
- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
95104
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
96105
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
97106
- [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Non-autoregressive Neural Machine Translation (NAT)
2+
3+
This page mainly includes instructions for reproducing results from the paper
4+
* [Levenshtein Transformer (Gu et al., 2019)](https://arxiv.org/abs/1905.11006).
5+
6+
We also provided our own implementations for several popular non-autoregressive-based models as reference:<br>
7+
* [Non-Autoregressive Neural Machine Translation (Gu et al., 2017)](https://arxiv.org/abs/1711.02281)<br>
8+
* [Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)](https://arxiv.org/abs/1802.06901)<br>
9+
* [Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)](https://arxiv.org/abs/1902.03249)<br>
10+
* [Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)](https://arxiv.org/abs/1904.09324v2)
11+
12+
## Dataset
13+
14+
First, follow the [instructions to download and preprocess the WMT'14 En-De dataset](../translation#prepare-wmt14en2desh).
15+
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
16+
17+
### Knowledge Distillation
18+
Following [Gu et al. 2019](https://arxiv.org/abs/1905.11006), [knowledge distillation](https://arxiv.org/abs/1606.07947) from an autoregressive model can effectively simplify the training data distribution, which is sometimes essential for NAT-based models to learn good translations.
19+
The easiest way of performing distillation is to follow the [instructions of training a standard transformer model](../translation) on the same data, and then decode the training set to produce a distillation dataset for NAT.
20+
21+
### Download
22+
We also provided the preprocessed [original](http://dl.fbaipublicfiles.com/nat/original_dataset.zip) and [distillation](http://dl.fbaipublicfiles.com/nat/distill_dataset.zip) datasets. Please build the binarized dataset on your own.
23+
24+
25+
## Train a model
26+
27+
Then we can train a nonautoregressive model using the `translation_lev` task and a new criterion `nat_loss`.
28+
Use the `--noise` flag to specify the input noise used on the target sentences.
29+
In default, we run the task for *Levenshtein Transformer*, with `--noise='random_delete'`. Full scripts to run other models can also be found [here](./scripts.md).
30+
31+
The following command will train a *Levenshtein Transformer* on the binarized dataset.
32+
33+
```bash
34+
fairseq-train \
35+
data-bin/wmt14_en_de_distill \
36+
--save-dir checkpoints \
37+
--ddp-backend=no_c10d \
38+
--task translation_lev \
39+
--criterion nat_loss \
40+
--arch levenshtein_transformer \
41+
--noise random_delete \
42+
--share-all-embeddings \
43+
--optimizer adam --adam-betas '(0.9,0.98)' \
44+
--lr 0.0005 --lr-scheduler inverse_sqrt \
45+
--min-lr '1e-09' --warmup-updates 10000 \
46+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
47+
--dropout 0.3 --weight-decay 0.01 \
48+
--decoder-learned-pos \
49+
--encoder-learned-pos \
50+
--apply-bert-init \
51+
--log-format 'simple' --log-interval 100 \
52+
--fixed-validation-seed 7 \
53+
--max-tokens 8000 \
54+
--save-interval-updates 10000 \
55+
--max-update 300000
56+
```
57+
58+
## Translate
59+
60+
Once a model is trained, we can generate translations using an `iterative_refinement_generator` which will based on the model's initial output and iteratively read and greedily refine the translation until (1) the model predicts the same translations for two consecutive iterations; or (2) the generator reaches the maximum iterations (`--iter-decode-max-iter`). Use `--print-step` to check the actual # of iteration for each sentence.
61+
62+
For *Levenshtein Transformer*, it sometimes helps to apply a `--iter-decode-eos-penalty` (typically, 0~3) to penalize the model finishing generation too early and generating too short translations.
63+
64+
65+
For example, to generate with `--iter-decode-max-iter=9`:
66+
```bash
67+
fairseq-generate \
68+
data-bin/wmt14_en_de_distill \
69+
--gen-subset test \
70+
--task translation_lev \
71+
--path checkpoints/checkpoint_best.pt \
72+
--iter-decode-max-iter 9 \
73+
--iter-decode-eos-penalty 0 \
74+
--beam 1 --remove-bpe \
75+
--print-step \
76+
--batch-size 400
77+
```
78+
In the end of the generation, we can see the tokenized BLEU score for the translation.
79+
80+
81+
## Citation
82+
83+
```bibtex
84+
@article{gu2019levenshtein,
85+
title={Levenshtein Transformer},
86+
author={Gu, Jiatao and Wang, Changhan and Zhao, Jake},
87+
journal={arXiv preprint arXiv:1905.11006},
88+
year={2019}
89+
}
90+
```
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Examples of Training scripts for Non-autoregressive Machine Translation models
2+
3+
### Non-autoregressive Transformer (NAT, Gu et al., 2017)
4+
Note that we need to have an additional module to perform "length prediction" (`--length-loss-factor`) before generating the whole sequence.
5+
```bash
6+
fairseq-train \
7+
data-bin/wmt14_en_de_distill \
8+
--save-dir checkpoints \
9+
--ddp-backend=no_c10d \
10+
--task translation_lev \
11+
--criterion nat_loss \
12+
--arch nonautoregressive_transformer \
13+
--noise full_mask \
14+
--share-all-embeddings \
15+
--optimizer adam --adam-betas '(0.9,0.98)' \
16+
--lr 0.0005 --lr-scheduler inverse_sqrt \
17+
--min-lr '1e-09' --warmup-updates 10000 \
18+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
19+
--dropout 0.3 --weight-decay 0.01 \
20+
--decoder-learned-pos \
21+
--encoder-learned-pos \
22+
--pred-length-offset \
23+
--length-loss-factor 0.1 \
24+
--apply-bert-init \
25+
--log-format 'simple' --log-interval 100 \
26+
--fixed-validation-seed 7 \
27+
--max-tokens 8000 \
28+
--save-interval-updates 10000 \
29+
--max-update 300000
30+
```
31+
32+
### Non-autoregressive Transformer with Iterative Refinement (iNAT, Lee et al., 2018)
33+
Note that `--train-step` means how many iterations of refinement we used during training, and `--dae-ratio` controls the ratio of denoising auto-encoder training described in the original paper.
34+
```bash
35+
fairseq-train \
36+
data-bin/wmt14_en_de_distill \
37+
--save-dir checkpoints \
38+
--ddp-backend=no_c10d \
39+
--task translation_lev \
40+
--criterion nat_loss \
41+
--arch nonautoregressive_transformer \
42+
--noise full_mask \
43+
--share-all-embeddings \
44+
--optimizer adam --adam-betas '(0.9,0.98)' \
45+
--lr 0.0005 --lr-scheduler inverse_sqrt \
46+
--min-lr '1e-09' --warmup-updates 10000 \
47+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
48+
--dropout 0.3 --weight-decay 0.01 \
49+
--decoder-learned-pos \
50+
--encoder-learned-pos \
51+
--pred-length-offset \
52+
--length-loss-factor 0.1 \
53+
--train-step 4 \
54+
--dae-ratio 0.5 \
55+
--stochastic-approx \
56+
--apply-bert-init \
57+
--log-format 'simple' --log-interval 100 \
58+
--fixed-validation-seed 7 \
59+
--max-tokens 8000 \
60+
--save-interval-updates 10000 \
61+
--max-update 300000
62+
```
63+
64+
### Insertion Transformer (InsT, Stern et al., 2019)
65+
Note that we need to specify the "slot-loss" (uniform or balanced tree) described in the original paper. Here we use `--label-tau` to control the temperature.
66+
67+
```bash
68+
fairseq-train \
69+
data-bin/wmt14_en_de_distill \
70+
--save-dir checkpoints \
71+
--ddp-backend=no_c10d \
72+
--task translation_lev \
73+
--criterion nat_loss \
74+
--arch insertion_transformer \
75+
--noise random_delete \
76+
--share-all-embeddings \
77+
--optimizer adam --adam-betas '(0.9,0.98)' \
78+
--lr 0.0005 --lr-scheduler inverse_sqrt \
79+
--min-lr '1e-09' --warmup-updates 10000 \
80+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
81+
--dropout 0.3 --weight-decay 0.01 \
82+
--decoder-learned-pos \
83+
--encoder-learned-pos \
84+
--pred-length-offset \
85+
--length-loss-factor 0.1 \
86+
--apply-bert-init \
87+
--log-format 'simple' --log-interval 100 \
88+
--fixed-validation-seed 7 \
89+
--max-tokens 8000 \
90+
--save-interval-updates 10000 \
91+
--max-update 300000
92+
```
93+
94+
95+
### Mask Predict (CMLM, Ghazvininejad et al., 2019)
96+
```bash
97+
fairseq-train \
98+
data-bin/wmt14_en_de_distill \
99+
--save-dir checkpoints \
100+
--ddp-backend=no_c10d \
101+
--task translation_lev \
102+
--criterion nat_loss \
103+
--arch cmlm_transformer \
104+
--noise random_mask \
105+
--share-all-embeddings \
106+
--optimizer adam --adam-betas '(0.9,0.98)' \
107+
--lr 0.0005 --lr-scheduler inverse_sqrt \
108+
--min-lr '1e-09' --warmup-updates 10000 \
109+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
110+
--dropout 0.3 --weight-decay 0.01 \
111+
--decoder-learned-pos \
112+
--encoder-learned-pos \
113+
--apply-bert-init \
114+
--log-format 'simple' --log-interval 100 \
115+
--fixed-validation-seed 7 \
116+
--max-tokens 8000 \
117+
--save-interval-updates 10000 \
118+
--max-update 300000
119+
```
120+
121+
122+
123+
124+
### Levenshtein Transformer (LevT, Gu et al., 2019)
125+
```bash
126+
fairseq-train \
127+
data-bin/wmt14_en_de_distill \
128+
--save-dir checkpoints \
129+
--ddp-backend=no_c10d \
130+
--task translation_lev \
131+
--criterion nat_loss \
132+
--arch levenshtein_transformer \
133+
--noise random_delete \
134+
--share-all-embeddings \
135+
--optimizer adam --adam-betas '(0.9,0.98)' \
136+
--lr 0.0005 --lr-scheduler inverse_sqrt \
137+
--min-lr '1e-09' --warmup-updates 10000 \
138+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
139+
--dropout 0.3 --weight-decay 0.01 \
140+
--decoder-learned-pos \
141+
--encoder-learned-pos \
142+
--apply-bert-init \
143+
--log-format 'simple' --log-interval 100 \
144+
--fixed-validation-seed 7 \
145+
--max-tokens 8000 \
146+
--save-interval-updates 10000 \
147+
--max-update 300000
148+
```

0 commit comments

Comments
 (0)