|
--- |
|
license: apache-2.0 |
|
language: |
|
- en |
|
datasets: |
|
- Skylion007/openwebtext |
|
metrics: |
|
- perplexity |
|
- mauve |
|
--- |
|
|
|
# Self-Distillation Through Time (SDTT) |
|
SDTT is a distillation method for diffusion language models. Recent diffusion language models such as [SEDD](https://huggingface.co/louaaron/sedd-small) or [MDLM](https://huggingface.co/kuleshov-group/mdlm-owt) achieve great results. |
|
However, because they cannot use KV-caching (non-causal architecture), it is slow to sample from them. Therefore, we devise a novel distillation method to reduce the inference latency of discrete diffusion models. |
|
After distillation, we can sample up to 8x faster than GPT-2 (that uses KV-caching). Find more details below and on [our GitHub repo](https://github.com/jdeschena/sdtt). |
|
|
|
## Using SDTT |
|
- We released 3 groups of models: |
|
1. The **baseline students** distilled with the `kld`, `mse` and `tvd` objectives, distilled from a model trained for 1M steps. |
|
2. The **students from the scaling experiments**, with sizes `sm`, `md`, `large`, distilled from models trained for 400k steps. |
|
3. The **teachers from the scaling experiments**, with sizes `sm`, `md`, `large`, before any distillation. |
|
- To load those models, first install our code: |
|
```bash |
|
git clone https://github.com/jdeschena/sdtt.git |
|
cd sdtt |
|
pip install -r requirements.txt |
|
pip install flash-attn |
|
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu |
|
pip install -e . |
|
``` |
|
- You can then import our models, sample and evaluate them: |
|
|
|
#### Load the baseline students |
|
```python |
|
from sdtt import load_small_student |
|
student = load_small_student(loss="kld", round=7) # load the kld student after the last distillation round |
|
student = load_small_student(loss="mse", round=2) # load the mse student after the second distillation round |
|
student = load_small_student(loss="tvd", round=1) # load the tvd student after the first distillation round |
|
``` |
|
|
|
#### Load the students from the scaling experiment |
|
```python |
|
from sdtt import load_scaling_student |
|
student = load_scaling_student(size="sm", round=7) # load small student after the last distillation round |
|
student = load_scaling_student(size="md", round=1) # load medium student after the first distillation round |
|
student = load_scaling_student(size="large", round=3) # load large student after the third distillation round |
|
``` |
|
|
|
#### Load the teachers from the scaling experiment |
|
```python |
|
from sdtt import load_scaling_teacher |
|
student = load_scaling_student(size="sm",) # load small teacher |
|
student = load_scaling_student(size="md",) # load medium teacher |
|
student = load_scaling_student(size="large",) # load large teacher |
|
``` |
|
|
|
#### Sample from the pretrained models |
|
```python |
|
from sdtt import load_small_student, load_scaling_student, load_scaling_teacher |
|
import torch |
|
|
|
model = load_small_student(loss="kld", round=7) # load model, see above |
|
model.cuda() # put model on gpu |
|
|
|
# Unconditional generation |
|
tokens = model.sample( |
|
n_samples=8, |
|
num_steps=256, |
|
seq_len=1024, |
|
verbose=True, |
|
) |
|
# Detokenize |
|
uncond_text = model.tokenizer.batch_decode(tokens) |
|
|
|
# Conditional generation, based on a prompt |
|
# Prepare a prompt |
|
prompt = "Today is a great day. The sun is shining," |
|
prompt_tokens = model.tokenizer(prompt)["input_ids"] |
|
prompt_tokens.insert(0, model.tokenizer.bos_token_id) |
|
prompt_tokens = torch.tensor(prompt_tokens, device="cuda") |
|
prompt_len = len(prompt_tokens) |
|
|
|
def project_fn(x): |
|
# Project the first 10 tokens of all examples to the prompt |
|
x[:, :prompt_len] = prompt_tokens |
|
return x # Don't forget to return |
|
|
|
tokens = model.sample( |
|
n_samples=8, |
|
num_steps=256, |
|
seq_len=1024, |
|
verbose=True, |
|
project_fn=project_fn |
|
) |
|
|
|
cond_text = model.tokenizer.batch_decode(tokens) |
|
``` |
|
|
|
|
|
|
|
For more details, please see our github repository: [SDTT](https://github.com/jdeschena/sdtt) |
|
|
|
## Model Details |
|
Our small checkpoints are distilled from the [MDLM](https://github.com/kuleshov-group/mdlm) checkpoints. We also release medium (424M) and large (863M) checkpoints that we pretrained ourselves. |
|
|
|
## Citation |
|
|
|
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. --> |
|
Please cite our work using the bibtex below: |
|
|
|
**BibTeX:** |
|
|
|
``` |
|
@article{deschenaux2024autoregressionfastllmsselfdistillation, |
|
title={Beyond Autoregression: Fast LLMs via Self-Distillation Through Time}, |
|
author={Deschenaux, Justin and Gulcehre, Caglar} |
|
eprint={2410.21035}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.LG}, |
|
url={https://arxiv.org/abs/2410.21035}, |
|
} |
|
``` |
|
|
|
## Contact |
|
Justin Deschenaux (justin.deschenaux@epfl.ch) |