import torch from types import SimpleNamespace from .lora import ( extract_lora_ups_down, inject_trainable_lora_extended, monkeypatch_or_replace_lora_extended, ) CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"] lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo") lora_func_types = dict(loader="loader", injector="injector") lora_args = dict( model=None, loras=None, target_replace_module=[], target_module=[], r=4, search_class=[torch.nn.Linear], dropout=0, lora_bias="none", ) LoraVersions = SimpleNamespace(**lora_versions) LoraFuncTypes = SimpleNamespace(**lora_func_types) LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo] LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector] def filter_dict(_dict, keys=[]): if len(keys) == 0: assert "Keys cannot empty for filtering return dict." for k in keys: if k not in lora_args.keys(): assert f"{k} does not exist in available LoRA arguments" return {k: v for k, v in _dict.items() if k in keys} class LoraHandler(object): def __init__( self, version: str = LoraVersions.cloneofsimo, use_unet_lora: bool = False, use_text_lora: bool = False, save_for_webui: bool = False, only_for_webui: bool = False, lora_bias: str = "none", unet_replace_modules: list = ["UNet3DConditionModel"], ): self.version = version assert self.is_cloneofsimo_lora() self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader) self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector) self.lora_bias = lora_bias self.use_unet_lora = use_unet_lora self.use_text_lora = use_text_lora self.save_for_webui = save_for_webui self.only_for_webui = only_for_webui self.unet_replace_modules = unet_replace_modules self.use_lora = any([use_text_lora, use_unet_lora]) if self.use_lora: print(f"Using LoRA Version: {self.version}") def is_cloneofsimo_lora(self): return self.version == LoraVersions.cloneofsimo def get_lora_func(self, func_type: str = LoraFuncTypes.loader): if func_type == LoraFuncTypes.loader: return monkeypatch_or_replace_lora_extended if func_type == LoraFuncTypes.injector: return inject_trainable_lora_extended assert "LoRA Version does not exist." def get_lora_func_args( self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias ): return_dict = lora_args.copy() return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS) return_dict.update( { "model": model, "loras": lora_path, "target_replace_module": replace_modules, "r": r, } ) return return_dict def do_lora_injection( self, model, replace_modules, bias="none", dropout=0, r=4, lora_loader_args=None, ): REPLACE_MODULES = replace_modules params = None negation = None injector_args = lora_loader_args params, negation = self.lora_injector(**injector_args) for _up, _down in extract_lora_ups_down( model, target_replace_module=REPLACE_MODULES ): if all(x is not None for x in [_up, _down]): print( f"Lora successfully injected into {model.__class__.__name__}." ) break return params, negation def add_lora_to_model( self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16 ): params = None negation = None lora_loader_args = self.get_lora_func_args( lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias ) if use_lora: params, negation = self.do_lora_injection( model, replace_modules, bias=self.lora_bias, lora_loader_args=lora_loader_args, dropout=dropout, r=r, ) params = model if params is None else params return params, negation