import torch from peft import LoraConfig def add_lora_to_unet(unet, rank=4): l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], [] l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"] for n, p in unet.named_parameters(): check_flag = 0 if "bias" in n or "norm" in n: continue for pattern in l_grep: if pattern in n and ("down_blocks" in n or "conv_in" in n): l_target_modules_encoder.append(n.replace(".weight","")) break elif pattern in n and ("up_blocks" in n or "conv_out" in n): l_target_modules_decoder.append(n.replace(".weight","")) break elif pattern in n: l_modules_others.append(n.replace(".weight","")) break unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_target_modules_encoder), adapter_name="default_encoder") unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_target_modules_decoder), adapter_name="default_decoder") unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_modules_others), adapter_name="default_others") return unet