pythia410m-sft-tldr / code /merge_peft_adapter.py
mnoukhov's picture
Training in progress, step 500
1904ee8 verified
from dataclasses import dataclass, field
from typing import Optional
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
"""
adapter_model_name: str = field(default=None, metadata={"help": "the model name"})
# base_model_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
output_name: str = field(default=None, metadata={"help": "the model name"})
dtype: Optional[str] = field(default="bf16")
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
# assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
# assert script_args.base_model_name is not None, "please provide the name of the Base model"
# assert script_args.base_model_name is not None, "please provide the output name of the merged model"
if script_args.dtype == "bf16":
torch_dtype = torch.bfloat16
elif script_args.dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
if peft_config.task_type == "SEQ_CLS":
# peft is for reward model so load sequence classification
model = AutoModelForSequenceClassification.from_pretrained(
peft_config.base_model_name_or_path,
num_labels=1,
torch_dtype=torch_dtype,
)
else:
model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
return_dict=True,
torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
# Load the Lora model
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
model.eval()
model = model.merge_and_unload()
model.save_pretrained(f"{script_args.output_name}")
tokenizer.save_pretrained(f"{script_args.output_name}")
# model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False)