JustinLin610's picture
first commit
ee21b96
|
raw
history blame
3.27 kB
# Truncated Backpropagation Through Time (BPTT)
Truncated BPTT is a useful technique for training language models on very long
sequences. Typically a long sequences is split into chunks and a language model
is trained over the chunks sequentially. The LM may condition on previous
chunks, but gradients only flow through the current chunk. This technique was
the basis for the paper: [Transformer-XL: Attentive Language Models Beyond a
Fixed-Length Context](https://arxiv.org/abs/1901.02860), which achieved
state-of-the-art language modeling results at the time of publication.
It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since
we need to iterate over the data sequentially and disable any batch shuffling
logic. The code provided in this example illustrates how to implement Truncated
BPTT in fairseq by overriding ``FairseqTask::get_batch_iterator`` to iterate
over the data sequentially. Crucially, this example supports batching and
multi-GPU (data parallel) training.
##### 0. Setup
First, see the general [language modeling README](README.md) for instructions on
preprocessing the WikiText-103 data.
##### 1. Train a Transformer-XL model on WikiText-103
We will train a 16-layer Transformer-XL model following the [hyperparameters
used in the original
paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh).
The following command assumes 4 GPUs, so that the total batch size is 60
sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
--user-dir examples/truncated_bptt \
data-bin/wikitext-103/ \
--task truncated_bptt_lm --tokens-per-sample 150 \
--batch-size 15 --max-update 200000 \
--arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \
--d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \
--optimizer adam --clip-norm 0.25 \
--lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \
--log-format json --log-interval 25 \
--fp16
```
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
and simulate training on 4 GPUs.
##### 2. Evaluate
```bash
fairseq-eval-lm data-bin/wikitext-103/ \
--path checkpoints/checkpoint_best.pt \
--user-dir examples/truncated_bptt/ \
--task truncated_bptt_lm \
--batch-size 1 --required-batch-size-multiple 1 \
--model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \
--tokens-per-sample 64
# ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537
# ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s)
# ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70
# Compare to 24.0 test perplexity from the paper
```
*Note:* During training the model saw 150 tokens of context
(``--tokens-per-sample=150``) and 150 extra memory tokens (``--mem-len=150``).
During evaluation we measure perplexity on sequences of 64 tokens
(``--tokens-per-sample=64``) and increase the memory length
(``--model-overrides='{"mem_len":640}'``). These settings match the evaluation
settings from [the original
paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh).