SeerAttention-Llama-3.1-8B-AttnGates
This repo only contains the AttnGates' weights for Llama-3.1-8B-Instruct Model.
SeerAttention introduces learnable AttnGate modules to accelerate the computationally intensive prefill stage of long-context large language models (LLMs) via dynamic block-level sparsity. The AttnGates are trained in a parameter-efficient self-distillation framework, where they learn to mimic the 2D max-pooled attention patterns of the original frozen model, preserving its integrity while avoiding costly retraining. During inference, these gates generate block-sparse binary masks by applying threshold/TopK to their learned soft scores, enabling efficient computation through a custom block-sparse FlashAttention kernel.
Original Github Repo
https://github.com/microsoft/SeerAttention.
Evaluation Results
Perplexity on PG19
Density | 8192 | 16384 | 32768 | 65536 | 131072 |
---|---|---|---|---|---|
1.00 | 10.03 | 9.88 | 9.92 | 9.97 | 10.03 |
0.50 | 10.04 | 9.89 | 9.92 | 9.99 | 10.05 |
0.40 | 10.06 | 9.89 | 9.93 | 9.99 | 10.07 |
0.30 | 10.09 | 9.91 | 9.95 | 10.01 | 10.15 |
0.20 | 10.19 | 9.94 | 9.97 | 10.04 | 10.37 |
0.10 | 10.61 | 10.08 | 10.04 | 10.09 | 10.88 |
LongBench
With threshold set to 2e-3.
Task | 0-4k | 4-8k | 8k+ |
---|---|---|---|
2wikimqa | 51.1 | 47.85 | 33.36 |
gov_report | 35.03 | 35.05 | 34.57 |
hotpotqa | 63.97 | 60.0 | 56.7 |
lcc | 67.98 | 73.18 | 65.28 |
multi_news | 28.1 | 25.78 | 24.25 |
multifieldqa_en | 58.63 | 51.45 | 51.87 |
passage_count | 18.0 | 10.15 | 11.88 |
passage_retrieval_en | 100.0 | 99.0 | 98.0 |
qasper | 47.77 | 44.04 | 39.63 |
repobench-p | 51.78 | 56.24 | 56.75 |
samsum | 43.28 | 41.19 | 45.29 |
trec | 64.0 | 76.0 | 75.0 |
triviaqa | 90.91 | 88.45 | 92.43 |
averaged | 55.43 | 54.49 | 52.69 |
RULER
Dense Baseline | SeerAttn | Avg density | |
---|---|---|---|
4k | 95.53 | 95.53 | 0.87 |
8k | 92.27 | 92.71 | 0.72 |
16k | 92.01 | 92.02 | 0.56 |
32k | 87.63 | 88.49 | 0.46 |
64k | 84.39 | 83.48 | 0.32 |
128k | 76.26 | 73.37 | 0.17 |
LongBenchV2 CoT Benchmark
All the SeerAttention models run with threshold=5e-4.
For R1-Distilled models, we remove the two passes generation setup (think + summary), we directly ask the models to output anwser after thinking. The generation max length is set to 10240.
Model | Overall | Easy | Hard | Short | Medium | Long |
---|---|---|---|---|---|---|
Llama-3.1-8B-Instruct | 30.4 | 31.2 | 29.9 | 37.8 | 24.7 | 29.6 |
SeerAttention-Llama-3.1-8B | 31.6 | 33.3 | 30.5 | 33.9 | 31.6 | 27.8 |
Qwen2.5-14B-Instruct | 34.8 | 37.5 | 33.1 | 44.4 | 32.1 | 24.1 |
SeerAttention-Qwen2.5-14B | 32.8 | 38.0 | 29.6 | 45.0 | 30.2 | 17.6 |
Qwen2.5-32B-Instruct | 36.4 | 42.2 | 32.8 | 47.8 | 29.8 | 30.6 |
SeerAttention-Qwen2.5-32B | 36.4 | 41.1 | 33.4 | 49.4 | 29.8 | 27.8 |
DeepSeek-R1-Distill-Qwen-14B | 34.2 | 43.2 | 28.6 | 45.0 | 27.9 | 28.7 |
SeerAttention-DeepSeek-R1-Distill-Qwen-14B | 31.6 | 35.9 | 28.9 | 41.7 | 26.0 | 25.9 |
DeepSeek-R1-Distill-Qwen-32B | 37.2 | 42.7 | 33.8 | 47.2 | 35.8 | 23.1 |
SeerAttention-DeepSeek-R1-Distill-Qwen-32B | 37.0 | 42.2 | 33.8 | 49.4 | 31.6 | 26.9 |
- Downloads last month
- 167
Model tree for SeerAttention/SeerAttention-Llama-3.1-8B-AttnGates
Base model
meta-llama/Llama-3.1-8B