File size: 4,407 Bytes
9f200a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from types import SimpleNamespace

from .lora import (
    extract_lora_ups_down,
    inject_trainable_lora_extended,
    monkeypatch_or_replace_lora_extended,
)

CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"]

lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo")

lora_func_types = dict(loader="loader", injector="injector")

lora_args = dict(
    model=None,
    loras=None,
    target_replace_module=[],
    target_module=[],
    r=4,
    search_class=[torch.nn.Linear],
    dropout=0,
    lora_bias="none",
)

LoraVersions = SimpleNamespace(**lora_versions)
LoraFuncTypes = SimpleNamespace(**lora_func_types)

LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]


def filter_dict(_dict, keys=[]):
    if len(keys) == 0:
        assert "Keys cannot empty for filtering return dict."

    for k in keys:
        if k not in lora_args.keys():
            assert f"{k} does not exist in available LoRA arguments"

    return {k: v for k, v in _dict.items() if k in keys}


class LoraHandler(object):
    def __init__(
        self,
        version: str = LoraVersions.cloneofsimo,
        use_unet_lora: bool = False,
        use_text_lora: bool = False,
        save_for_webui: bool = False,
        only_for_webui: bool = False,
        lora_bias: str = "none",
        unet_replace_modules: list = ["UNet3DConditionModel"],
    ):
        self.version = version
        assert self.is_cloneofsimo_lora()

        self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
        self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
        self.lora_bias = lora_bias
        self.use_unet_lora = use_unet_lora
        self.use_text_lora = use_text_lora
        self.save_for_webui = save_for_webui
        self.only_for_webui = only_for_webui
        self.unet_replace_modules = unet_replace_modules
        self.use_lora = any([use_text_lora, use_unet_lora])

        if self.use_lora:
            print(f"Using LoRA Version: {self.version}")

    def is_cloneofsimo_lora(self):
        return self.version == LoraVersions.cloneofsimo

    def get_lora_func(self, func_type: str = LoraFuncTypes.loader):
        if func_type == LoraFuncTypes.loader:
            return monkeypatch_or_replace_lora_extended

        if func_type == LoraFuncTypes.injector:
            return inject_trainable_lora_extended

        assert "LoRA Version does not exist."

    def get_lora_func_args(
        self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias
    ):
        return_dict = lora_args.copy()

        return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
        return_dict.update(
            {
                "model": model,
                "loras": lora_path,
                "target_replace_module": replace_modules,
                "r": r,
            }
        )

        return return_dict

    def do_lora_injection(
        self,
        model,
        replace_modules,
        bias="none",
        dropout=0,
        r=4,
        lora_loader_args=None,
    ):
        REPLACE_MODULES = replace_modules

        params = None
        negation = None

        injector_args = lora_loader_args

        params, negation = self.lora_injector(**injector_args)
        for _up, _down in extract_lora_ups_down(
            model, target_replace_module=REPLACE_MODULES
        ):

            if all(x is not None for x in [_up, _down]):
                print(
                    f"Lora successfully injected into {model.__class__.__name__}."
                )

            break

        return params, negation

    def add_lora_to_model(
        self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16
    ):

        params = None
        negation = None

        lora_loader_args = self.get_lora_func_args(
            lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias
        )

        if use_lora:
            params, negation = self.do_lora_injection(
                model,
                replace_modules,
                bias=self.lora_bias,
                lora_loader_args=lora_loader_args,
                dropout=dropout,
                r=r,
            )

        params = model if params is None else params
        return params, negation