Llama2-TwAddr-LoRA / scripts /step1_finetuning.py
penut85420's picture
Add README, Data, Scripts
6512525
import datasets
import torch
from peft import LoraConfig, TaskType, get_peft_model
from peft.peft_model import PeftModel
from transformers import LlamaForCausalLM as ModelCls
from transformers import Trainer, TrainingArguments
# 讀取 Model
model_name = "TheBloke/Llama-2-7B-Chat-fp16"
model: ModelCls = ModelCls.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)
# 讀取 Peft Model
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
model: PeftModel = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# 讀取資料集
data_files = {
"train": "data/train.tokens.json.gz",
"dev": "data/dev.tokens.json.gz",
}
dataset = datasets.load_dataset(
"json",
data_files=data_files,
cache_dir="cache",
)
# 設定訓練參數
output_dir = "models/Llama-7B-TwAddr-LoRA"
train_args = TrainingArguments(
output_dir,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
eval_accumulation_steps=2,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-4,
save_total_limit=3,
num_train_epochs=5,
load_best_model_at_end=True,
bf16=True,
)
# 開始訓練模型
trainer = Trainer(
model=model,
args=train_args,
train_dataset=dataset["train"],
eval_dataset=dataset["dev"],
)
trainer.train()
# 儲存訓練完的模型
trainer.save_model()