# LoRA network module: currently conv2d is not fully supported # reference: # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py import math import os from typing import Dict, List, Optional, Type, Union from diffusers import AutoencoderKL from transformers import CLIPTextModel import numpy as np import torch import torch.nn as nn import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) HUNYUAN_TARGET_REPLACE_MODULES = ["MMDoubleStreamBlock", "MMSingleStreamBlock"] class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. """ def __init__( self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None, rank_dropout=None, module_dropout=None, split_dims: Optional[List[int]] = None, ): """ if alpha == 0 or None, alpha is rank (no scaling). split_dims is used to mimic the split qkv of multi-head attention. """ super().__init__() self.lora_name = lora_name if org_module.__class__.__name__ == "Conv2d": in_dim = org_module.in_channels out_dim = org_module.out_channels else: in_dim = org_module.in_features out_dim = org_module.out_features self.lora_dim = lora_dim self.split_dims = split_dims if split_dims is None: if org_module.__class__.__name__ == "Conv2d": kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.lora_up.weight) else: # conv2d not supported assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" # print(f"split_dims: {split_dims}") self.lora_down = torch.nn.ModuleList( [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] ) self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) for lora_down in self.lora_down: torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) for lora_up in self.lora_up: torch.nn.init.zeros_(lora_up.weight) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # for save/load # same as microsoft's self.multiplier = multiplier self.org_module = org_module # remove in applying self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward del self.org_module def forward(self, x): org_forwarded = self.org_forward(x) # module dropout if self.module_dropout is not None and self.training: if torch.rand(1) < self.module_dropout: return org_forwarded if self.split_dims is None: lx = self.lora_down(x) # normal dropout if self.dropout is not None and self.training: lx = torch.nn.functional.dropout(lx, p=self.dropout) # rank dropout if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout if len(lx.size()) == 3: mask = mask.unsqueeze(1) # for Text Encoder elif len(lx.size()) == 4: mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d lx = lx * mask # scaling for rank dropout: treat as if the rank is changed scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability else: scale = self.scale lx = self.lora_up(lx) return org_forwarded + lx * self.multiplier * scale else: lxs = [lora_down(x) for lora_down in self.lora_down] # normal dropout if self.dropout is not None and self.training: lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] # rank dropout if self.rank_dropout is not None and self.training: masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] for i in range(len(lxs)): if len(lx.size()) == 3: masks[i] = masks[i].unsqueeze(1) elif len(lx.size()) == 4: masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) lxs[i] = lxs[i] * masks[i] # scaling for rank dropout: treat as if the rank is changed scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability else: scale = self.scale lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale class LoRAInfModule(LoRAModule): def __init__( self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, **kwargs, ): # no dropout for inference super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) self.org_module_ref = [org_module] # for reference self.enabled = True self.network: LoRANetwork = None def set_network(self, network): self.network = network # merge weight to org_module def merge_to(self, sd, dtype, device): # extract weight from org_module org_sd = self.org_module.state_dict() weight = org_sd["weight"] org_dtype = weight.dtype org_device = weight.device weight = weight.to(device, dtype=torch.float) # for calculation if dtype is None: dtype = org_dtype if device is None: device = org_device if self.split_dims is None: # get up/down weight down_weight = sd["lora_down.weight"].to(device, dtype=torch.float) up_weight = sd["lora_up.weight"].to(device, dtype=torch.float) # merge weight if len(weight.size()) == 2: # linear weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale elif down_weight.size()[2:4] == (1, 1): # conv2d 1x1 weight = ( weight + self.multiplier * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * self.scale ) else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) # logger.info(conved.size(), weight.size(), module.stride, module.padding) weight = weight + self.multiplier * conved * self.scale # set weight to org_module org_sd["weight"] = weight.to(org_device, dtype=dtype) self.org_module.load_state_dict(org_sd) else: # split_dims total_dims = sum(self.split_dims) for i in range(len(self.split_dims)): # get up/down weight down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) # pad up_weight -> (total_dims, rank) padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight # merge weight weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale # set weight to org_module org_sd["weight"] = weight.to(dtype) self.org_module.load_state_dict(org_sd) # return weight for merge def get_weight(self, multiplier=None): if multiplier is None: multiplier = self.multiplier # get up/down weight from module up_weight = self.lora_up.weight.to(torch.float) down_weight = self.lora_down.weight.to(torch.float) # pre-calculated weight if len(down_weight.size()) == 2: # linear weight = self.multiplier * (up_weight @ down_weight) * self.scale elif down_weight.size()[2:4] == (1, 1): # conv2d 1x1 weight = ( self.multiplier * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * self.scale ) else: # conv2d 3x3 conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) weight = self.multiplier * conved * self.scale return weight def default_forward(self, x): # logger.info(f"default_forward {self.lora_name} {x.size()}") if self.split_dims is None: lx = self.lora_down(x) lx = self.lora_up(lx) return self.org_forward(x) + lx * self.multiplier * self.scale else: lxs = [lora_down(x) for lora_down in self.lora_down] lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale def forward(self, x): if not self.enabled: return self.org_forward(x) return self.default_forward(x) def create_network_hunyuan_video( multiplier: float, network_dim: Optional[int], network_alpha: Optional[float], vae: nn.Module, text_encoders: List[nn.Module], unet: nn.Module, neuron_dropout: Optional[float] = None, **kwargs, ): return create_network( HUNYUAN_TARGET_REPLACE_MODULES, "lora_unet", multiplier, network_dim, network_alpha, vae, text_encoders, unet, neuron_dropout=neuron_dropout, **kwargs, ) def create_network( target_replace_modules: List[str], prefix: str, multiplier: float, network_dim: Optional[int], network_alpha: Optional[float], vae: nn.Module, text_encoders: List[nn.Module], unet: nn.Module, neuron_dropout: Optional[float] = None, **kwargs, ): if network_dim is None: network_dim = 4 # default if network_alpha is None: network_alpha = 1.0 # extract dim/alpha for conv2d, and block dim conv_dim = kwargs.get("conv_dim", None) conv_alpha = kwargs.get("conv_alpha", None) if conv_dim is not None: conv_dim = int(conv_dim) if conv_alpha is None: conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) # TODO generic rank/dim setting with regular expression # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) if rank_dropout is not None: rank_dropout = float(rank_dropout) module_dropout = kwargs.get("module_dropout", None) if module_dropout is not None: module_dropout = float(module_dropout) # verbose verbose = kwargs.get("verbose", False) if verbose is not None: verbose = True if verbose == "True" else False # too many arguments ( ^ω^)・・・ network = LoRANetwork( target_replace_modules, prefix, text_encoders, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha, dropout=neuron_dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, verbose=verbose, ) loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) # loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) # loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None # loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None # loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None if loraplus_lr_ratio is not None: # or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: network.set_loraplus_lr_ratio(loraplus_lr_ratio) # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) return network class LoRANetwork(torch.nn.Module): # only supports U-Net (DiT), Text Encoders are not supported def __init__( self, target_replace_modules: List[str], prefix: str, text_encoders: Union[List[CLIPTextModel], CLIPTextModel], unet: nn.Module, multiplier: float = 1.0, lora_dim: int = 4, alpha: float = 1, dropout: Optional[float] = None, rank_dropout: Optional[float] = None, module_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, module_class: Type[object] = LoRAModule, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() self.multiplier = multiplier self.lora_dim = lora_dim self.alpha = alpha self.conv_lora_dim = conv_lora_dim self.conv_alpha = conv_alpha self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.target_replace_modules = target_replace_modules self.prefix = prefix self.loraplus_lr_ratio = None # self.loraplus_unet_lr_ratio = None # self.loraplus_text_encoder_lr_ratio = None if modules_dim is not None: logger.info(f"create LoRA network from weights") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" ) # if self.conv_lora_dim is not None: # logger.info( # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" # ) # if train_t5xxl: # logger.info(f"train T5XXL as well") # create module instances def create_modules( is_unet: bool, pfx: str, root_module: torch.nn.Module, target_replace_mods: List[str], filter: Optional[str] = None, default_dim: Optional[int] = None, ) -> List[LoRAModule]: loras = [] skipped = [] for name, module in root_module.named_modules(): if target_replace_mods is None or module.__class__.__name__ in target_replace_mods: if target_replace_mods is None: # dirty hack for all modules module = root_module # search all modules for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: original_name = (name + "." if name else "") + child_name lora_name = f"{pfx}.{original_name}".replace(".", "_") if filter is not None and not filter in lora_name: continue dim = None alpha = None if modules_dim is not None: # モジュール指定あり if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] else: # 通常、すべて対象とする if is_linear or is_conv2d_1x1: dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha if dim is None or dim == 0: # skipした情報を出力 if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): skipped.append(lora_name) continue lora = module_class( lora_name, child_module, self.multiplier, dim, alpha, dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, ) loras.append(lora) if target_replace_mods is None: break # all modules are searched return loras, skipped # # create LoRA for text encoder # # it is redundant to create LoRA modules even if they are not used self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] # skipped_te = [] # for i, text_encoder in enumerate(text_encoders): # index = i # if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False # break # logger.info(f"create LoRA for Text Encoder {index+1}:") # text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) # logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") # self.text_encoder_loras.extend(text_encoder_loras) # skipped_te += skipped # create LoRA for U-Net self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, prefix, unet, target_replace_modules) logger.info(f"create LoRA for U-Net/DiT: {len(self.unet_loras)} modules.") if verbose: for lora in self.unet_loras: logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") skipped = skipped_un if verbose and len(skipped) > 0: logger.warning( f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: logger.info(f"\t{name}") # assertion names = set() for lora in self.text_encoder_loras + self.unet_loras: assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) def prepare_network(self, args): """ called after the network is created """ pass def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier def set_enabled(self, is_enabled): for lora in self.text_encoder_loras + self.unet_loras: lora.enabled = is_enabled def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file weights_sd = load_file(file) else: weights_sd = torch.load(file, map_location="cpu") info = self.load_state_dict(weights_sd, False) return info def apply_to( self, text_encoders: Optional[nn.Module], unet: Optional[nn.Module], apply_text_encoder: bool = True, apply_unet: bool = True, ): if apply_text_encoder: logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") else: self.text_encoder_loras = [] if apply_unet: logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") else: self.unet_loras = [] for lora in self.text_encoder_loras + self.unet_loras: lora.apply_to() self.add_module(lora.lora_name, lora) # マージできるかどうかを返す def is_mergeable(self): return True # TODO refactor to common function with apply_to def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None): for lora in self.text_encoder_loras + self.unet_loras: sd_for_lora = {} for key in weights_sd.keys(): if key.startswith(lora.lora_name): sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] if len(sd_for_lora) == 0: logger.info(f"no weight for {lora.lora_name}") continue lora.merge_to(sd_for_lora, dtype, device) logger.info(f"weights are merged") def set_loraplus_lr_ratio(self, loraplus_lr_ratio): # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): self.loraplus_lr_ratio = loraplus_lr_ratio logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_lr_ratio}") # logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") def prepare_optimizer_params(self, unet_lr: float = 1e-4, **kwargs): self.requires_grad_(True) all_params = [] lr_descriptions = [] def assemble_params(loras, lr, loraplus_ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): if loraplus_ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param params = [] descriptions = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} if len(param_data["params"]) == 0: continue if lr is not None: if key == "plus": param_data["lr"] = lr * loraplus_ratio else: param_data["lr"] = lr if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: logger.info("NO LR skipping!") continue params.append(param_data) descriptions.append("plus" if key == "plus" else "") return params, descriptions if self.unet_loras: params, descriptions = assemble_params(self.unet_loras, unet_lr, self.loraplus_lr_ratio) all_params.extend(params) lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) return all_params, lr_descriptions def enable_gradient_checkpointing(self): # not supported pass def prepare_grad_etc(self, unet): self.requires_grad_(True) def on_epoch_start(self, unet): self.train() def on_step_start(self): pass def get_trainable_params(self): return self.parameters() def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None state_dict = self.state_dict() if dtype is not None: for key in list(state_dict.keys()): v = state_dict[key] v = v.detach().clone().to("cpu").to(dtype) state_dict[key] = v if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import save_file from utils import model_utils # Precalculate model hashes to save time on indexing if metadata is None: metadata = {} model_hash, legacy_hash = model_utils.precalculate_safetensors_hashes(state_dict, metadata) metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash save_file(state_dict, file, metadata) else: torch.save(state_dict, file) def backup_weights(self): # 重みのバックアップを行う loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras for lora in loras: org_module = lora.org_module_ref[0] if not hasattr(org_module, "_lora_org_weight"): sd = org_module.state_dict() org_module._lora_org_weight = sd["weight"].detach().clone() org_module._lora_restored = True def restore_weights(self): # 重みのリストアを行う loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras for lora in loras: org_module = lora.org_module_ref[0] if not org_module._lora_restored: sd = org_module.state_dict() sd["weight"] = org_module._lora_org_weight org_module.load_state_dict(sd) org_module._lora_restored = True def pre_calculation(self): # 事前計算を行う loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras for lora in loras: org_module = lora.org_module_ref[0] sd = org_module.state_dict() org_weight = sd["weight"] lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) sd["weight"] = org_weight + lora_weight assert sd["weight"].shape == org_weight.shape org_module.load_state_dict(sd) org_module._lora_restored = False lora.enabled = False def apply_max_norm_regularization(self, max_norm_value, device): downkeys = [] upkeys = [] alphakeys = [] norms = [] keys_scaled = 0 state_dict = self.state_dict() for key in state_dict.keys(): if "lora_down" in key and "weight" in key: downkeys.append(key) upkeys.append(key.replace("lora_down", "lora_up")) alphakeys.append(key.replace("lora_down.weight", "alpha")) for i in range(len(downkeys)): down = state_dict[downkeys[i]].to(device) up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) dim = down.shape[0] scale = alpha / dim if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) else: updown = up @ down updown *= scale norm = updown.norm().clamp(min=max_norm_value / 2) desired = torch.clamp(norm, max=max_norm_value) ratio = desired.cpu() / norm.cpu() sqrt_ratio = ratio**0.5 if ratio != 1: keys_scaled += 1 state_dict[upkeys[i]] *= sqrt_ratio state_dict[downkeys[i]] *= sqrt_ratio scalednorm = updown.norm() * ratio norms.append(scalednorm.item()) return keys_scaled, sum(norms) / len(norms), max(norms) def create_network_from_weights_hunyuan_video( multiplier: float, weights_sd: Dict[str, torch.Tensor], text_encoders: Optional[List[nn.Module]] = None, unet: Optional[nn.Module] = None, for_inference: bool = False, **kwargs, ) -> LoRANetwork: return create_network_from_weights( HUNYUAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs ) # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights( target_replace_modules: List[str], multiplier: float, weights_sd: Dict[str, torch.Tensor], text_encoders: Optional[List[nn.Module]] = None, unet: Optional[nn.Module] = None, for_inference: bool = False, **kwargs, ) -> LoRANetwork: # get dim/alpha mapping modules_dim = {} modules_alpha = {} for key, value in weights_sd.items(): if "." not in key: continue lora_name = key.split(".")[0] if "alpha" in key: modules_alpha[lora_name] = value elif "lora_down" in key: dim = value.shape[0] modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( target_replace_modules, "lora_unet", text_encoders, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class, ) return network