lorahub / util.py
SivilTaram
update demo
470be5c
raw
history blame
6.06 kB
from transformers import AutoModelForSeq2SeqLM
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator
from transformers import AutoTokenizer
from tqdm import tqdm
import pandas as pd
import numpy
import random
import nevergrad as ng
from peft.utils.save_and_load import set_peft_model_state_dict, get_peft_model_state_dict
from peft import PeftModel, PeftConfig
from functools import partial
random.seed(42)
numpy.random.seed(42)
def load_base_model_and_lora_modules(lora_module_list):
# use gpu if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# load basic model
default_peft_model_id = lora_module_list[0]
# find the base model
model_name_or_path = PeftConfig.from_pretrained(default_peft_model_id).base_model_name_or_path
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
# 0 is the default model
peft_model = PeftModel.from_pretrained(base_model, default_peft_model_id)
peft_model = peft_model.to(device)
peft_model.eval()
print("> Begin to load lora modules")
cache = {}
for peft_model_id in tqdm(lora_module_list):
print("> Loading {} ...".format(peft_model_id))
cur_peft_model = PeftModel.from_pretrained(base_model, peft_model_id)
cache[peft_model_id] = get_peft_model_state_dict(cur_peft_model)
return peft_model, tokenizer, cache
def preprocess_function(examples, tokenizer):
inputs = examples["input"]
targets = examples["output"]
model_inputs = tokenizer(
inputs,
max_length=2048,
padding=True,
truncation=True,
return_tensors="pt",
)
labels = tokenizer(
targets,
max_length=256,
padding=True,
truncation=True,
return_tensors="pt",
)
labels = labels["input_ids"]
labels[labels == tokenizer.pad_token_id] = -100
model_inputs["labels"] = labels
return model_inputs
def load_dataset_and_run(example_inputs, example_outputs, tokenizer):
df = [
{"input": example_inputs[i], "output": example_outputs[i]}
for i in range(len(example_inputs))
]
dataset = Dataset.from_pandas(pd.DataFrame(df))
preprocess_func_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer)
processed_datasets = dataset.map(
preprocess_func_with_tokenizer,
batched=True,
num_proc=1,
desc="Running tokenizer on dataset",
)
return processed_datasets
def get_score(weights, model, cache, example_dataset):
# the composed lora state dict
final_state_dict = {}
# module list is the list
lora_module_list = list(cache.keys())
# all keys are the same
keys = cache[lora_module_list[0]].keys()
for i, peft_model_id in enumerate(lora_module_list):
lora_state_dict = cache[peft_model_id]
if i == 0:
for key in keys:
final_state_dict[key] = weights[i] * lora_state_dict[key]
else:
for key in keys:
final_state_dict[key] = (
final_state_dict[key] + weights[i] * lora_state_dict[key]
)
# reload the model with the new adapter config
set_peft_model_state_dict(model, final_state_dict)
def get_loss():
# use gpu if available
train_dataset = example_dataset
train_dataloader = DataLoader(
train_dataset,
collate_fn=default_data_collator,
batch_size=len(train_dataset),
pin_memory=True,
)
train_loss = 0
with torch.no_grad():
device = "cuda" if torch.cuda.is_available() else "cpu"
for _, batch in enumerate(train_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
train_loss += loss.detach().float()
loss = train_loss.float()
return float(loss) / len(train_dataset["input"])
# minimize the metric
loss = get_loss()
# L1 regularization term
sum_of_squares = sum([abs(x) for x in weights]) / len(weights)
metric_val = loss + 0.05 * sum_of_squares
return metric_val
def get_final_weights(weights, lora_module_list, cache):
final_state_dict = {}
keys = cache[lora_module_list[0]].keys()
for i, peft_model_id in enumerate(lora_module_list):
lora_state_dict = cache[peft_model_id]
if i == 0:
for key in keys:
final_state_dict[key] = weights[i] * lora_state_dict[key]
else:
for key in keys:
final_state_dict[key] = (
final_state_dict[key] + weights[i] * lora_state_dict[key]
)
return final_state_dict
def lorahub_learning(lora_module_list, text_input, text_output, max_inference_step):
number_of_loras = len(lora_module_list)
if number_of_loras == 0:
return None
# load model
model, tokenizer, cache = load_base_model_and_lora_modules(lora_module_list)
# process dataset
dataset = load_dataset_and_run(text_input.split("\n"), text_output.split("\n"), tokenizer)
get_score_partial = partial(get_score, model=model, cache=cache,
example_dataset=dataset)
# set up the limit of the weights
instrum = ng.p.Array(
init=[0] * number_of_loras,
upper=[1.5] * number_of_loras,
lower=[-1.5] * number_of_loras,
)
optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=max_inference_step)
print("> Begin to perform gradient-free optimization ...")
recommendation = optimizer.minimize(get_score_partial, verbosity=1)
final_lora = get_final_weights(recommendation.value, lora_module_list, cache)
return recommendation, final_lora