File size: 1,361 Bytes
34b61ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from peft import LoraConfig
def add_lora_to_unet(unet, rank=4):
l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
for n, p in unet.named_parameters():
check_flag = 0
if "bias" in n or "norm" in n:
continue
for pattern in l_grep:
if pattern in n and ("down_blocks" in n or "conv_in" in n):
l_target_modules_encoder.append(n.replace(".weight",""))
break
elif pattern in n and ("up_blocks" in n or "conv_out" in n):
l_target_modules_decoder.append(n.replace(".weight",""))
break
elif pattern in n:
l_modules_others.append(n.replace(".weight",""))
break
unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_target_modules_encoder), adapter_name="default_encoder")
unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_target_modules_decoder), adapter_name="default_decoder")
unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_modules_others), adapter_name="default_others")
return unet |