adapter weights for sketch2image
Browse files- S2I/modules/models.py +12 -3
- S2I/modules/utils.py +1 -0
- app.py +1 -1
S2I/modules/models.py
CHANGED
@@ -4,7 +4,7 @@ from diffusers import DDPMScheduler
|
|
4 |
from transformers import AutoTokenizer, CLIPTextModel
|
5 |
from diffusers import AutoencoderKL, UNet2DConditionModel
|
6 |
from peft import LoraConfig
|
7 |
-
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path
|
8 |
|
9 |
|
10 |
class RelationShipConvolution(torch.nn.Module):
|
@@ -50,7 +50,16 @@ class PrimaryModel:
|
|
50 |
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
51 |
vae.decoder.ignore_skip = False
|
52 |
return vae
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def from_pretrained(self, model_name, r):
|
55 |
if self.global_tokenizer is None:
|
56 |
# self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
|
@@ -72,7 +81,7 @@ class PrimaryModel:
|
|
72 |
self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
|
73 |
p_ckpt_path = download_models()
|
74 |
p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
|
75 |
-
sd =
|
76 |
conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
|
77 |
self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
|
78 |
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
|
|
|
4 |
from transformers import AutoTokenizer, CLIPTextModel
|
5 |
from diffusers import AutoencoderKL, UNet2DConditionModel
|
6 |
from peft import LoraConfig
|
7 |
+
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home
|
8 |
|
9 |
|
10 |
class RelationShipConvolution(torch.nn.Module):
|
|
|
50 |
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
51 |
vae.decoder.ignore_skip = False
|
52 |
return vae
|
53 |
+
def weights_adapter(self, p_ckpt, model_name):
|
54 |
+
if model_name == '350k-adapter':
|
55 |
+
home = get_s2i_home()
|
56 |
+
sd_sketch = torch.load(os.path.join(home, f"sketch2image_lora_350k.pkl"), map_location="cpu")
|
57 |
+
sd = torch.load(p_ckpt, map_location="cpu")
|
58 |
+
sd.update(sd_sketch)
|
59 |
+
return sd
|
60 |
+
else:
|
61 |
+
sd = torch.load(p_ckpt, map_location="cpu")
|
62 |
+
return sd
|
63 |
def from_pretrained(self, model_name, r):
|
64 |
if self.global_tokenizer is None:
|
65 |
# self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
|
|
|
81 |
self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
|
82 |
p_ckpt_path = download_models()
|
83 |
p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
|
84 |
+
sd = self.weights_adapter(p_ckpt, model_name)
|
85 |
conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
|
86 |
self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
|
87 |
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
|
S2I/modules/utils.py
CHANGED
@@ -84,6 +84,7 @@ def get_s2i_home() -> str:
|
|
84 |
|
85 |
def download_models():
|
86 |
urls = {
|
|
|
87 |
'350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true',
|
88 |
'100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
|
89 |
}
|
|
|
84 |
|
85 |
def download_models():
|
86 |
urls = {
|
87 |
+
'350k-adapter': 'https://huggingface.co/myn0908/sk2ks/resolve/main/adapter_weights_large_sketch2image_lora.pkl?download=true',
|
88 |
'350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true',
|
89 |
'100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
|
90 |
}
|
app.py
CHANGED
@@ -263,7 +263,7 @@ with gr.Blocks(css=css, theme="NoCrypt/miku@1.2.1") as demo:
|
|
263 |
label="Demo Speed",
|
264 |
interactive=True)
|
265 |
model_options = gr.Radio(
|
266 |
-
choices=["100k", "350k"],
|
267 |
value="350k",
|
268 |
label="Type Sketch2Image models",
|
269 |
interactive=True)
|
|
|
263 |
label="Demo Speed",
|
264 |
interactive=True)
|
265 |
model_options = gr.Radio(
|
266 |
+
choices=["100k", "350k", "350k-adapter"],
|
267 |
value="350k",
|
268 |
label="Type Sketch2Image models",
|
269 |
interactive=True)
|