File size: 1,361 Bytes
34b61ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from peft import LoraConfig

def add_lora_to_unet(unet, rank=4):
    l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
    l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
    for n, p in unet.named_parameters():
        check_flag = 0
        if "bias" in n or "norm" in n:
            continue
        for pattern in l_grep:
            if pattern in n and ("down_blocks" in n or "conv_in" in n):
                l_target_modules_encoder.append(n.replace(".weight",""))
                break
            elif pattern in n and ("up_blocks" in n or "conv_out" in n):
                l_target_modules_decoder.append(n.replace(".weight",""))
                break
            elif pattern in n:
                l_modules_others.append(n.replace(".weight",""))
                break
    unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_target_modules_encoder), adapter_name="default_encoder")
    unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_target_modules_decoder), adapter_name="default_decoder")
    unet.add_adapter(LoraConfig(r=rank,init_lora_weights="gaussian",target_modules=l_modules_others), adapter_name="default_others")
    return unet