import os import re import bisect from typing import Dict from modules import shared debug = os.environ.get('SD_LORA_DEBUG', None) is not None suffix_conversion = { "attentions": {}, "resnets": { "conv1": "in_layers_2", "conv2": "out_layers_3", "norm1": "in_layers_0", "norm2": "out_layers_0", "time_emb_proj": "emb_layers_1", "conv_shortcut": "skip_connection", } } re_digits = re.compile(r"\d+") re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") re_compiled = {} def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] for i in range(3): # num_blocks is 3 in sdxl # loop over downblocks/upblocks for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." sd_up_res_prefix = f"output_blocks.{3 * i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) # if i > 0: commentout for sdxl # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{2}." # change for sdxl unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." sd_mid_res_prefix = f"middle_block.{2 * j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_resnet = [ # (stable-diffusion, HF Diffusers) ("in_layers.0.", "norm1."), ("in_layers.2.", "conv1."), ("out_layers.0.", "norm2."), ("out_layers.3.", "conv2."), ("emb_layers.1.", "time_emb_proj."), ("skip_connection.", "conv_shortcut."), ] unet_conversion_map = [] for sd, hf in unet_conversion_map_layer: if "resnets" in hf: for sd_res, hf_res in unet_conversion_map_resnet: unet_conversion_map.append((sd + sd_res, hf + hf_res)) else: unet_conversion_map.append((sd, hf)) for j in range(2): hf_time_embed_prefix = f"time_embedding.linear_{j + 1}." sd_time_embed_prefix = f"time_embed.{j * 2}." unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) for j in range(2): hf_label_embed_prefix = f"add_embedding.linear_{j + 1}." sd_label_embed_prefix = f"label_emb.0.{j * 2}." unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) unet_conversion_map.append(("out.0.", "conv_norm_out.")) unet_conversion_map.append(("out.2.", "conv_out.")) sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} return sd_hf_conversion_map class KeyConvert: def __init__(self): if shared.backend == shared.Backend.ORIGINAL: self.converter = self.original self.is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping else: self.converter = self.diffusers self.is_sdxl = True if shared.sd_model_type == "sdxl" else False self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None self.LORA_PREFIX_UNET = "lora_unet_" self.LORA_PREFIX_TEXT_ENCODER = "lora_te_" self.OFT_PREFIX_UNET = "oft_unet_" # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_" self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_" def original(self, key): key = convert_diffusers_name_to_compvis(key, self.is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) if sd_module is None: m = re_x_proj.match(key) if m: sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" if sd_module is None and "lora_unet" in key: key = key.replace("lora_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) elif sd_module is None and "lora_te1_text_model" in key: key = key.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: key = key.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) # SegMoE begin expert_key = key + "_experts_0" expert_module = shared.sd_model.network_layer_mapping.get(expert_key, None) if expert_module is not None: sd_module = expert_module key = expert_key if sd_module is None: key = key.replace("_net_", "_experts_0_net_") sd_module = shared.sd_model.network_layer_mapping.get(key, None) key = key if isinstance(key, list) else [key] sd_module = sd_module if isinstance(sd_module, list) else [sd_module] if "_experts_0" in key[0]: i = expert_module = 1 while expert_module is not None: expert_key = key[0].replace("_experts_0", f"_experts_{i}") expert_module = shared.sd_model.network_layer_mapping.get(expert_key, None) if expert_module is not None: key.append(expert_key) sd_module.append(expert_module) i += 1 # SegMoE end return key, sd_module def diffusers(self, key): if self.is_sdxl: map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules map_keys.sort() search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "") position = bisect.bisect_right(map_keys, search_key) map_key = map_keys[position - 1] if search_key.startswith(map_key): key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object sd_module = shared.sd_model.network_layer_mapping.get(key, None) # SegMoE begin expert_key = key + "_experts_0" expert_module = shared.sd_model.network_layer_mapping.get(expert_key, None) if expert_module is not None: sd_module = expert_module key = expert_key if sd_module is None: key = key.replace("_net_", "_experts_0_net_") sd_module = shared.sd_model.network_layer_mapping.get(key, None) key = key if isinstance(key, list) else [key] sd_module = sd_module if isinstance(sd_module, list) else [sd_module] if "_experts_0" in key[0]: i = expert_module = 1 while expert_module is not None: expert_key = key[0].replace("_experts_0", f"_experts_{i}") expert_module = shared.sd_model.network_layer_mapping.get(expert_key, None) if expert_module is not None: key.append(expert_key) sd_module.append(expert_module) i += 1 # SegMoE end if debug and sd_module is None: raise RuntimeError(f"LoRA key not found in network_layer_mapping: key={key} mapping={shared.sd_model.network_layer_mapping.keys()}") return key, sd_module def __call__(self, key): return self.converter(key) def convert_diffusers_name_to_compvis(key, is_sd2): def match(match_list, regex_text): regex = re_compiled.get(regex_text) if regex is None: regex = re.compile(regex_text) re_compiled[regex_text] = regex r = re.match(regex, key) if not r: return False match_list.clear() match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) return True m = [] if match(m, r"lora_unet_conv_in(.*)"): return f'diffusion_model_input_blocks_0_0{m[0]}' if match(m, r"lora_unet_conv_out(.*)"): return f'diffusion_model_out_2{m[0]}' if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): if is_sd2: if 'mlp_fc1' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" elif 'mlp_fc2' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" else: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): if 'mlp_fc1' in m[1]: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" elif 'mlp_fc2' in m[1]: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" else: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" return key