mega-encoder-small-16k-v1
This is a "huggingface-native" pretrained encoder-only model with 16384 context length. The model architecture is MEGA.
Numbers
Despite being a long-context model evaluated on a short-context benchmark, MEGA holds up decently:
Model | Size | CTX | Avg |
---|---|---|---|
mega-encoder-small-16k-v1 | 122M | 16384 | 0.777 |
bert-base-uncased | 110M | 512 | 0.7905 |
roberta-base | 125M | 514 | 0.86 |
bert-plus-L8-4096-v1.0 | 88.1M | 4096 | 0.8278 |
mega-wikitext103 | 7.0M | 10000 | 0.48 |
GLUE Details
Model | Size | CTX | Avg | CoLA | SST2 | MRPC | STSB | QQP | MNLI | QNLI | RTE |
---|---|---|---|---|---|---|---|---|---|---|---|
mega-encoder-small-16k-v1 | 122M | 16384 | 0.777 | 0.454 | 0.914 | 0.8404 | 0.906 | 0.894 | 0.806 | 0.842 | 0.556 |
bert-base-uncased | 110M | 512 | 0.7905 | 0.521 | 0.935 | 0.889 | 0.858 | 0.712 | 0.84 | 0.905 | 0.664 |
roberta-base | 125M | 514 | 0.86 | 0.64 | 0.95 | 0.9 | 0.91 | 0.92 | 0.88 | 0.93 | 0.79 |
bert-plus-L8-4096-v1.0 | 88.1M | 4096 | 0.8278 | 0.6272 | 0.906 | 0.8659 | 0.9207 | 0.906 | 0.832 | 0.9 | 0.6643 |
mega-wikitext103 | 7M | 10000 | 0.480 | 0.00 | 0.732 | 0.748 | -0.087 | 0.701 | 0.54 | 0.598 | 0.513 |
The evals for MEGA/bert-plus can be found in this open wandb project and are taken as the max observed values on the validation sets. The values for other models are taken as reported in their papers.
Design
Architecture
This encoder model has 8 layers, hidden size 768, and a feedforward ratio of 3x. The resulting total size is 122M params.
Architecture Details
Details:
- We use a hidden size of 768, and a 3x hidden:feedforward ratio.
- This contrasts with the 2x ratio used in the paper
- To handle the long context, we use MEGA's chunking mechanism, with a chunk length of 1024. As such, there is a linear increase in VRAM usage for multiples of this context length past 1024.
- EMA dimension: we use an EMA dimension of 32 in the interest of modeling long and (potentially) complex sequences
- We use 8 layers, and a context length of 16384 tokens.
- We use
"simple"
relative positional embeddings instead of the rotary embeddings touted in the paper.- This choice came from examining the detailed logs of models trained/evaluated on the LRA benchmark. Models geared towards encoder-type tasks all use the simple relative positional embeddings
- We observed poor performance/unexplicable 'walls' in previous experiments using rotary positional embeddings with MEGA as an encoder
- BART tokenizer: we use the tokenizer from
facebook/bart-large
- This choice was motivated mostly from the desire to use the MEGA encoder in combination with a decoder model in the HF EncoderDecoderModel class in a "huggingface-native" way. BART is supported as a decoder for the this class, and BART's tokenizer has the necessary preprocessing for encoder training.
- Example usage of MEGA+BART to create an encoder-decoder here
- The tokenizer's vocab is exactly the same as Roberta's
Training
This model was trained with the transformers package. You can find (mostly unorganized) training runs on wandb here.
Training Details
- Multi-task training: the majority of training is "standard" MLM, with no next-sentence prediction, etc. However, in the interest of pretraining a useful encoder for fine-tuning on various tasks, we mix-in such tasks in between several of the MLM phases, carrying-over the model's backbone to the next training phase.
- an example would be multiple-choice tuning on the swagdataset
- MLM Mask Ratio 40% default: we use 40% for the MLM ratio, following Wettig et al. 2022. This is decreased slightly for training at longer sequences (8192+) to encourage the model to learn/leverage the available context in predictions.
- AMP with bf16
- Gradient checkpointing implementation: training this (or similar) models at ctx 8192 or longer becomes quite vram intensive despite the linear increase in memory usage
Usage
This is a pretrained model intended to be fine-tuned on various encoder-compatible tasks. However, if you are interested in testing inference with this model or have a deep passion for predicting mask tokens, you can use the following code:
import json
from transformers import pipeline
pipe = pipeline("fill-mask", model="BEE-spoke-data/mega-encoder-small-16k-v1")
text = "I love to <mask> memes."
result = pipe(text)
print(json.dumps(result, indent=2))
Gradient checkpointing implementation
If fine-tuning this model on <task>
, using gradient checkpointing makes training at 16384 context quite feasible. By installing the transformers fork below and passing gradient_checkpointing=True
in the training args, you should be able to finetune at batch size 1 with VRAM to spare on a single 3090/4090.
pip uninstall -y transformers
pip install -U git+https://github.com/pszemraj/transformers.git@mega-gradient-checkpointing
pip install -U huggingface-hub
if there is sufficient interest, we can look at making a PR into the official repo.
Citation
if you find this useful, please consider citing this DOI, it would make us happy.
@misc{beespoke_data_2024,
author = {Peter Szemraj and Vincent Haines and {BEEspoke Data}},
title = {mega-encoder-small-16k-v1 (Revision 1476bcf)},
year = 2024,
url = {https://huggingface.co/BEE-spoke-data/mega-encoder-small-16k-v1},
doi = {10.57967/hf/1837},
publisher = {Hugging Face}
}
- Downloads last month
- 4