distilmodernbert / README.md
andersonbcdefg's picture
Update README.md
32807cd verified
This is a version of [ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) distilled down to 16 layers out of 22.
This reduces the number of parameters from 149M to 119M; however, practically speaking, since the embedding params
do not contribute greatly to latency, the effect is reducing the "trunk" of the model from 110M params to 80M params.
I would expect this to reduce latency by roughly 25% (increasing throughput by roughly 33%).
The last 6 local attention layers were removed:
0. Global
1. Local
2. Local
3. Global
4. Local
5. Local
6. Global
7. Local
8. Local
9. Global
10. Local
11. Local
12. Global
13. Local (REMOVED)
14. Local (REMOVED)
15. Global
16. Local (REMOVED)
17. Local (REMOVED)
18. Global
19. Local (REMOVED)
20. Local (REMOVED)
21. Global
Unfortunately the HuggingFace modeling code for ModernBERT relies on global-local attention patterns being uniform throughout the model,
so loading this bad boy properly takes a bit of model surgery. I hope in the future that the HuggingFace team will update this
model configuration to allow custom striping of global+local layers. For now, here's how to do it:
1. Download the checkpoint (model.pt) from this repository.
2. Initialize ModernBERT-base:
```python
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)
```
3. Remove the layers:
```python
layers_to_remove = [13, 14, 16, 17, 19, 20]
model.model.layers = nn.ModuleList([
layer for idx, layer in enumerate(model.model.layers)
if idx not in layers_to_remove
])
```
4. Load the checkpoint state dict:
```python
state_dict = torch.load("model.pt")
model.model.load_state_dict(state_dict)
```
5. Use the model! Yay!
# Training Information
This model was distilled from ModernBERT-base on the [MiniPile dataset](https://huggingface.co/datasets/JeanKaddour/minipile),
which includes English and code data. Distillation used all 1M samples in this dataset for 1 epoch, MSE loss on the logits,
batch size of 16, AdamW optimizer, and constant learning rate of 1.0e-5.
The embeddings/LM head were frozen and shared between the teacher and student; only the transformer blocks were trained.
I have not yet evaluated this model. However, after the initial model surgery, it failed to correctly complete
"The capital of France is [MASK]", and after training, it correctly says "Paris", so something good happened!