Fine-tuning Zamba2-1.2B Model: Step-by-Step Guide
#3
by
ssmits
- opened
This guide will walk you through the process of fine-tuning the Zamba2-1.2B model. Make sure you have sufficient GPU memory as this is a 1.2B parameter model (RTX 4090).
Tested Environment:
- Vast.ai cloud instance (template link)
- CUDA 11.8
- Python 3.10
- PyTorch Image: pytorch/pytorch:2.3.1-cuda11.8-cudnn8-devel
- Selected for extra stability
1. Setup Environment
First, clone the repository and set up the environment:
git clone https://github.com/Zyphra/transformers_zamba2.git
cd transformers_zamba2
# Create and activate virtual environment
python -m venv venv
source venv/bin/activate # For Windows use: venv\Scripts\activate
# Install dependencies
pip install -e .
pip install accelerate datasets
2. Basic Inference Test
Let's first test if the model loads correctly:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B")
model = AutoModelForCausalLM.from_pretrained(
"Zyphra/Zamba2-1.2B",
device_map="cuda",
torch_dtype=torch.bfloat16
)
# Test generation
input_text = "What factors contributed to the fall of the Roman Empire?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
3. Fine-tuning Setup
Here's the complete fine-tuning script with detailed configurations:
CONTEXT_WINDOW = 1024
from transformers import (
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
Trainer, DataCollatorForLanguageModeling
)
import torch
from datasets import Dataset
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" # Better for inference
# Initialize model
model = AutoModelForCausalLM.from_pretrained(
"Zyphra/Zamba2-1.2B",
torch_dtype=torch.bfloat16,
device_map="auto" # Handles multi-GPU/CPU mapping
)
model.config.pad_token_id = tokenizer.pad_token_id
# Tokenization function
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding=True,
truncation=True,
max_length=1024,
return_tensors=None
)
# Prepare training data
train_texts = [
"What factors contributed to the fall of the Roman Empire?",
# Add your training examples here
]
dataset = Dataset.from_dict({"text": train_texts})
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names
)
# Training configuration
training_args = TrainingArguments(
output_dir="./zamba2-finetuned",
num_train_epochs=3,
per_device_train_batch_size=1,
save_steps=500,
save_total_limit=2,
logging_steps=100,
learning_rate=2e-5,
weight_decay=0.01,
fp16=False,
bf16=True,
gradient_accumulation_steps=16
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Custom trainer wrapper for device mapping
class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = model
def _move_model_to_device(self, model, device):
pass # Model already mapped to devices
# Initialize trainer
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator
)
# Train and save
trainer.train()
model.save_pretrained("./zamba2-finetuned-final")
tokenizer.save_pretrained("./zamba2-finetuned-final")
Important Notes:
Hardware Requirements:
- Recommended: GPU with at least 24GB VRAM for CONTEXT_WINDOW = 1024 (>24GB for 2048, did not test multi-GPU yet).
- The script uses bfloat16 precision to reduce memory usage
Training Configuration:
- Context window: 1024 tokens
- Learning rate: 2e-5
- Weight decay: 0.01
- Gradient accumulation: 16 steps
- Training epochs: 3
Customization:
- Add your training examples to the
train_texts
list - Adjust
training_args
parameters based on your needs - Modify
max_length
in tokenization if needed
- Add your training examples to the
Output:
- The fine-tuned model will be saved in
./zamba2-finetuned-final
- Checkpoints during training are saved in
./zamba2-finetuned
- The fine-tuned model will be saved in
Feel free to ask if you have any questions about the process.