Update README.md
Browse files
README.md
CHANGED
@@ -1,86 +1,79 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
---
|
5 |
|
6 |
-
# Model Card for Zamba2-1.2B
|
7 |
|
8 |
-
Zamba2-1.2B is a
|
9 |
|
10 |
-
1.
|
|
|
|
|
|
|
11 |
|
12 |
-
2.
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
Zamba2-1.2B differs from our [2.7B model](https://huggingface.co/Zyphra/Zamba2-2.7B) in three ways:
|
17 |
-
|
18 |
-
1.) We have added rotary position embeddings
|
19 |
-
|
20 |
-
2.) A single shared transformer block (instead of two that we alternate between)
|
21 |
-
|
22 |
-
3.) Added LoRA projectors to attention blocks (instead of just a LoRA on the MLP block)
|
23 |
-
|
24 |
-
We found that while hybrid SSM-transformer models are perfectly capable of performing well without position embeddings, adding rotary embeddings to the shared attention block slightly improved performance. Secondly, we utilize a single attention block (instead of alternating between two independent transformer blocks) because this enables a higher flop count for the model at a given parameter budget and at smaller scales this becomes more important than the slightly faster latency.
|
25 |
-
|
26 |
-
Zamba2-1.2B uses the Mistral v0.1 tokenizer and was pre-trained on 3T tokens of text and code data sourced from open web-datasets, including [Zyda](https://arxiv.org/abs/2406.01981). Subsequently, in a second phase, Zamba2-1.2B was annealed on a mixture of 100B high-quality tokens.
|
27 |
-
|
28 |
-
Note: this is a temporary HuggingFace implementation of Zamba2-1.2B. It may not yet be fully compatible with all frameworks and tools intended to interface with HuggingFace models.
|
29 |
-
|
30 |
-
A standalone Pytorch implementation of Zamba2-1.2B may be found [here](https://github.com/Zyphra/Zamba2).
|
31 |
|
32 |
## Quick start
|
33 |
|
34 |
### Prerequisites
|
35 |
|
36 |
-
To download Zamba2-1.2B, clone Zyphra's fork of transformers:
|
37 |
1. `git clone https://github.com/Zyphra/transformers_zamba2.git`
|
38 |
2. `cd transformers_zamba2`
|
39 |
3. Install the repository: `pip install -e .`
|
40 |
4. `pip install accelerate`
|
41 |
|
42 |
-
|
43 |
-
You can run the model without using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly higher latency and memory usage.
|
44 |
-
|
45 |
-
To run on CPU, please specify `use_mamba_kernels=False` when loading the model using ``AutoModelForCausalLM.from_pretrained``.
|
46 |
-
|
47 |
-
|
48 |
### Inference
|
|
|
49 |
```python
|
50 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
51 |
import torch
|
52 |
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
```
|
62 |
|
63 |
-
## Training
|
|
|
|
|
64 |
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
-
### Fine-tuning
|
68 |
|
69 |
-
The model includes an advanced learning rate optimization system for fine-tuning, implemented through the `LROptimizerCallback` class
|
70 |
|
71 |
```python
|
72 |
from transformers import AutoTokenizer, Trainer
|
73 |
from lr_optimizer import setup_training, LROptimizerCallback
|
74 |
|
75 |
-
# Method 1: Using the complete setup function
|
76 |
-
training_setup = setup_training(
|
77 |
-
model_name="Zyphra/Zamba2-1.2B",
|
78 |
-
dataset_name="your/dataset",
|
79 |
-
num_trials=10
|
80 |
-
)
|
81 |
-
trainer = training_setup['trainer']
|
82 |
-
|
83 |
-
# Method 2: Using the callback directly
|
84 |
callback = LROptimizerCallback(
|
85 |
num_trials=10,
|
86 |
lr_range=(1e-6, 1e-4)
|
@@ -91,21 +84,17 @@ trainer = Trainer(
|
|
91 |
callbacks=[callback]
|
92 |
)
|
93 |
|
94 |
-
# Start training with optimized configuration
|
95 |
trainer.train()
|
96 |
```
|
97 |
|
98 |
-
|
99 |
-
- Explores learning rates between 1e-6 and 1e-4 using Bayesian optimization
|
100 |
-
- Applies Gaussian Process Regression for precise LR selection
|
101 |
-
- Implements memory optimization through gradient checkpointing
|
102 |
-
- Supports both fp16 and bf16 training
|
103 |
-
|
104 |
-
For detailed configuration options, see the [fine-tuning documentation](link-to-docs).
|
105 |
|
106 |
-
|
107 |
|
108 |
-
|
|
|
|
|
|
|
109 |
|
110 |
<center>
|
111 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/Vay6htbnBcySR3Z6NEgwj.png" width="300" alt="Zamba architecture">
|
@@ -113,37 +102,29 @@ Zamba2-1.2B utilizes and extends our original Zamba hybrid SSM-attention archite
|
|
113 |
|
114 |
## Performance
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
<center>
|
121 |
-
<img src="https://cdn-uploads.huggingface.co/production/uploads/65bc13717c6ad1994b6619e9/7Japy8VaJzKaFEjJgtWBp.png" width="700" alt="Zamba performance">
|
122 |
-
</center>
|
123 |
-
|
124 |
-
<center>
|
125 |
-
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/Viwo3-bpYLFUu7cLIUFVv.png" width="800" alt="Zamba performance">
|
126 |
-
</center>
|
127 |
-
|
128 |
-
<!--
|
129 |
-
<center>
|
130 |
-
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/JVZUvVMPIpIJy9RDyohMJ.png" width="800" alt="Zamba performance">
|
131 |
-
</center>
|
132 |
-
-->
|
133 |
|
134 |
Time to First Token (TTFT) | Output Generation
|
135 |
:-------------------------:|:-------------------------:
|
136 |
![image/png](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/5lpWDLdtPPVAk8COJq7gZ.png) | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/V2tS6eCOGbpKybEoZmOB7.png)
|
137 |
|
138 |
-
|
139 |
-
And memory overhead
|
140 |
-
|
141 |
<center>
|
142 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/m0YUmAmiVnRg6l9m10CEt.png" width="400" alt="Zamba inference and memory cost">
|
143 |
</center>
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
-
##
|
147 |
|
148 |
-
|
149 |
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
datasets:
|
4 |
+
- HuggingFaceH4/ultrachat_200k
|
5 |
+
- BAAI/Infinity-Instruct
|
6 |
+
- HuggingFaceH4/ultrafeedback_binarized
|
7 |
+
- Intel/orca_dpo_pairs
|
8 |
+
- argilla/OpenHermesPreferences
|
9 |
+
- BramVanroy/dolly-15k-dutch
|
10 |
+
base_model:
|
11 |
+
- Zyphra/Zamba2-1.2B-instruct
|
12 |
+
library_name: transformers
|
13 |
---
|
14 |
|
15 |
+
# Model Card for Zamba2-1.2B-instruct-Dutch
|
16 |
|
17 |
+
Zamba2-1.2B-instruct-Dutch is a Dutch language instruction-following model obtained through a two-stage fine-tuning process:
|
18 |
|
19 |
+
1. First stage (Base instruction model by Zyphra):
|
20 |
+
- Zyphra fine-tuned Zamba2-1.2B to create Zamba2-1.2B-instruct through:
|
21 |
+
- SFT training on [ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) and [Infinity-Instruct](https://huggingface.co/datasets/BAAI/Infinity-Instruct)
|
22 |
+
- DPO training on [ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized), [orca_dpo_pairs](https://huggingface.co/datasets/Intel/orca_dpo_pairs), and [OpenHermesPreferences](https://huggingface.co/datasets/argilla/OpenHermesPreferences)
|
23 |
|
24 |
+
2. Second stage (Dutch language adaptation):
|
25 |
+
- Further fine-tuning of Zyphra's Zamba2-1.2B-instruct on the [dolly-15k-dutch](https://huggingface.co/datasets/BramVanroy/dolly-15k-dutch) dataset, specifically using the training split
|
26 |
|
27 |
+
The model maintains the core hybrid architecture of Zamba2 while being optimized for Dutch language understanding and generation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
## Quick start
|
30 |
|
31 |
### Prerequisites
|
32 |
|
33 |
+
To download Zamba2-1.2B-instruct-Dutch, clone Zyphra's fork of transformers:
|
34 |
1. `git clone https://github.com/Zyphra/transformers_zamba2.git`
|
35 |
2. `cd transformers_zamba2`
|
36 |
3. Install the repository: `pip install -e .`
|
37 |
4. `pip install accelerate`
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
### Inference
|
40 |
+
|
41 |
```python
|
42 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
43 |
import torch
|
44 |
|
45 |
+
# Instantiate model and tokenizer
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B-instruct-Dutch")
|
47 |
+
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-1.2B-instruct-Dutch", device_map="cuda", torch_dtype=torch.bfloat16)
|
48 |
|
49 |
+
# Format the input as a chat template
|
50 |
+
prompt = "Wat zijn de belangrijkste oorzaken van de val van het Romeinse Rijk?"
|
51 |
+
sample = [{'role': 'user', 'content': prompt}]
|
52 |
+
chat_sample = tokenizer.apply_chat_template(sample, tokenize=False)
|
53 |
|
54 |
+
# Tokenize input and generate output
|
55 |
+
input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to("cuda")
|
56 |
+
outputs = model.generate(**input_ids, max_new_tokens=150, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
|
57 |
+
print((tokenizer.decode(outputs[0])))
|
58 |
```
|
59 |
|
60 |
+
## Training Details
|
61 |
+
|
62 |
+
The model was fine-tuned using the following approach:
|
63 |
|
64 |
+
1. Started with the base Zamba2-1.2B-instruct model
|
65 |
+
2. Fine-tuned on the dolly-15k-dutch dataset using optimized learning rates
|
66 |
+
3. Implemented memory optimization through gradient checkpointing
|
67 |
+
4. Utilized mixed precision training (bf16)
|
68 |
|
69 |
+
### Fine-tuning Configuration
|
70 |
|
71 |
+
The model includes an advanced learning rate optimization system for fine-tuning, implemented through the `LROptimizerCallback` class:
|
72 |
|
73 |
```python
|
74 |
from transformers import AutoTokenizer, Trainer
|
75 |
from lr_optimizer import setup_training, LROptimizerCallback
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
callback = LROptimizerCallback(
|
78 |
num_trials=10,
|
79 |
lr_range=(1e-6, 1e-4)
|
|
|
84 |
callbacks=[callback]
|
85 |
)
|
86 |
|
|
|
87 |
trainer.train()
|
88 |
```
|
89 |
|
90 |
+
## Model Architecture
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
Zamba2-1.2B-instruct-Dutch maintains the hybrid SSM-attention architecture of the base model:
|
93 |
|
94 |
+
- Backbone of Mamba2 layers interleaved with shared attention layers
|
95 |
+
- LoRA projection matrices for shared transformer blocks
|
96 |
+
- Rotary position embeddings in the shared attention layer
|
97 |
+
- Concatenated original embeddings for improved information maintenance
|
98 |
|
99 |
<center>
|
100 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/Vay6htbnBcySR3Z6NEgwj.png" width="300" alt="Zamba architecture">
|
|
|
102 |
|
103 |
## Performance
|
104 |
|
105 |
+
The model maintains the efficient inference characteristics of the base Zamba2 architecture:
|
106 |
+
- Low latency inference
|
107 |
+
- Rapid generation
|
108 |
+
- Small memory footprint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
Time to First Token (TTFT) | Output Generation
|
111 |
:-------------------------:|:-------------------------:
|
112 |
![image/png](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/5lpWDLdtPPVAk8COJq7gZ.png) | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/V2tS6eCOGbpKybEoZmOB7.png)
|
113 |
|
114 |
+
Memory overhead:
|
|
|
|
|
115 |
<center>
|
116 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/m0YUmAmiVnRg6l9m10CEt.png" width="400" alt="Zamba inference and memory cost">
|
117 |
</center>
|
118 |
|
119 |
+
## Limitations
|
120 |
+
|
121 |
+
- The model is primarily focused on Dutch language understanding and generation
|
122 |
+
- Performance on other languages may be limited
|
123 |
+
- The training dataset size is relatively small compared to larger multilingual models
|
124 |
+
- No explicit content moderation mechanisms are included
|
125 |
|
126 |
+
## License
|
127 |
|
128 |
+
This model is released under the Apache 2.0 license.
|
129 |
|
130 |
+
Note: This is a temporary HuggingFace implementation. A standalone PyTorch implementation may be found at [Zamba2 GitHub repository](https://github.com/Zyphra/Zamba2).
|