Spaces:
Runtime error
Runtime error
File size: 10,775 Bytes
ee21b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
# Understanding Back-Translation at Scale (Edunov et al., 2018)
This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
## Pre-trained models
Model | Description | Dataset | Download
---|---|---|---
`transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
## Example usage (torch.hub)
We require a few additional Python dependencies for preprocessing:
```bash
pip install subword_nmt sacremoses
```
Then to generate translations from the full model ensemble:
```python
import torch
# List available models
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
# Load the WMT'18 En-De ensemble
en2de_ensemble = torch.hub.load(
'pytorch/fairseq', 'transformer.wmt18.en-de',
checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
tokenizer='moses', bpe='subword_nmt')
# The ensemble contains 5 models
len(en2de_ensemble.models)
# 5
# Translate
en2de_ensemble.translate('Hello world!')
# 'Hallo Welt!'
```
## Training your own model (WMT'18 English-German)
The following instructions can be adapted to reproduce the models from the paper.
#### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
First download and preprocess the data:
```bash
# Download and prepare the data
cd examples/backtranslation/
bash prepare-wmt18en2de.sh
cd ../..
# Binarize the data
TEXT=examples/backtranslation/wmt18_en_de
fairseq-preprocess \
--joined-dictionary \
--source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
--workers 20
# Copy the BPE code into the data-bin directory for future use
cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
```
(Optionally) Train a baseline model (English-German) using just the parallel data:
```bash
CHECKPOINT_DIR=checkpoints_en_de_parallel
fairseq-train --fp16 \
data-bin/wmt18_en_de \
--source-lang en --target-lang de \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 30000 \
--save-dir $CHECKPOINT_DIR
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
# different number of GPUs.
```
Average the last 10 checkpoints:
```bash
python scripts/average_checkpoints.py \
--inputs $CHECKPOINT_DIR \
--num-epoch-checkpoints 10 \
--output $CHECKPOINT_DIR/checkpoint.avg10.pt
```
Evaluate BLEU:
```bash
# tokenized BLEU on newstest2017:
bash examples/backtranslation/tokenized_bleu.sh \
wmt17 \
en-de \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint.avg10.pt
# BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
# compare to 29.46 in Table 1, which is also for tokenized BLEU
# generally it's better to report (detokenized) sacrebleu though:
bash examples/backtranslation/sacrebleu.sh \
wmt17 \
en-de \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint.avg10.pt
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
```
#### Step 2. Back-translate monolingual German data
Train a reverse model (German-English) to do the back-translation:
```bash
CHECKPOINT_DIR=checkpoints_de_en_parallel
fairseq-train --fp16 \
data-bin/wmt18_en_de \
--source-lang de --target-lang en \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 30000 \
--save-dir $CHECKPOINT_DIR
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
# different number of GPUs.
```
Let's evaluate the back-translation (BT) model to make sure it is well trained:
```bash
bash examples/backtranslation/sacrebleu.sh \
wmt17 \
de-en \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint_best.py
# BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
# compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
```
Next prepare the monolingual data:
```bash
# Download and prepare the monolingual data
# By default the script samples 25M monolingual sentences, which after
# deduplication should be just over 24M sentences. These are split into 25
# shards, each with 1M sentences (except for the last shard).
cd examples/backtranslation/
bash prepare-de-monolingual.sh
cd ../..
# Binarize each shard of the monolingual data
TEXT=examples/backtranslation/wmt18_de_mono
for SHARD in $(seq -f "%02g" 0 24); do \
fairseq-preprocess \
--only-source \
--source-lang de --target-lang en \
--joined-dictionary \
--srcdict data-bin/wmt18_en_de/dict.de.txt \
--testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
--destdir data-bin/wmt18_de_mono/shard${SHARD} \
--workers 20; \
cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
done
```
Now we're ready to perform back-translation over the monolingual data. The
following command generates via sampling, but it's possible to use greedy
decoding (`--beam 1`), beam search (`--beam 5`),
top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
```bash
mkdir backtranslation_output
for SHARD in $(seq -f "%02g" 0 24); do \
fairseq-generate --fp16 \
data-bin/wmt18_de_mono/shard${SHARD} \
--path $CHECKPOINT_DIR/checkpoint_best.pt \
--skip-invalid-size-inputs-valid-test \
--max-tokens 4096 \
--sampling --beam 1 \
> backtranslation_output/sampling.shard${SHARD}.out; \
done
```
After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
the back-translations and apply length ratio filters:
```bash
python examples/backtranslation/extract_bt_data.py \
--minlen 1 --maxlen 250 --ratio 1.5 \
--output backtranslation_output/bt_data --srclang en --tgtlang de \
backtranslation_output/sampling.shard*.out
# Ensure lengths are the same:
# wc -l backtranslation_output/bt_data.{en,de}
# 21795614 backtranslation_output/bt_data.en
# 21795614 backtranslation_output/bt_data.de
# 43591228 total
```
Binarize the filtered BT data and combine it with the parallel data:
```bash
TEXT=backtranslation_output
fairseq-preprocess \
--source-lang en --target-lang de \
--joined-dictionary \
--srcdict data-bin/wmt18_en_de/dict.en.txt \
--trainpref $TEXT/bt_data \
--destdir data-bin/wmt18_en_de_bt \
--workers 20
# We want to train on the combined data, so we'll symlink the parallel + BT data
# in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
# and the BT data as "train1", so that fairseq will combine them automatically
# and so that we can use the `--upsample-primary` option to upsample the
# parallel data (if desired).
PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
mkdir -p $COMB_DATA
for LANG in en de; do \
ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
for EXT in bin idx; do \
ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
done; \
done
```
#### 3. Train an English-German model over the combined parallel + BT data
Finally we can train a model over the parallel + BT data:
```bash
CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
fairseq-train --fp16 \
data-bin/wmt18_en_de_para_plus_bt \
--upsample-primary 16 \
--source-lang en --target-lang de \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 100000 \
--save-dir $CHECKPOINT_DIR
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
# different number of GPUs.
```
Average the last 10 checkpoints:
```bash
python scripts/average_checkpoints.py \
--inputs $CHECKPOINT_DIR \
--num-epoch-checkpoints 10 \
--output $CHECKPOINT_DIR/checkpoint.avg10.pt
```
Evaluate BLEU:
```bash
# tokenized BLEU on newstest2017:
bash examples/backtranslation/tokenized_bleu.sh \
wmt17 \
en-de \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint.avg10.pt
# BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
# compare to 32.35 in Table 1, which is also for tokenized BLEU
# generally it's better to report (detokenized) sacrebleu:
bash examples/backtranslation/sacrebleu.sh \
wmt17 \
en-de \
data-bin/wmt18_en_de \
data-bin/wmt18_en_de/code \
$CHECKPOINT_DIR/checkpoint.avg10.pt
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
```
## Citation
```bibtex
@inproceedings{edunov2018backtranslation,
title = {Understanding Back-Translation at Scale},
author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
year = 2018,
}
```
|