Spaces:
Runtime error
Runtime error
# Adaptive Span | |
Adaptive Span is a novel self-attention mechanism that can learn its optimal | |
attention span. This allows us to extend significantly the maximum context size | |
used in Transformer, while maintaining control over their memory footprint | |
and computational time. It uses the Truncated BPTT technique for training, | |
as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md). | |
Adaptive Span was introduced by paper: | |
[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799), | |
which achieved state-of-the-art language modeling results at the time of publication. | |
We manage to reproduce their result in fairseq and keep most of the | |
[original implementation](https://github.com/facebookresearch/adaptive-span) untouched. | |
You can refer to the their sweep file as well if any combination of hyperparameter is not clear. | |
##### 0. Setup | |
First you need to process the Enwik8 dataset, we use the pre-tokenized dataset | |
from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh). | |
You can download the dataset, and then run: | |
```bash | |
fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \ | |
--validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \ | |
--destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20 | |
``` | |
##### 1. Train a Adaptive Span model on Enwik8 | |
We will train a 12-layer Adaptive Span model following the [hyperparameters | |
used in the original | |
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh). | |
The following command assumes 4 GPUs, so that the total batch size is 64 | |
sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs: | |
```bash | |
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ | |
--user-dir examples/adaptive_span \ | |
--data ~/data/enwik8/data-bin/ \ | |
--fp16 --fp16-no-flatten-grads --max-update 600000 \ | |
--task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \ | |
--n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \ | |
--attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \ | |
--validate-interval-updates 1000 \ | |
--lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \ | |
--lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \ | |
--seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07 | |
``` | |
This should land around 1.05 on validation, 1.03 on test. You can lower the | |
--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc | |
improvement to the transformerXL baseline here. | |
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients | |
and simulate training on 4 GPUs. | |
You can also reproduce the transformerXL result on enwik8 using this code base. | |
It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh). | |
You can try by | |
```bash | |
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ | |
--user-dir examples/truncated_bptt \ | |
~/data/enwik8/data-bin/ \ | |
--task truncated_bptt_lm --fp16 --max-update 400000 \ | |
--tokens-per-sample 512 --arch transformer_xl --n-layer 12 \ | |
--d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \ | |
--dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \ | |
--lr-scheduler cosine --warmup-updates 0 \ | |
--lr 0.0 --lr 0.00025 --batch-size 15 \ | |
--update-freq 1 --seed 2 --log-format json --log-interval 25 \ | |
--fp16 | |
``` | |
##### 2. Evaluate | |
For Adaptive Span: | |
```bash | |
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \ | |
--user-dir examples/adaptive_span \ | |
--task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test | |
``` | |
For Transformer-XL evaluation: | |
```bash | |
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \ | |
--user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \ | |
--tokens-per-sample 80 \ | |
--model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \ | |
--gen-subset valid | |
``` | |
*Note:* During training the model saw 512 tokens of context | |
(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation | |
settings from [the original | |
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh). | |