Update readme with instructions on how to change the kernels. (#6)
Browse files- Update readme with instructions on how to change the kernels. (044aabbd6dbabb4c49e06f3fd961b656b9c8d734)
- fix import (8b0f5773ecf168d0eb6ad229bd6846148725e1d9)
Co-authored-by: Maximilian Beck <maxmbeck@users.noreply.huggingface.co>
README.md
CHANGED
@@ -1,58 +1,73 @@
|
|
1 |
-
---
|
2 |
-
license: other
|
3 |
-
---
|
4 |
-
|
5 |
-
# xLSTM-7B
|
6 |
-
This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the `xlstm-jax` framework.
|
7 |
-
|
8 |
-
|
9 |
-
## How to use it
|
10 |
-
First, install `xlstm`, which now uses the `mlstm_kernels` package for triton kernels:
|
11 |
-
|
12 |
-
```bash
|
13 |
-
pip install xlstm
|
14 |
-
pip install mlstm_kernels
|
15 |
-
```
|
16 |
-
|
17 |
-
For now, install the transformers repositiory fork from NX-AI (until it is merged):
|
18 |
-
```bash
|
19 |
-
pip install 'transformers @ git+ssh://git@github.com/NX-AI/transformers.git@integrate_xlstm'
|
20 |
-
```
|
21 |
-
|
22 |
-
Use this model as:
|
23 |
-
```python
|
24 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
25 |
-
|
26 |
-
xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")
|
27 |
-
|
28 |
-
# this is a fork of EleutherAI/gpt-neox-20b
|
29 |
-
tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")
|
30 |
-
|
31 |
-
tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda")
|
32 |
-
|
33 |
-
out = xlstm.generate(tokens, max_new_tokens=20)
|
34 |
-
|
35 |
-
print(tokenizer.decode(out[0]))
|
36 |
-
```
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
##
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: other
|
3 |
+
---
|
4 |
+
|
5 |
+
# xLSTM-7B
|
6 |
+
This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the `xlstm-jax` framework.
|
7 |
+
|
8 |
+
|
9 |
+
## How to use it
|
10 |
+
First, install `xlstm`, which now uses the `mlstm_kernels` package for triton kernels:
|
11 |
+
|
12 |
+
```bash
|
13 |
+
pip install xlstm
|
14 |
+
pip install mlstm_kernels
|
15 |
+
```
|
16 |
+
|
17 |
+
For now, install the transformers repositiory fork from NX-AI (until it is merged):
|
18 |
+
```bash
|
19 |
+
pip install 'transformers @ git+ssh://git@github.com/NX-AI/transformers.git@integrate_xlstm'
|
20 |
+
```
|
21 |
+
|
22 |
+
Use this model as:
|
23 |
+
```python
|
24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
25 |
+
|
26 |
+
xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")
|
27 |
+
|
28 |
+
# this is a fork of EleutherAI/gpt-neox-20b
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")
|
30 |
+
|
31 |
+
tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda")
|
32 |
+
|
33 |
+
out = xlstm.generate(tokens, max_new_tokens=20)
|
34 |
+
|
35 |
+
print(tokenizer.decode(out[0]))
|
36 |
+
```
|
37 |
+
|
38 |
+
If you cannot or do not want to use the triton kernels, you can change them to native PyTorch implementations:
|
39 |
+
```python
|
40 |
+
xlstm_config = AutoConfig.from_pretrained("NX-AI/xLSTM-7b")
|
41 |
+
xlstm_config.step_kernel = "native"
|
42 |
+
xlstm_config.chunkwise_kernel = "chunkwise--native_autograd"
|
43 |
+
xlstm_config.sequence_kernel = "native_sequence__native"
|
44 |
+
|
45 |
+
xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", config=xlstm_config, device_map="auto")
|
46 |
+
|
47 |
+
# verify selected kernels
|
48 |
+
from pprint import pprint
|
49 |
+
pprint(xlstm.backbone.blocks[0].mlstm_layer.config)
|
50 |
+
```
|
51 |
+
|
52 |
+
|
53 |
+
## Speed results
|
54 |
+
Generation Speed using `torch.cuda.graph` and `torch.compile` optimizations on one NVIDIA H100:
|
55 |
+
![generation speed](plot_tokens_per_sec.svg)
|
56 |
+
|
57 |
+
## Performance
|
58 |
+
![mmlu_train_token](MMLUvsTrainToken.svg)
|
59 |
+
|
60 |
+
Using HuggingFace's `lm_eval`:
|
61 |
+
|
62 |
+
| BBH | MMLU-Pro | Math | MUSR | GPQA | IfEval |
|
63 |
+
|-------|----------|--------|------|------|--------|
|
64 |
+
| 0.381 | 0.242 | 0.036 | 0.379|0.280 | 0.244 |
|
65 |
+
|
66 |
+
Using HuggingFace's `lighteval` in the Leaderboard-v1 settings:
|
67 |
+
|
68 |
+
|Arc-Challenge (25-shot) |MMLU (5-shot) |Hellaswag (10-shot)|Winogrande (5-shot) |TruthfulQA (0-shot) |GSM8k (5-shot) |OpenbookQA (5-shot) | PiQA (5-shot)|
|
69 |
+
|------------------------|--------------|-------------------|--------------------|--------------------|---------------|--------------------|--------------|
|
70 |
+
| 0.584 |0.589 | 0.710 |0.742 | 0.420 | 0.004 | 0.443 | 0.817 |
|
71 |
+
|
72 |
+
## License
|
73 |
+
NXAI Community License (see `LICENSE` file)
|