ssmits commited on
Commit
0d4225a
1 Parent(s): 68bc2e4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -80
README.md CHANGED
@@ -1,86 +1,79 @@
1
  ---
2
  license: apache-2.0
3
- library_name: transformers_zamba2
 
 
 
 
 
 
 
 
 
4
  ---
5
 
6
- # Model Card for Zamba2-1.2B
7
 
8
- Zamba2-1.2B is a hybrid model composed of state-space ([Mamba](https://github.com/state-spaces/mamba)) and transformer blocks. It broadly follows the [Zamba architecture](https://arxiv.org/abs/2405.16712) which consists of a Mamba backbone alternating with shared transformer blocks (see diagram in [Model Details](#model-details)). Zamba2-1.2B possesses three major improvements over Zamba1:
9
 
10
- 1.) Mamba1 blocks have been replaced with Mamba2 blocks.
 
 
 
11
 
12
- 2.) We apply a LoRA projector to each shared MLP and attention block, which allows the network to specialize at each invocation of the shared transformer layer across depth. LoRA enables us to add depth-specialization for only a minimal increase in total parameter count.
 
13
 
14
- 3.) We utilize rotary position embeddings in the shared attention layer.
15
-
16
- Zamba2-1.2B differs from our [2.7B model](https://huggingface.co/Zyphra/Zamba2-2.7B) in three ways:
17
-
18
- 1.) We have added rotary position embeddings
19
-
20
- 2.) A single shared transformer block (instead of two that we alternate between)
21
-
22
- 3.) Added LoRA projectors to attention blocks (instead of just a LoRA on the MLP block)
23
-
24
- We found that while hybrid SSM-transformer models are perfectly capable of performing well without position embeddings, adding rotary embeddings to the shared attention block slightly improved performance. Secondly, we utilize a single attention block (instead of alternating between two independent transformer blocks) because this enables a higher flop count for the model at a given parameter budget and at smaller scales this becomes more important than the slightly faster latency.
25
-
26
- Zamba2-1.2B uses the Mistral v0.1 tokenizer and was pre-trained on 3T tokens of text and code data sourced from open web-datasets, including [Zyda](https://arxiv.org/abs/2406.01981). Subsequently, in a second phase, Zamba2-1.2B was annealed on a mixture of 100B high-quality tokens.
27
-
28
- Note: this is a temporary HuggingFace implementation of Zamba2-1.2B. It may not yet be fully compatible with all frameworks and tools intended to interface with HuggingFace models.
29
-
30
- A standalone Pytorch implementation of Zamba2-1.2B may be found [here](https://github.com/Zyphra/Zamba2).
31
 
32
  ## Quick start
33
 
34
  ### Prerequisites
35
 
36
- To download Zamba2-1.2B, clone Zyphra's fork of transformers:
37
  1. `git clone https://github.com/Zyphra/transformers_zamba2.git`
38
  2. `cd transformers_zamba2`
39
  3. Install the repository: `pip install -e .`
40
  4. `pip install accelerate`
41
 
42
-
43
- You can run the model without using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly higher latency and memory usage.
44
-
45
- To run on CPU, please specify `use_mamba_kernels=False` when loading the model using ``AutoModelForCausalLM.from_pretrained``.
46
-
47
-
48
  ### Inference
 
49
  ```python
50
  from transformers import AutoTokenizer, AutoModelForCausalLM
51
  import torch
52
 
53
- tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B")
54
- model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-1.2B", device_map="cuda", torch_dtype=torch.bfloat16)
 
55
 
56
- input_text = "What factors contributed to the fall of the Roman Empire?"
57
- input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
 
 
58
 
59
- outputs = model.generate(**input_ids, max_new_tokens=100)
60
- print(tokenizer.decode(outputs[0]))
 
 
61
  ```
62
 
63
- ## Training Data
 
 
64
 
65
- The model is fine-tuned on the **BramVanroy/dolly-15k-dutch** dataset, specifically using the training split (`train_sft`). This dataset is not SoTA, however the goal is to demonstrate the capabilities and it fits <1024 tokens.
 
 
 
66
 
67
- ### Fine-tuning with Learning Rate Optimization
68
 
69
- The model includes an advanced learning rate optimization system for fine-tuning, implemented through the `LROptimizerCallback` class. This callback automatically handles learning rate optimization during training. Here's how to use it:
70
 
71
  ```python
72
  from transformers import AutoTokenizer, Trainer
73
  from lr_optimizer import setup_training, LROptimizerCallback
74
 
75
- # Method 1: Using the complete setup function
76
- training_setup = setup_training(
77
- model_name="Zyphra/Zamba2-1.2B",
78
- dataset_name="your/dataset",
79
- num_trials=10
80
- )
81
- trainer = training_setup['trainer']
82
-
83
- # Method 2: Using the callback directly
84
  callback = LROptimizerCallback(
85
  num_trials=10,
86
  lr_range=(1e-6, 1e-4)
@@ -91,21 +84,17 @@ trainer = Trainer(
91
  callbacks=[callback]
92
  )
93
 
94
- # Start training with optimized configuration
95
  trainer.train()
96
  ```
97
 
98
- The optimization process automatically:
99
- - Explores learning rates between 1e-6 and 1e-4 using Bayesian optimization
100
- - Applies Gaussian Process Regression for precise LR selection
101
- - Implements memory optimization through gradient checkpointing
102
- - Supports both fp16 and bf16 training
103
-
104
- For detailed configuration options, see the [fine-tuning documentation](link-to-docs).
105
 
106
- ## Model Details
107
 
108
- Zamba2-1.2B utilizes and extends our original Zamba hybrid SSM-attention architecture. The core Zamba architecture consists of a backbone of Mamba layers interleaved with one or more shared attention layers. This attention has shared weights to minimize the parameter cost of the model. We find that concatenating the original model embeddings to the input to this attention block improves performance, likely due to better maintenance of information across depth. The Zamba2 architecture also applies LoRA projection matrices to the shared transformer blocks to gain some additional expressivity in each block and allow each shared block to specialize slightly to its own unique position while keeping the additional parameter overhead small.
 
 
 
109
 
110
  <center>
111
  <img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/Vay6htbnBcySR3Z6NEgwj.png" width="300" alt="Zamba architecture">
@@ -113,37 +102,29 @@ Zamba2-1.2B utilizes and extends our original Zamba hybrid SSM-attention archite
113
 
114
  ## Performance
115
 
116
- Zamba2-1.2B achieves leading and state-of-the-art performance among models of <2B parameters and is competitive with some models of significantly greater size. Moreover, due to its unique hybrid SSM architecture, Zamba2-1.2B achieves extremely low inference latency and rapid generation with a significantly smaller memory footprint than comparable transformer based models.
117
-
118
- Zamba2-1.2B's high performance and small inference compute and memory footprint renders it an ideal generalist model for on-device applications.
119
-
120
- <center>
121
- <img src="https://cdn-uploads.huggingface.co/production/uploads/65bc13717c6ad1994b6619e9/7Japy8VaJzKaFEjJgtWBp.png" width="700" alt="Zamba performance">
122
- </center>
123
-
124
- <center>
125
- <img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/Viwo3-bpYLFUu7cLIUFVv.png" width="800" alt="Zamba performance">
126
- </center>
127
-
128
- <!--
129
- <center>
130
- <img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/JVZUvVMPIpIJy9RDyohMJ.png" width="800" alt="Zamba performance">
131
- </center>
132
- -->
133
 
134
  Time to First Token (TTFT) | Output Generation
135
  :-------------------------:|:-------------------------:
136
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/5lpWDLdtPPVAk8COJq7gZ.png) | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/V2tS6eCOGbpKybEoZmOB7.png)
137
 
138
-
139
- And memory overhead
140
-
141
  <center>
142
  <img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/m0YUmAmiVnRg6l9m10CEt.png" width="400" alt="Zamba inference and memory cost">
143
  </center>
144
 
 
 
 
 
 
 
145
 
146
- ## Notice
147
 
148
- Zamba2-1.2B is a pretrained base model and therefore does not have any moderation mechanism and may output toxic or otherwise harmful language. In addition, one should not expect good instruct or chat performance, as this model was not fine-tuned for instruction following or chat.
149
 
 
 
1
  ---
2
  license: apache-2.0
3
+ datasets:
4
+ - HuggingFaceH4/ultrachat_200k
5
+ - BAAI/Infinity-Instruct
6
+ - HuggingFaceH4/ultrafeedback_binarized
7
+ - Intel/orca_dpo_pairs
8
+ - argilla/OpenHermesPreferences
9
+ - BramVanroy/dolly-15k-dutch
10
+ base_model:
11
+ - Zyphra/Zamba2-1.2B-instruct
12
+ library_name: transformers
13
  ---
14
 
15
+ # Model Card for Zamba2-1.2B-instruct-Dutch
16
 
17
+ Zamba2-1.2B-instruct-Dutch is a Dutch language instruction-following model obtained through a two-stage fine-tuning process:
18
 
19
+ 1. First stage (Base instruction model by Zyphra):
20
+ - Zyphra fine-tuned Zamba2-1.2B to create Zamba2-1.2B-instruct through:
21
+ - SFT training on [ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) and [Infinity-Instruct](https://huggingface.co/datasets/BAAI/Infinity-Instruct)
22
+ - DPO training on [ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized), [orca_dpo_pairs](https://huggingface.co/datasets/Intel/orca_dpo_pairs), and [OpenHermesPreferences](https://huggingface.co/datasets/argilla/OpenHermesPreferences)
23
 
24
+ 2. Second stage (Dutch language adaptation):
25
+ - Further fine-tuning of Zyphra's Zamba2-1.2B-instruct on the [dolly-15k-dutch](https://huggingface.co/datasets/BramVanroy/dolly-15k-dutch) dataset, specifically using the training split
26
 
27
+ The model maintains the core hybrid architecture of Zamba2 while being optimized for Dutch language understanding and generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  ## Quick start
30
 
31
  ### Prerequisites
32
 
33
+ To download Zamba2-1.2B-instruct-Dutch, clone Zyphra's fork of transformers:
34
  1. `git clone https://github.com/Zyphra/transformers_zamba2.git`
35
  2. `cd transformers_zamba2`
36
  3. Install the repository: `pip install -e .`
37
  4. `pip install accelerate`
38
 
 
 
 
 
 
 
39
  ### Inference
40
+
41
  ```python
42
  from transformers import AutoTokenizer, AutoModelForCausalLM
43
  import torch
44
 
45
+ # Instantiate model and tokenizer
46
+ tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B-instruct-Dutch")
47
+ model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-1.2B-instruct-Dutch", device_map="cuda", torch_dtype=torch.bfloat16)
48
 
49
+ # Format the input as a chat template
50
+ prompt = "Wat zijn de belangrijkste oorzaken van de val van het Romeinse Rijk?"
51
+ sample = [{'role': 'user', 'content': prompt}]
52
+ chat_sample = tokenizer.apply_chat_template(sample, tokenize=False)
53
 
54
+ # Tokenize input and generate output
55
+ input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to("cuda")
56
+ outputs = model.generate(**input_ids, max_new_tokens=150, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
57
+ print((tokenizer.decode(outputs[0])))
58
  ```
59
 
60
+ ## Training Details
61
+
62
+ The model was fine-tuned using the following approach:
63
 
64
+ 1. Started with the base Zamba2-1.2B-instruct model
65
+ 2. Fine-tuned on the dolly-15k-dutch dataset using optimized learning rates
66
+ 3. Implemented memory optimization through gradient checkpointing
67
+ 4. Utilized mixed precision training (bf16)
68
 
69
+ ### Fine-tuning Configuration
70
 
71
+ The model includes an advanced learning rate optimization system for fine-tuning, implemented through the `LROptimizerCallback` class:
72
 
73
  ```python
74
  from transformers import AutoTokenizer, Trainer
75
  from lr_optimizer import setup_training, LROptimizerCallback
76
 
 
 
 
 
 
 
 
 
 
77
  callback = LROptimizerCallback(
78
  num_trials=10,
79
  lr_range=(1e-6, 1e-4)
 
84
  callbacks=[callback]
85
  )
86
 
 
87
  trainer.train()
88
  ```
89
 
90
+ ## Model Architecture
 
 
 
 
 
 
91
 
92
+ Zamba2-1.2B-instruct-Dutch maintains the hybrid SSM-attention architecture of the base model:
93
 
94
+ - Backbone of Mamba2 layers interleaved with shared attention layers
95
+ - LoRA projection matrices for shared transformer blocks
96
+ - Rotary position embeddings in the shared attention layer
97
+ - Concatenated original embeddings for improved information maintenance
98
 
99
  <center>
100
  <img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/Vay6htbnBcySR3Z6NEgwj.png" width="300" alt="Zamba architecture">
 
102
 
103
  ## Performance
104
 
105
+ The model maintains the efficient inference characteristics of the base Zamba2 architecture:
106
+ - Low latency inference
107
+ - Rapid generation
108
+ - Small memory footprint
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  Time to First Token (TTFT) | Output Generation
111
  :-------------------------:|:-------------------------:
112
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/5lpWDLdtPPVAk8COJq7gZ.png) | ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/V2tS6eCOGbpKybEoZmOB7.png)
113
 
114
+ Memory overhead:
 
 
115
  <center>
116
  <img src="https://cdn-uploads.huggingface.co/production/uploads/65c05e75c084467acab2f84a/m0YUmAmiVnRg6l9m10CEt.png" width="400" alt="Zamba inference and memory cost">
117
  </center>
118
 
119
+ ## Limitations
120
+
121
+ - The model is primarily focused on Dutch language understanding and generation
122
+ - Performance on other languages may be limited
123
+ - The training dataset size is relatively small compared to larger multilingual models
124
+ - No explicit content moderation mechanisms are included
125
 
126
+ ## License
127
 
128
+ This model is released under the Apache 2.0 license.
129
 
130
+ Note: This is a temporary HuggingFace implementation. A standalone PyTorch implementation may be found at [Zamba2 GitHub repository](https://github.com/Zyphra/Zamba2).