import os import re import time import torch import numpy as np from safetensors.torch import load_file from diffusers.loaders import LoraLoaderMixin from diffusers.loaders.lora_conversion_utils import _maybe_map_sgm_blocks_to_diffusers, _convert_kohya_lora_to_diffusers from types import SimpleNamespace import logging.handlers LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" LORA_UNET_LAYERS = ['lora_unet_down_blocks_0_attentions_0', 'lora_unet_down_blocks_0_attentions_1', 'lora_unet_down_blocks_1_attentions_0', 'lora_unet_down_blocks_1_attentions_1', 'lora_unet_down_blocks_2_attentions_0', 'lora_unet_down_blocks_2_attentions_1', 'lora_unet_mid_block_attentions_0', 'lora_unet_up_blocks_1_attentions_0', 'lora_unet_up_blocks_1_attentions_1', 'lora_unet_up_blocks_1_attentions_2', 'lora_unet_up_blocks_2_attentions_0', 'lora_unet_up_blocks_2_attentions_1', 'lora_unet_up_blocks_2_attentions_2', 'lora_unet_up_blocks_3_attentions_0', 'lora_unet_up_blocks_3_attentions_1', 'lora_unet_up_blocks_3_attentions_2'] def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=1.0, lora_file_format="fp32", device="cuda:0"): if lora_file_format == "fp32": model_dtype = np.float32 elif lora_file_format == "fp16": model_dtype = np.float16 else: raise Exception(f"unsupported model dtype: {lora_file_format}") all_files = os.scandir(lora_model_path) unload_dict = [] # directly update weight in diffusers model for file in all_files: if 'text' in file.name: layer_infos = file.name.split('.')[0].split( 'text_model_')[-1].split('_') curr_layer = clip_model.text_model else: continue # find the target layer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) if len(layer_infos) > 0: temp_name = layer_infos.pop(0) # if temp_name == "self": # temp_name += "_" + layer_infos.pop(0) # elif temp_name != "mlp" and len(layer_infos) == 1: # temp_name += "_" + layer_infos.pop(0) elif len(layer_infos) == 0: break except Exception: if len(temp_name) > 0: temp_name += '_'+layer_infos.pop(0) else: temp_name = layer_infos.pop(0) data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to( clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape) if len(curr_layer.weight.data) == 4: adding_weight = alpha * data.permute(0, 3, 1, 2) else: adding_weight = alpha * data curr_layer.weight.data += adding_weight curr_layer_unload_data = { "layer": curr_layer, "added_weight": adding_weight } unload_dict.append(curr_layer_unload_data) return unload_dict def add_xltext_lora_layer(clip_model, clip_model_2, lora_model_path, alpha=1.0, lora_file_format="fp32", device="cuda:0"): if lora_file_format == "fp32": model_dtype = np.float32 elif lora_file_format == "fp16": model_dtype = np.float16 else: raise Exception(f"unsupported model dtype: {lora_file_format}") all_files = os.scandir(lora_model_path) unload_dict = [] # directly update weight in diffusers model for file in all_files: if 'text' in file.name: layer_infos = file.name.split('.')[0].split( 'text_model_')[-1].split('_') if "text_encoder_2" in file.name: curr_layer = clip_model_2.text_model elif "text_encoder" in file.name: curr_layer = clip_model.text_model else: raise ValueError( "Cannot identify clip model, need text_encoder or text_encoder_2 in filename, found: ", file.name) else: continue # find the target layer # find the target layer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) if len(layer_infos) > 0: temp_name = layer_infos.pop(0) # if temp_name == "self": # temp_name += "_" + layer_infos.pop(0) # elif temp_name != "mlp" and len(layer_infos) == 1: # temp_name += "_" + layer_infos.pop(0) elif len(layer_infos) == 0: break except Exception: if len(temp_name) > 0: temp_name += '_'+layer_infos.pop(0) else: temp_name = layer_infos.pop(0) data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to( clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape) if len(curr_layer.weight.data) == 4: adding_weight = alpha * data.permute(0, 3, 1, 2) else: adding_weight = alpha * data curr_layer.weight.data += adding_weight curr_layer_unload_data = { "layer": curr_layer, "added_weight": adding_weight } unload_dict.append(curr_layer_unload_data) return unload_dict def lora_trans(state_dict): loraload = LoraLoaderMixin() unet_config = SimpleNamespace(**{'layers_per_block': 2}) state_dicts = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dicts_trans, state_dicts_alpha = _convert_kohya_lora_to_diffusers( state_dicts) keys = list(state_dicts_trans.keys()) for k in keys: key = k.replace('processor.', '') for x in ['.lora_linear_layer.', '_lora.', '.lora.']: key = key.replace(x, '.lora_') if key.find('text_encoder') >= 0: for x in ['q', 'k', 'v', 'out']: key = key.replace(f'.to_{x}.', f'.{x}_proj.') key = key.replace('to_out.', 'to_out.0.') if key != k: state_dicts_trans[key] = state_dicts_trans.pop(k) alpha = torch.Tensor(list(set(list(state_dicts_alpha.values())))) state_dicts_trans.update({'lora.alpha': alpha}) return state_dicts_trans def load_state_dict(filename, need_trans=True): state_dict = load_file(os.path.abspath(filename), device="cpu") if need_trans: state_dict = lora_trans(state_dict) return state_dict def move_state_dict_to_cuda(state_dict): ret_state_dict = {} for item in state_dict: ret_state_dict[item] = state_dict[item].cuda() return ret_state_dict def add_lora_to_opt_model(state_dict, unet, clip_model, clip_model_2, alpha=1.0, need_trans=False): # directly update weight in diffusers model state_dict = move_state_dict_to_cuda(state_dict) alpha_ks = list(filter(lambda x: x.find('.alpha') >= 0, state_dict)) lora_alpha = state_dict[alpha_ks[0]].item() if len(alpha_ks) > 0 else -1 visited = set() for key in state_dict: # print(key) # it is suggested to print out the key, it usually will be something like below # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" # as we have set the alpha beforehand, so just skip if '.alpha' in key or key in visited: continue if "text" in key: curr_layer = clip_model_2 if key.find( 'text_encoder_2') >= 0 else clip_model # if is_sdxl: layer_infos = key.split('.')[1:] for x in layer_infos: try: curr_layer = curr_layer.__getattr__(x) except Exception: break # update weight pair_keys = [key.replace("lora_down", "lora_up"), key.replace("lora_up", "lora_down")] weight_up, weight_down = state_dict[pair_keys[0] ], state_dict[pair_keys[1]] weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0 if len(weight_up.shape) == 4: weight_up = weight_up.squeeze([2, 3]) weight_down = weight_down.squeeze([2, 3]) if len(weight_down.shape) == 4: adding_weight = torch.einsum( 'a b, b c h w -> a c h w', weight_up, weight_down) else: adding_weight = torch.mm( weight_up, weight_down).unsqueeze(2).unsqueeze(3) else: adding_weight = torch.mm(weight_up, weight_down) adding_weight = alpha * weight_scale * adding_weight curr_layer.weight.data += adding_weight.to(torch.float16) # update visited list for item in pair_keys: visited.add(item) elif "unet" in key: layer_infos = key layer_infos = layer_infos.replace(".lora_up.weight", "") layer_infos = layer_infos.replace(".lora_down.weight", "") layer_infos = layer_infos[5:] layer_names = layer_infos.split(".") layers = [] i = 0 while i < len(layer_names): if len(layers) >= 4: layers[-1] += "_" + layer_names[i] elif i + 1 < len(layer_names) and layer_names[i+1].isdigit(): layers.append(layer_names[i] + "_" + layer_names[i+1]) i += 1 elif len(layers) > 0 and "samplers" in layers[-1]: layers[-1] += "_" + layer_names[i] else: layers.append(layer_names[i]) i += 1 layer_infos = ".".join(layers) pair_keys = [key.replace("lora_down", "lora_up"), key.replace("lora_up", "lora_down")] # update weight if len(state_dict[pair_keys[0]].shape) == 4: weight_up = state_dict[pair_keys[0]].squeeze( 3).squeeze(2).to(torch.float32) weight_down = state_dict[pair_keys[1]].to(torch.float32) weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0 weight_up, weight_down = state_dict[pair_keys[0] ], state_dict[pair_keys[1]] weight_up = weight_up.squeeze([2, 3]).to(torch.float32) weight_down = weight_down.squeeze([2, 3]).to(torch.float32) if len(weight_down.shape) == 4: curr_layer_weight = weight_scale * \ torch.einsum('a b, b c h w -> a c h w', weight_up, weight_down) else: curr_layer_weight = weight_scale * \ torch.mm(weight_up, weight_down).unsqueeze( 2).unsqueeze(3) curr_layer_weight = curr_layer_weight.permute(0, 2, 3, 1) else: weight_up = state_dict[pair_keys[0]].to(torch.float32) weight_down = state_dict[pair_keys[1]].to(torch.float32) weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0 curr_layer_weight = weight_scale * \ torch.mm(weight_up, weight_down) # curr_layer_weight = curr_layer_weight.to(torch.float16) unet.load_lora_by_name(layers, curr_layer_weight, alpha) for item in pair_keys: visited.add(item)