JustinLin610's picture
first commit
ee21b96
|
raw
history blame
3.27 kB

Truncated Backpropagation Through Time (BPTT)

Truncated BPTT is a useful technique for training language models on very long sequences. Typically a long sequences is split into chunks and a language model is trained over the chunks sequentially. The LM may condition on previous chunks, but gradients only flow through the current chunk. This technique was the basis for the paper: Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context, which achieved state-of-the-art language modeling results at the time of publication.

It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since we need to iterate over the data sequentially and disable any batch shuffling logic. The code provided in this example illustrates how to implement Truncated BPTT in fairseq by overriding FairseqTask::get_batch_iterator to iterate over the data sequentially. Crucially, this example supports batching and multi-GPU (data parallel) training.

0. Setup

First, see the general language modeling README for instructions on preprocessing the WikiText-103 data.

1. Train a Transformer-XL model on WikiText-103

We will train a 16-layer Transformer-XL model following the hyperparameters used in the original paper.

The following command assumes 4 GPUs, so that the total batch size is 60 sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs:

CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
    --user-dir examples/truncated_bptt \
    data-bin/wikitext-103/ \
    --task truncated_bptt_lm --tokens-per-sample 150 \
    --batch-size 15 --max-update 200000 \
    --arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \
    --d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \
    --optimizer adam --clip-norm 0.25 \
    --lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025  \
    --log-format json --log-interval 25 \
    --fp16

If training on a single GPU, set --update-freq=4 to accumulate 4x gradients and simulate training on 4 GPUs.

2. Evaluate
fairseq-eval-lm data-bin/wikitext-103/ \
    --path checkpoints/checkpoint_best.pt \
    --user-dir examples/truncated_bptt/ \
    --task truncated_bptt_lm \
    --batch-size 1 --required-batch-size-multiple 1 \
    --model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \
    --tokens-per-sample 64
# ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537
# ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s)
# ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70
# Compare to 24.0 test perplexity from the paper

Note: During training the model saw 150 tokens of context (--tokens-per-sample=150) and 150 extra memory tokens (--mem-len=150). During evaluation we measure perplexity on sequences of 64 tokens (--tokens-per-sample=64) and increase the memory length (--model-overrides='{"mem_len":640}'). These settings match the evaluation settings from the original paper.