Update LoRA fine-tune example - more target_modules, lower LR, bf16 (#49)
Browse files- Update LoRA fine-tune example - more target_modules, lower LR, bf16 (a69ca0f303d6079e51f4d323a81e2ec76484fc92)
Co-authored-by: Michael Gokhman <michael-go@users.noreply.huggingface.co>
README.md
CHANGED
@@ -96,31 +96,40 @@ model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
|
|
96 |
</details>
|
97 |
|
98 |
### Fine-tuning example
|
99 |
-
Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library:
|
100 |
|
101 |
```python
|
|
|
102 |
from datasets import load_dataset
|
103 |
-
from trl import SFTTrainer
|
104 |
from peft import LoraConfig
|
105 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
106 |
|
107 |
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
|
108 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
dataset = load_dataset("Abirate/english_quotes", split="train")
|
111 |
-
training_args =
|
112 |
output_dir="./results",
|
113 |
-
num_train_epochs=
|
114 |
per_device_train_batch_size=4,
|
115 |
logging_dir='./logs',
|
116 |
logging_steps=10,
|
117 |
-
learning_rate=
|
118 |
-
|
119 |
-
lora_config = LoraConfig(
|
120 |
-
r=8,
|
121 |
-
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
|
122 |
-
task_type="CAUSAL_LM",
|
123 |
-
bias="none"
|
124 |
)
|
125 |
trainer = SFTTrainer(
|
126 |
model=model,
|
@@ -128,9 +137,7 @@ trainer = SFTTrainer(
|
|
128 |
args=training_args,
|
129 |
peft_config=lora_config,
|
130 |
train_dataset=dataset,
|
131 |
-
dataset_text_field="quote",
|
132 |
)
|
133 |
-
|
134 |
trainer.train()
|
135 |
```
|
136 |
|
|
|
96 |
</details>
|
97 |
|
98 |
### Fine-tuning example
|
99 |
+
Jamba is a base model that can be fine-tuned for custom solutions (including for chat/instruct versions). You can fine-tune it using any technique of your choice. Here is an example of fine-tuning with the [PEFT](https://huggingface.co/docs/peft/index) library (requires ~120GB GPU RAM, in example 2xA100 80GB):
|
100 |
|
101 |
```python
|
102 |
+
import torch
|
103 |
from datasets import load_dataset
|
104 |
+
from trl import SFTTrainer, SFTConfig
|
105 |
from peft import LoraConfig
|
106 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
107 |
|
108 |
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
|
109 |
+
model = AutoModelForCausalLM.from_pretrained(
|
110 |
+
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
|
111 |
+
|
112 |
+
lora_config = LoraConfig(
|
113 |
+
r=8,
|
114 |
+
target_modules=[
|
115 |
+
"embed_tokens",
|
116 |
+
"x_proj", "in_proj", "out_proj", # mamba
|
117 |
+
"gate_proj", "up_proj", "down_proj", # mlp
|
118 |
+
"q_proj", "k_proj", "v_proj" # attention
|
119 |
+
],
|
120 |
+
task_type="CAUSAL_LM",
|
121 |
+
bias="none"
|
122 |
+
)
|
123 |
|
124 |
dataset = load_dataset("Abirate/english_quotes", split="train")
|
125 |
+
training_args = SFTConfig(
|
126 |
output_dir="./results",
|
127 |
+
num_train_epochs=2,
|
128 |
per_device_train_batch_size=4,
|
129 |
logging_dir='./logs',
|
130 |
logging_steps=10,
|
131 |
+
learning_rate=1e-5,
|
132 |
+
dataset_text_field="quote",
|
|
|
|
|
|
|
|
|
|
|
133 |
)
|
134 |
trainer = SFTTrainer(
|
135 |
model=model,
|
|
|
137 |
args=training_args,
|
138 |
peft_config=lora_config,
|
139 |
train_dataset=dataset,
|
|
|
140 |
)
|
|
|
141 |
trainer.train()
|
142 |
```
|
143 |
|