File size: 5,001 Bytes
e13f5a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from utils import model_utils
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def convert_from_diffusers(prefix, weights_sd):
# convert from diffusers(?) to default LoRA
# Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
# note: Diffusers has no alpha, so alpha is set to rank
new_weights_sd = {}
lora_dims = {}
for key, weight in weights_sd.items():
diffusers_prefix, key_body = key.split(".", 1)
if diffusers_prefix != "diffusion_model":
logger.warning(f"unexpected key: {key} in diffusers format")
continue
new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
new_weights_sd[new_key] = weight
lora_name = new_key.split(".")[0] # before first dot
if lora_name not in lora_dims and "lora_down" in new_key:
lora_dims[lora_name] = weight.shape[0]
# add alpha with rank
for lora_name, dim in lora_dims.items():
new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
return new_weights_sd
def convert_to_diffusers(prefix, weights_sd):
# convert from default LoRA to diffusers
# get alphas
lora_alphas = {}
for key, weight in weights_sd.items():
if key.startswith(prefix):
lora_name = key.split(".", 1)[0] # before first dot
if lora_name not in lora_alphas and "alpha" in key:
lora_alphas[lora_name] = weight
new_weights_sd = {}
for key, weight in weights_sd.items():
if key.startswith(prefix):
if "alpha" in key:
continue
lora_name = key.split(".", 1)[0] # before first dot
# HunyuanVideo lora name to module name: ugly but works
module_name = lora_name[len(prefix) :] # remove "lora_unet_"
module_name = module_name.replace("_", ".") # replace "_" with "."
module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
module_name = module_name.replace("img.", "img_") # fix img
module_name = module_name.replace("txt.", "txt_") # fix txt
module_name = module_name.replace("attn.", "attn_") # fix attn
diffusers_prefix = "diffusion_model"
if "lora_down" in key:
new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
dim = weight.shape[0]
elif "lora_up" in key:
new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
dim = weight.shape[1]
else:
logger.warning(f"unexpected key: {key} in default LoRA format")
continue
# scale weight by alpha
if lora_name in lora_alphas:
# we scale both down and up, so scale is sqrt
scale = lora_alphas[lora_name] / dim
scale = scale.sqrt()
weight = weight * scale
else:
logger.warning(f"missing alpha for {lora_name}")
new_weights_sd[new_key] = weight
return new_weights_sd
def convert(input_file, output_file, target_format):
logger.info(f"loading {input_file}")
weights_sd = load_file(input_file)
with safe_open(input_file, framework="pt") as f:
metadata = f.metadata()
logger.info(f"converting to {target_format}")
prefix = "lora_unet_"
if target_format == "default":
new_weights_sd = convert_from_diffusers(prefix, weights_sd)
metadata = metadata or {}
model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
elif target_format == "other":
new_weights_sd = convert_to_diffusers(prefix, weights_sd)
else:
raise ValueError(f"unknown target format: {target_format}")
logger.info(f"saving to {output_file}")
save_file(new_weights_sd, output_file, metadata=metadata)
logger.info("done")
def parse_args():
parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
parser.add_argument("--input", type=str, required=True, help="input model file")
parser.add_argument("--output", type=str, required=True, help="output model file")
parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
convert(args.input, args.output, args.target)
|