Pretraining Without Attention(BiGS)

Official JAX Models with maximal sequence length 512

Paper

BiGS

This repository contains BiGS's jax model definitions, pretrained models weights, training and fintuning code for our paper exploring using state space models for pretraining. You can find more details in our paper.

Pretraining Without Attention
Junxiong Wang, Jing Nathan Yan, Albert Gu, Alexander M.Rush
Cornell University, Cornell Tech, DeepMind

Transformers have been essential to pretraining success in NLP. While other architectures have been used, downstream accuracy is either significantly worse, or requires attention layers to match standard benchmarks such as GLUE. This work explores pretraining without attention by using recent advances in sequence routing based on state-space models (SSMs). Our proposed model, Bidirectional Gated SSM (BiGS), combines SSM layers with a multiplicative gating architecture that has been effective in simplified sequence modeling architectures. The model learns static layers that do not consider pair-wise interactions. Even so, BiGS is able to match BERT pretraining accuracy on GLUE and can be extended to long-form pretraining of 4096 tokens without approximation. Analysis shows that while the models have similar accuracy, the approach has significantly different inductive biases than BERT in terms of interactions and syntactic representations.

Load Masked Language Model

import jax
from jax import numpy as jnp
from transformers import BertTokenizer
from BiGS.modeling_flax_bigs import FlaxBiGSForMaskedLM

tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = FlaxBiGSForMaskedLM.from_pretrained('JunxiongWang/BiGS_512')

text = "The goal of life is [MASK]."
encoded_input = tokenizer(text, return_tensors='np', padding='max_length', max_length=512)
output = model(**encoded_input)
tokenizer.convert_ids_to_tokens(jnp.flip(jnp.argsort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10])
# output: ['happiness', 'love', 'peace', 'perfection', 'life', 'enlightenment', 'god', 'survival', 'freedom', 'good']
jnp.flip(jnp.sort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10]
# probability: [0.16052087, 0.04306792, 0.03651363, 0.03468223, 0.02927081, 0.02549769, 0.02385132, 0.02261189, 0.01672831, 0.01619471]

text = "Paris is the [MASK] of France."
encoded_input = tokenizer(text, return_tensors='np', padding='max_length', max_length=512)
output = model(**encoded_input)
tokenizer.convert_ids_to_tokens(jnp.flip(jnp.argsort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10])
# output: ['capital', 'centre', 'center', 'city', 'capitol', 'prefecture', 'headquarters', 'president', 'metropolis', 'heart']
jnp.flip(jnp.sort(jax.nn.softmax(output.logits[encoded_input['input_ids']==103]))[0])[:10]
# probability: [0.9981787 , 0.00034076, 0.00026992, 0.00026926, 0.00017787, 0.00004816, 0.00004256, 0.00003716, 0.00003634, 0.00002893]

Load Sequence Classification Model

from BiGS.modeling_flax_bigs import FlaxBiGSForSequenceClassification
model = FlaxBiGSForSequenceClassification.from_pretrained('JunxiongWang/BiGS_512')

Load Question Answering Model

from BiGS.modeling_flax_bigs import FlaxBiGSForQuestionAnswering
model = FlaxBiGSForQuestionAnswering.from_pretrained('JunxiongWang/BiGS_512')

Load Multiple Choice Classification Model

from BiGS.modeling_flax_bigs import FlaxBiGSForMultipleChoice
model = FlaxBiGSForMultipleChoice.from_pretrained('JunxiongWang/BiGS_512')

GLUE Experiments

GLUE is made up of a total of 9 different tasks. You can use this python script to run GLUE tasks.

We finetune BiGS on TPU-v3 with 8 cores. Since the batch size per device is 2, the total number of batch size is 16.

export TASK_NAME=cola

python run_glue2.py \
    --model_name_or_path JunxiongWang/BiGS_512 \
    --task_name $TASK_NAME \
    --max_seq_length 512 \
    --learning_rate 2e-5 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 2 \
    --logging_steps 100 \
    --eval_steps 500 \
    --weight_decay 0.01 \
    --output_dir BiGS_$TASK_NAME/

Those give us the following result

Task Metric Result
CoLA Matthews corr 67.9
SST-2 Accuracy 93.8
QQP Accuracy/F1 91.4/88.4
MNLI Matched acc./Mismatched acc. 86.2
QNLI Accuracy 91.6
MRPC F1/Accuracy 86.4/80.4
STS-B Pearson/Spearman corr. 89.1/89.0
RTE Accuracy 73.3

If you use our models, please cite the following papers.

@article{wang2022pretraining,
  title={Pretraining Without Attention},
  author={Wang, Junxiong and Yan, Jing Nathan and Gu, Albert and Rush, Alexander M},
  journal={arXiv preprint arXiv:2212.10544},
  year={2022}
}
Downloads last month
5
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train JunxiongWang/BiGS_512