Spaces:
Running
on
A100
Running
on
A100
Commit
•
722e096
1
Parent(s):
57558ed
Update lora.py
Browse files
lora.py
CHANGED
@@ -5,16 +5,12 @@
|
|
5 |
|
6 |
import math
|
7 |
import os
|
8 |
-
from typing import
|
9 |
-
from diffusers import AutoencoderKL
|
10 |
-
from transformers import CLIPTextModel
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
import re
|
14 |
|
15 |
|
16 |
-
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
17 |
-
|
18 |
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
19 |
|
20 |
|
@@ -219,13 +215,7 @@ class LoRAInfModule(LoRAModule):
|
|
219 |
|
220 |
def default_forward(self, x):
|
221 |
# print("default_forward", self.lora_name, x.size())
|
222 |
-
org_forward
|
223 |
-
lora_down = self.lora_down(x)
|
224 |
-
lora_up_down = self.lora_up(lora_down)
|
225 |
-
print(org_forward)
|
226 |
-
print(lora_up_down)
|
227 |
-
print(self.multiplier)
|
228 |
-
return org_forward + lora_up_down * self.multiplier #* self.scale
|
229 |
|
230 |
def forward(self, x):
|
231 |
if not self.enabled:
|
@@ -410,16 +400,7 @@ def parse_block_lr_kwargs(nw_kwargs):
|
|
410 |
return down_lr_weight, mid_lr_weight, up_lr_weight
|
411 |
|
412 |
|
413 |
-
def create_network(
|
414 |
-
multiplier: float,
|
415 |
-
network_dim: Optional[int],
|
416 |
-
network_alpha: Optional[float],
|
417 |
-
vae: AutoencoderKL,
|
418 |
-
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
419 |
-
unet,
|
420 |
-
neuron_dropout: Optional[float] = None,
|
421 |
-
**kwargs,
|
422 |
-
):
|
423 |
if network_dim is None:
|
424 |
network_dim = 4 # default
|
425 |
if network_alpha is None:
|
@@ -738,36 +719,33 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|
738 |
class LoRANetwork(torch.nn.Module):
|
739 |
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
740 |
|
741 |
-
|
|
|
742 |
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
743 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
744 |
LORA_PREFIX_UNET = "lora_unet"
|
745 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
746 |
|
747 |
-
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
748 |
-
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
749 |
-
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
750 |
-
|
751 |
def __init__(
|
752 |
self,
|
753 |
-
text_encoder
|
754 |
unet,
|
755 |
-
multiplier
|
756 |
-
lora_dim
|
757 |
-
alpha
|
758 |
-
dropout
|
759 |
-
rank_dropout
|
760 |
-
module_dropout
|
761 |
-
conv_lora_dim
|
762 |
-
conv_alpha
|
763 |
-
block_dims
|
764 |
-
block_alphas
|
765 |
-
conv_block_dims
|
766 |
-
conv_block_alphas
|
767 |
-
modules_dim
|
768 |
-
modules_alpha
|
769 |
-
module_class
|
770 |
-
varbose
|
771 |
) -> None:
|
772 |
"""
|
773 |
LoRA network: すごく引数が多いが、パターンは以下の通り
|
@@ -805,21 +783,8 @@ class LoRANetwork(torch.nn.Module):
|
|
805 |
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
806 |
|
807 |
# create module instances
|
808 |
-
def create_modules(
|
809 |
-
is_unet
|
810 |
-
text_encoder_idx: Optional[int], # None, 1, 2
|
811 |
-
root_module: torch.nn.Module,
|
812 |
-
target_replace_modules: List[torch.nn.Module],
|
813 |
-
) -> List[LoRAModule]:
|
814 |
-
prefix = (
|
815 |
-
self.LORA_PREFIX_UNET
|
816 |
-
if is_unet
|
817 |
-
else (
|
818 |
-
self.LORA_PREFIX_TEXT_ENCODER
|
819 |
-
if text_encoder_idx is None
|
820 |
-
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
821 |
-
)
|
822 |
-
)
|
823 |
loras = []
|
824 |
skipped = []
|
825 |
for name, module in root_module.named_modules():
|
@@ -835,14 +800,11 @@ class LoRANetwork(torch.nn.Module):
|
|
835 |
|
836 |
dim = None
|
837 |
alpha = None
|
838 |
-
|
839 |
if modules_dim is not None:
|
840 |
-
# モジュール指定あり
|
841 |
if lora_name in modules_dim:
|
842 |
dim = modules_dim[lora_name]
|
843 |
alpha = modules_alpha[lora_name]
|
844 |
elif is_unet and block_dims is not None:
|
845 |
-
# U-Netでblock_dims指定あり
|
846 |
block_idx = get_block_index(lora_name)
|
847 |
if is_linear or is_conv2d_1x1:
|
848 |
dim = block_dims[block_idx]
|
@@ -851,7 +813,6 @@ class LoRANetwork(torch.nn.Module):
|
|
851 |
dim = conv_block_dims[block_idx]
|
852 |
alpha = conv_block_alphas[block_idx]
|
853 |
else:
|
854 |
-
# 通常、すべて対象とする
|
855 |
if is_linear or is_conv2d_1x1:
|
856 |
dim = self.lora_dim
|
857 |
alpha = self.alpha
|
@@ -860,7 +821,6 @@ class LoRANetwork(torch.nn.Module):
|
|
860 |
alpha = self.conv_alpha
|
861 |
|
862 |
if dim is None or dim == 0:
|
863 |
-
# skipした情報を出力
|
864 |
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
865 |
skipped.append(lora_name)
|
866 |
continue
|
@@ -878,24 +838,7 @@ class LoRANetwork(torch.nn.Module):
|
|
878 |
loras.append(lora)
|
879 |
return loras, skipped
|
880 |
|
881 |
-
|
882 |
-
print(text_encoders)
|
883 |
-
# create LoRA for text encoder
|
884 |
-
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
885 |
-
self.text_encoder_loras = []
|
886 |
-
skipped_te = []
|
887 |
-
for i, text_encoder in enumerate(text_encoders):
|
888 |
-
if len(text_encoders) > 1:
|
889 |
-
index = i + 1
|
890 |
-
print(f"create LoRA for Text Encoder {index}:")
|
891 |
-
else:
|
892 |
-
index = None
|
893 |
-
print(f"create LoRA for Text Encoder:")
|
894 |
-
|
895 |
-
print(text_encoder)
|
896 |
-
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
897 |
-
self.text_encoder_loras.extend(text_encoder_loras)
|
898 |
-
skipped_te += skipped
|
899 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
900 |
|
901 |
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
@@ -903,7 +846,7 @@ class LoRANetwork(torch.nn.Module):
|
|
903 |
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
904 |
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
905 |
|
906 |
-
self.unet_loras, skipped_un = create_modules(True,
|
907 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
908 |
|
909 |
skipped = skipped_te + skipped_un
|
@@ -937,6 +880,7 @@ class LoRANetwork(torch.nn.Module):
|
|
937 |
weights_sd = load_file(file)
|
938 |
else:
|
939 |
weights_sd = torch.load(file, map_location="cpu")
|
|
|
940 |
info = self.load_state_dict(weights_sd, False)
|
941 |
return info
|
942 |
|
@@ -1017,7 +961,6 @@ class LoRANetwork(torch.nn.Module):
|
|
1017 |
|
1018 |
return lr_weight
|
1019 |
|
1020 |
-
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
1021 |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
1022 |
self.requires_grad_(True)
|
1023 |
all_params = []
|
|
|
5 |
|
6 |
import math
|
7 |
import os
|
8 |
+
from typing import List, Tuple, Union
|
|
|
|
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
import re
|
12 |
|
13 |
|
|
|
|
|
14 |
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
15 |
|
16 |
|
|
|
215 |
|
216 |
def default_forward(self, x):
|
217 |
# print("default_forward", self.lora_name, x.size())
|
218 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
def forward(self, x):
|
221 |
if not self.enabled:
|
|
|
400 |
return down_lr_weight, mid_lr_weight, up_lr_weight
|
401 |
|
402 |
|
403 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
if network_dim is None:
|
405 |
network_dim = 4 # default
|
406 |
if network_alpha is None:
|
|
|
719 |
class LoRANetwork(torch.nn.Module):
|
720 |
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
721 |
|
722 |
+
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
723 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
724 |
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
725 |
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
726 |
LORA_PREFIX_UNET = "lora_unet"
|
727 |
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
728 |
|
|
|
|
|
|
|
|
|
729 |
def __init__(
|
730 |
self,
|
731 |
+
text_encoder,
|
732 |
unet,
|
733 |
+
multiplier=1.0,
|
734 |
+
lora_dim=4,
|
735 |
+
alpha=1,
|
736 |
+
dropout=None,
|
737 |
+
rank_dropout=None,
|
738 |
+
module_dropout=None,
|
739 |
+
conv_lora_dim=None,
|
740 |
+
conv_alpha=None,
|
741 |
+
block_dims=None,
|
742 |
+
block_alphas=None,
|
743 |
+
conv_block_dims=None,
|
744 |
+
conv_block_alphas=None,
|
745 |
+
modules_dim=None,
|
746 |
+
modules_alpha=None,
|
747 |
+
module_class=LoRAModule,
|
748 |
+
varbose=False,
|
749 |
) -> None:
|
750 |
"""
|
751 |
LoRA network: すごく引数が多いが、パターンは以下の通り
|
|
|
783 |
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
784 |
|
785 |
# create module instances
|
786 |
+
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
787 |
+
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
788 |
loras = []
|
789 |
skipped = []
|
790 |
for name, module in root_module.named_modules():
|
|
|
800 |
|
801 |
dim = None
|
802 |
alpha = None
|
|
|
803 |
if modules_dim is not None:
|
|
|
804 |
if lora_name in modules_dim:
|
805 |
dim = modules_dim[lora_name]
|
806 |
alpha = modules_alpha[lora_name]
|
807 |
elif is_unet and block_dims is not None:
|
|
|
808 |
block_idx = get_block_index(lora_name)
|
809 |
if is_linear or is_conv2d_1x1:
|
810 |
dim = block_dims[block_idx]
|
|
|
813 |
dim = conv_block_dims[block_idx]
|
814 |
alpha = conv_block_alphas[block_idx]
|
815 |
else:
|
|
|
816 |
if is_linear or is_conv2d_1x1:
|
817 |
dim = self.lora_dim
|
818 |
alpha = self.alpha
|
|
|
821 |
alpha = self.conv_alpha
|
822 |
|
823 |
if dim is None or dim == 0:
|
|
|
824 |
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
825 |
skipped.append(lora_name)
|
826 |
continue
|
|
|
838 |
loras.append(lora)
|
839 |
return loras, skipped
|
840 |
|
841 |
+
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
842 |
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
843 |
|
844 |
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
|
|
846 |
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
847 |
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
848 |
|
849 |
+
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
850 |
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
851 |
|
852 |
skipped = skipped_te + skipped_un
|
|
|
880 |
weights_sd = load_file(file)
|
881 |
else:
|
882 |
weights_sd = torch.load(file, map_location="cpu")
|
883 |
+
|
884 |
info = self.load_state_dict(weights_sd, False)
|
885 |
return info
|
886 |
|
|
|
961 |
|
962 |
return lr_weight
|
963 |
|
|
|
964 |
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
965 |
self.requires_grad_(True)
|
966 |
all_params = []
|