File size: 5,320 Bytes
cabb4bc
 
69a93c1
 
cabb4bc
8721983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69a93c1
d2d47d3
69a93c1
e0f398f
8721983
 
 
 
 
 
 
 
 
 
 
14da305
8721983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7275d
8721983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69a93c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
---
license: mit
datasets:
- DAMO-NLP-SG/LongCorpus-2.5B
---

# CLEX: Continuous Length Extrapolation for Large Language Models
This repo stores the checkpoint of CLEX-Mixtral-8x7B-Chat-32K.


## Features and Highlights of CLEX
![CLEX_diagram](https://github.com/DAMO-NLP-SG/CLEX/assets/18526640/063ffe34-0116-4759-92bf-e22fc7264cdf)

- **Simple and Clear**: _MINIMAL_ code and architecture changes. Only one up-and-down projection layer introduced, _NO_ recurrent memory caching or sparse attention required.
- **Train Short, Test Long**: _NO_ performance drop on the sequences _4x~8x longer_ than the training ones (see [here](https://github.com/DAMO-NLP-SG/CLEX#language-modelling)). 
- **Continuous Length Extrapolation**: Explicitly modeling the continuous dynamics of context window size during length extrapolation.

If you have any questions, feel free to contact us. (Emails: guanzzh.chen@gmail.com, lixin4ever@gmail.com)

## Model Zoo
<div align="center">

| Model Name | Model Type | Starting Point | Train Data |Train Length | MAX Test Length | HF Repo |
|:-----|:-----|:-----------|:-----------|:-----------|:-----------|:------:|
| CLEX-LLaMA-2-7B-16K | base | LLaMA-2-7B | [Redpajama-Book](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T) | 16K | 64K | [link](https://huggingface.co/DAMO-NLP-SG/CLEX-7B-16K) |
| CLEX-LLaMA-2-7B-Chat-16K | chat | CLEX-7B-16K | [UltraChat](https://github.com/thunlp/UltraChat) | 16K | 64K | [link](https://huggingface.co/DAMO-NLP-SG/CLEX-7B-Chat-16K) |
| CLEX-LLaMA-2-7B-64K | base | LLaMA-2-7B | [Redpajama-Book](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T) | 64k | 256K | [link](https://huggingface.co/DAMO-NLP-SG/CLEX-LLaMA-2-7B-64K) |
| CLEX-Phi-2-32K | base | Phi-2-2.7B | [LongCorpus-2.5B](https://huggingface.co/datasets/DAMO-NLP-SG/LongCorpus-2.5B) | 32k | 128K | [link](https://huggingface.co/DAMO-NLP-SG/CLEX-Phi-2-32K) |
| CLEX-Mixtral-8x7B-32K | base | Mixtral-8x7B-v0.1 | [LongCorpus-2.5B](https://huggingface.co/datasets/DAMO-NLP-SG/LongCorpus-2.5B) | 32k | >128K | [link](https://huggingface.co/DAMO-NLP-SG/CLEX-Mixtral-8x7B-32K) |
| **CLEX-Mixtral-8x7B-Chat-32k** (this checkpoint) | chat | CLEX-Mixtral-8x7B-32K | [Ultrachat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) | 32k | >128K | [link](https://huggingface.co/DAMO-NLP-SG/CLEX-Mixtral-8x7B-Chat-32K) |
</div>


## Usage


```bash
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-SG/CLEX-Mixtral-8x7B-Chat-32K", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("DAMO-NLP-SG/CLEX-Mixtral-8x7B-Chat-32K", torch_dtype=torch.bfloat16, trust_remote_code=True)
inputs = tokenizer("What is CLEX?", return_tensors="pt")
sample = model.generate(**inputs, max_length=128)
print(tokenizer.decode(sample[0]))
```




## Evaluation


## InfiniteBench
We also evaluate CLEX-Mixtral-8x7B-Chat-32k on [InfiniteBench](https://github.com/OpenBMB/InfiniteBench), which is a 128k-length benchmark covering various tasks. We compare our CLEX-Mixtral-8x7B-Chat-32k with GPT-4, Claude, KimiChat, and vanilla Mixtral-8x7B.

| Task Name           | GPT-4  | YaRN-Mistral-7B | Kimi-Chat | Claude 2 | CLEX-Mixtral-8x7B-Chat-32k | Mixtral-8x7B-Instruct-v0.1 |
| ------------------- | ------ | --------------- | --------- | -------- | -------------------------- | -------------------------- |
| Retrieve.PassKey    | 100%   | 92.71%          | 98.14%    | 97.80%   | 99.72%                     | 96.78%                     |
| **Retrieve.Number** | 100%   | 56.61%          | 95.42%    | 98.14%   | 76.10%                     | 76.61%                     |
| **Retrieve.KV**     | 89.00% | < 5%            | 53.60%    | 65.40%   | <5%                        | <5%                        |
| En.Sum              | 14.73% | 9.09%           | 17.93%    | 14.45%   | 15.48%                     | 14.3%                      |
| En.QA               | 22.22% | 9.55%           | 16.52%    | 11.97%   | 15.52%                     | 16.81%                     |
| En.MC               | 67.25% | 27.95%          | 72.49%    | 62.88%   | 58.96%                     | 56.77%                     |
| En.Dia              | 8.50%  | 7.50%           | 11.50%    | 46.50%   | 9%                         | <5%                        |
| Code.Debug          | 39.59% | < 5%            | 18.02%    | < 5%     | 21.32%                     | <5%                        |
| Code.Run            | 23.25% | < 5%            | < 5%      | < 5%     | < 5%                       | <5%                        |
| Math.Calc           | < 5%   | < 5%            | < 5%      | < 5%     | < 5%                       | <5%                        |
| Math.Find           | 60.00% | 17.14%          | 12.57%    | 32.29%   | 28%                        | 26.57%                     |



## Citation
If you find our project useful, hope you can star our repo and cite our paper as follows:
```
@article{damonlpsg2023clex,
  author = {Chen, Guanzheng and Li, Xin and Meng, Zaiqiao and Liang, Shangsong and Bing, Lidong},
  title = {CLEX: Continuous Length Extrapolation for Large Language Models},
  year = 2023,
  journal = {arXiv preprint arXiv:2310.16450},
  url = {https://arxiv.org/abs/2310.16450}
}
```