|
|
|
|
|
|
|
|
|
import argparse |
|
import math |
|
import os |
|
import torch |
|
from safetensors.torch import load_file, save_file, safe_open |
|
from tqdm import tqdm |
|
from library import train_util, model_util |
|
import numpy as np |
|
|
|
|
|
def load_state_dict(file_name): |
|
if model_util.is_safetensors(file_name): |
|
sd = load_file(file_name) |
|
with safe_open(file_name, framework="pt") as f: |
|
metadata = f.metadata() |
|
else: |
|
sd = torch.load(file_name, map_location="cpu") |
|
metadata = None |
|
|
|
return sd, metadata |
|
|
|
|
|
def save_to_file(file_name, model, metadata): |
|
if model_util.is_safetensors(file_name): |
|
save_file(model, file_name, metadata) |
|
else: |
|
torch.save(model, file_name) |
|
|
|
|
|
def split_lora_model(lora_sd, unit): |
|
max_rank = 0 |
|
|
|
|
|
for key, value in lora_sd.items(): |
|
if "lora_down" in key: |
|
rank = value.size()[0] |
|
if rank > max_rank: |
|
max_rank = rank |
|
print(f"Max rank: {max_rank}") |
|
|
|
rank = unit |
|
split_models = [] |
|
new_alpha = None |
|
while rank < max_rank: |
|
print(f"Splitting rank {rank}") |
|
new_sd = {} |
|
for key, value in lora_sd.items(): |
|
if "lora_down" in key: |
|
new_sd[key] = value[:rank].contiguous() |
|
elif "lora_up" in key: |
|
new_sd[key] = value[:, :rank].contiguous() |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
new_sd[key] = value |
|
|
|
split_models.append((new_sd, rank, new_alpha)) |
|
rank += unit |
|
|
|
return max_rank, split_models |
|
|
|
|
|
def split(args): |
|
print("loading Model...") |
|
lora_sd, metadata = load_state_dict(args.model) |
|
|
|
print("Splitting Model...") |
|
original_rank, split_models = split_lora_model(lora_sd, args.unit) |
|
|
|
comment = metadata.get("ss_training_comment", "") |
|
for state_dict, new_rank, new_alpha in split_models: |
|
|
|
if metadata is None: |
|
new_metadata = {} |
|
else: |
|
new_metadata = metadata.copy() |
|
|
|
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" |
|
new_metadata["ss_network_dim"] = str(new_rank) |
|
|
|
|
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) |
|
metadata["sshs_model_hash"] = model_hash |
|
metadata["sshs_legacy_hash"] = legacy_hash |
|
|
|
filename, ext = os.path.splitext(args.save_to) |
|
model_file_name = filename + f"-{new_rank:04d}{ext}" |
|
|
|
print(f"saving model to: {model_file_name}") |
|
save_to_file(model_file_name, state_dict, new_metadata) |
|
|
|
|
|
def setup_parser() -> argparse.ArgumentParser: |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") |
|
parser.add_argument( |
|
"--save_to", |
|
type=str, |
|
default=None, |
|
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", |
|
) |
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default=None, |
|
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", |
|
) |
|
|
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = setup_parser() |
|
|
|
args = parser.parse_args() |
|
split(args) |
|
|