Spaces:
Runtime error
Runtime error
# Truncated Backpropagation Through Time (BPTT) | |
Truncated BPTT is a useful technique for training language models on very long | |
sequences. Typically a long sequences is split into chunks and a language model | |
is trained over the chunks sequentially. The LM may condition on previous | |
chunks, but gradients only flow through the current chunk. This technique was | |
the basis for the paper: [Transformer-XL: Attentive Language Models Beyond a | |
Fixed-Length Context](https://arxiv.org/abs/1901.02860), which achieved | |
state-of-the-art language modeling results at the time of publication. | |
It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since | |
we need to iterate over the data sequentially and disable any batch shuffling | |
logic. The code provided in this example illustrates how to implement Truncated | |
BPTT in fairseq by overriding ``FairseqTask::get_batch_iterator`` to iterate | |
over the data sequentially. Crucially, this example supports batching and | |
multi-GPU (data parallel) training. | |
##### 0. Setup | |
First, see the general [language modeling README](README.md) for instructions on | |
preprocessing the WikiText-103 data. | |
##### 1. Train a Transformer-XL model on WikiText-103 | |
We will train a 16-layer Transformer-XL model following the [hyperparameters | |
used in the original | |
paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh). | |
The following command assumes 4 GPUs, so that the total batch size is 60 | |
sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs: | |
```bash | |
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ | |
--user-dir examples/truncated_bptt \ | |
data-bin/wikitext-103/ \ | |
--task truncated_bptt_lm --tokens-per-sample 150 \ | |
--batch-size 15 --max-update 200000 \ | |
--arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \ | |
--d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \ | |
--optimizer adam --clip-norm 0.25 \ | |
--lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \ | |
--log-format json --log-interval 25 \ | |
--fp16 | |
``` | |
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients | |
and simulate training on 4 GPUs. | |
##### 2. Evaluate | |
```bash | |
fairseq-eval-lm data-bin/wikitext-103/ \ | |
--path checkpoints/checkpoint_best.pt \ | |
--user-dir examples/truncated_bptt/ \ | |
--task truncated_bptt_lm \ | |
--batch-size 1 --required-batch-size-multiple 1 \ | |
--model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \ | |
--tokens-per-sample 64 | |
# ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537 | |
# ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s) | |
# ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70 | |
# Compare to 24.0 test perplexity from the paper | |
``` | |
*Note:* During training the model saw 150 tokens of context | |
(``--tokens-per-sample=150``) and 150 extra memory tokens (``--mem-len=150``). | |
During evaluation we measure perplexity on sequences of 64 tokens | |
(``--tokens-per-sample=64``) and increase the memory length | |
(``--model-overrides='{"mem_len":640}'``). These settings match the evaluation | |
settings from [the original | |
paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh). | |