Jeffiyyyy commited on
Commit
90ad87a
1 Parent(s): 97817d3
Files changed (7) hide show
  1. .DS_Store +0 -0
  2. README.md +6 -6
  3. app.py +878 -0
  4. modules/lora.py +183 -0
  5. modules/model.py +897 -0
  6. modules/prompt_parser.py +391 -0
  7. modules/safe.py +188 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: LSPDEMO
3
- emoji: 🏃
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
- license: openrail
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: LSP LearningandStrivePartner Model
3
+ emoji: 🐠
4
+ colorFrom: green
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.17.0
8
  app_file: app.py
9
  pinned: false
10
+ license: afl-3.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import tempfile
3
+ import time
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import math
8
+ import re
9
+
10
+ from gradio import inputs
11
+ from diffusers import (
12
+ AutoencoderKL,
13
+ DDIMScheduler,
14
+ UNet2DConditionModel,
15
+ )
16
+ from modules.model import (
17
+ CrossAttnProcessor,
18
+ StableDiffusionPipeline,
19
+ )
20
+ from torchvision import transforms
21
+ from transformers import CLIPTokenizer, CLIPTextModel
22
+ from PIL import Image
23
+ from pathlib import Path
24
+ from safetensors.torch import load_file
25
+ import modules.safe as _
26
+ from modules.lora import LoRANetwork
27
+
28
+ models = [
29
+ ("LSPV1", "Jeffsun/LSP", 2),
30
+ ("Pastal Mix", "andite/pastel-mix", 2),
31
+ ("Basil Mix", "nuigurumi/basil_mix", 2)
32
+ ]
33
+
34
+ keep_vram = ["Korakoe/AbyssOrangeMix2-HF", "andite/pastel-mix"]
35
+ base_name, base_model, clip_skip = models[0]
36
+
37
+ samplers_k_diffusion = [
38
+ ("Euler a", "sample_euler_ancestral", {}),
39
+ ("Euler", "sample_euler", {}),
40
+ ("LMS", "sample_lms", {}),
41
+ ("Heun", "sample_heun", {}),
42
+ ("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}),
43
+ ("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}),
44
+ ("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}),
45
+ ("DPM++ 2M", "sample_dpmpp_2m", {}),
46
+ ("DPM++ SDE", "sample_dpmpp_sde", {}),
47
+ ("LMS Karras", "sample_lms", {"scheduler": "karras"}),
48
+ ("DPM2 Karras", "sample_dpm_2", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
49
+ ("DPM2 a Karras", "sample_dpm_2_ancestral", {"scheduler": "karras", "discard_next_to_last_sigma": True}),
50
+ ("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}),
51
+ ("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}),
52
+ ("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}),
53
+ ]
54
+
55
+ # samplers_diffusers = [
56
+ # ("DDIMScheduler", "diffusers.schedulers.DDIMScheduler", {})
57
+ # ("DDPMScheduler", "diffusers.schedulers.DDPMScheduler", {})
58
+ # ("DEISMultistepScheduler", "diffusers.schedulers.DEISMultistepScheduler", {})
59
+ # ]
60
+
61
+ start_time = time.time()
62
+ timeout = 90
63
+
64
+ scheduler = DDIMScheduler.from_pretrained(
65
+ base_model,
66
+ subfolder="scheduler",
67
+ )
68
+ vae = AutoencoderKL.from_pretrained(
69
+ "stabilityai/sd-vae-ft-ema",
70
+ torch_dtype=torch.float16
71
+ )
72
+ text_encoder = CLIPTextModel.from_pretrained(
73
+ base_model,
74
+ subfolder="text_encoder",
75
+ torch_dtype=torch.float16,
76
+ )
77
+ tokenizer = CLIPTokenizer.from_pretrained(
78
+ base_model,
79
+ subfolder="tokenizer",
80
+ torch_dtype=torch.float16,
81
+ )
82
+ unet = UNet2DConditionModel.from_pretrained(
83
+ base_model,
84
+ subfolder="unet",
85
+ torch_dtype=torch.float16,
86
+ )
87
+ pipe = StableDiffusionPipeline(
88
+ text_encoder=text_encoder,
89
+ tokenizer=tokenizer,
90
+ unet=unet,
91
+ vae=vae,
92
+ scheduler=scheduler,
93
+ )
94
+
95
+ unet.set_attn_processor(CrossAttnProcessor)
96
+ pipe.setup_text_encoder(clip_skip, text_encoder)
97
+ if torch.cuda.is_available():
98
+ pipe = pipe.to("cuda")
99
+
100
+ def get_model_list():
101
+ return models
102
+
103
+ te_cache = {
104
+ base_model: text_encoder
105
+ }
106
+
107
+ unet_cache = {
108
+ base_model: unet
109
+ }
110
+
111
+ lora_cache = {
112
+ base_model: LoRANetwork(text_encoder, unet)
113
+ }
114
+
115
+ te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
116
+ original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
117
+ current_model = base_model
118
+
119
+ def setup_model(name, lora_state=None, lora_scale=1.0):
120
+ global pipe, current_model
121
+
122
+ keys = [k[0] for k in models]
123
+ model = models[keys.index(name)][1]
124
+ if model not in unet_cache:
125
+ unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16)
126
+ text_encoder = CLIPTextModel.from_pretrained(model, subfolder="text_encoder", torch_dtype=torch.float16)
127
+
128
+ unet_cache[model] = unet
129
+ te_cache[model] = text_encoder
130
+ lora_cache[model] = LoRANetwork(text_encoder, unet)
131
+
132
+ if current_model != model:
133
+ if current_model not in keep_vram:
134
+ # offload current model
135
+ unet_cache[current_model].to("cpu")
136
+ te_cache[current_model].to("cpu")
137
+ lora_cache[current_model].to("cpu")
138
+ current_model = model
139
+
140
+ local_te, local_unet, local_lora, = te_cache[model], unet_cache[model], lora_cache[model]
141
+ local_unet.set_attn_processor(CrossAttnProcessor())
142
+ local_lora.reset()
143
+ clip_skip = models[keys.index(name)][2]
144
+
145
+ if torch.cuda.is_available():
146
+ local_unet.to("cuda")
147
+ local_te.to("cuda")
148
+
149
+ if lora_state is not None and lora_state != "":
150
+ local_lora.load(lora_state, lora_scale)
151
+ local_lora.to(local_unet.device, dtype=local_unet.dtype)
152
+
153
+ pipe.text_encoder, pipe.unet = local_te, local_unet
154
+ pipe.setup_unet(local_unet)
155
+ pipe.tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
156
+ pipe.tokenizer.added_tokens_encoder = {}
157
+ pipe.tokenizer.added_tokens_decoder = {}
158
+ pipe.setup_text_encoder(clip_skip, local_te)
159
+ return pipe
160
+
161
+
162
+ def error_str(error, title="Error"):
163
+ return (
164
+ f"""#### {title}
165
+ {error}"""
166
+ if error
167
+ else ""
168
+ )
169
+
170
+ def make_token_names(embs):
171
+ all_tokens = []
172
+ for name, vec in embs.items():
173
+ tokens = [f'emb-{name}-{i}' for i in range(len(vec))]
174
+ all_tokens.append(tokens)
175
+ return all_tokens
176
+
177
+ def setup_tokenizer(tokenizer, embs):
178
+ reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()]
179
+ clip_keywords = [' '.join(s) for s in make_token_names(embs)]
180
+
181
+ def parse_prompt(prompt: str):
182
+ for m, v in zip(reg_match, clip_keywords):
183
+ prompt = m.sub(v, prompt)
184
+ return prompt
185
+
186
+ def prepare_for_tokenization(self, text: str, is_split_into_words: bool = False, **kwargs):
187
+ text = parse_prompt(text)
188
+ r = original_prepare_for_tokenization(text, is_split_into_words, **kwargs)
189
+ return r
190
+ tokenizer.prepare_for_tokenization = prepare_for_tokenization.__get__(tokenizer, CLIPTokenizer)
191
+ return [t for sublist in make_token_names(embs) for t in sublist]
192
+
193
+
194
+ def convert_size(size_bytes):
195
+ if size_bytes == 0:
196
+ return "0B"
197
+ size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
198
+ i = int(math.floor(math.log(size_bytes, 1024)))
199
+ p = math.pow(1024, i)
200
+ s = round(size_bytes / p, 2)
201
+ return "%s %s" % (s, size_name[i])
202
+
203
+ def inference(
204
+ prompt,
205
+ guidance,
206
+ steps,
207
+ width=512,
208
+ height=512,
209
+ seed=0,
210
+ neg_prompt="",
211
+ state=None,
212
+ g_strength=0.4,
213
+ img_input=None,
214
+ i2i_scale=0.5,
215
+ hr_enabled=False,
216
+ hr_method="Latent",
217
+ hr_scale=1.5,
218
+ hr_denoise=0.8,
219
+ sampler="DPM++ 2M Karras",
220
+ embs=None,
221
+ model=None,
222
+ lora_state=None,
223
+ lora_scale=None,
224
+ ):
225
+ if seed is None or seed == 0:
226
+ seed = random.randint(0, 2147483647)
227
+
228
+ pipe = setup_model(model, lora_state, lora_scale)
229
+ generator = torch.Generator("cuda").manual_seed(int(seed))
230
+ start_time = time.time()
231
+
232
+ sampler_name, sampler_opt = None, None
233
+ for label, funcname, options in samplers_k_diffusion:
234
+ if label == sampler:
235
+ sampler_name, sampler_opt = funcname, options
236
+
237
+ tokenizer, text_encoder = pipe.tokenizer, pipe.text_encoder
238
+ if embs is not None and len(embs) > 0:
239
+ ti_embs = {}
240
+ for name, file in embs.items():
241
+ if str(file).endswith(".pt"):
242
+ loaded_learned_embeds = torch.load(file, map_location="cpu")
243
+ else:
244
+ loaded_learned_embeds = load_file(file, device="cpu")
245
+ loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"] if "string_to_param" in loaded_learned_embed else loaded_learned_embed
246
+ ti_embs[name] = loaded_learned_embeds
247
+
248
+ if len(ti_embs) > 0:
249
+ tokens = setup_tokenizer(tokenizer, ti_embs)
250
+ added_tokens = tokenizer.add_tokens(tokens)
251
+ delta_weight = torch.cat([val for val in ti_embs.values()], dim=0)
252
+
253
+ assert added_tokens == delta_weight.shape[0]
254
+ text_encoder.resize_token_embeddings(len(tokenizer))
255
+ token_embeds = text_encoder.get_input_embeddings().weight.data
256
+ token_embeds[-delta_weight.shape[0]:] = delta_weight
257
+
258
+ config = {
259
+ "negative_prompt": neg_prompt,
260
+ "num_inference_steps": int(steps),
261
+ "guidance_scale": guidance,
262
+ "generator": generator,
263
+ "sampler_name": sampler_name,
264
+ "sampler_opt": sampler_opt,
265
+ "pww_state": state,
266
+ "pww_attn_weight": g_strength,
267
+ "start_time": start_time,
268
+ "timeout": timeout,
269
+ }
270
+
271
+ if img_input is not None:
272
+ ratio = min(height / img_input.height, width / img_input.width)
273
+ img_input = img_input.resize(
274
+ (int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS
275
+ )
276
+ result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config)
277
+ elif hr_enabled:
278
+ result = pipe.txt2img(
279
+ prompt,
280
+ width=width,
281
+ height=height,
282
+ upscale=True,
283
+ upscale_x=hr_scale,
284
+ upscale_denoising_strength=hr_denoise,
285
+ **config,
286
+ **latent_upscale_modes[hr_method],
287
+ )
288
+ else:
289
+ result = pipe.txt2img(prompt, width=width, height=height, **config)
290
+
291
+ end_time = time.time()
292
+ vram_free, vram_total = torch.cuda.mem_get_info()
293
+ print(f"done: model={model}, res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")
294
+ return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
295
+
296
+
297
+ color_list = []
298
+
299
+
300
+ def get_color(n):
301
+ for _ in range(n - len(color_list)):
302
+ color_list.append(tuple(np.random.random(size=3) * 256))
303
+ return color_list
304
+
305
+
306
+ def create_mixed_img(current, state, w=512, h=512):
307
+ w, h = int(w), int(h)
308
+ image_np = np.full([h, w, 4], 255)
309
+ if state is None:
310
+ state = {}
311
+
312
+ colors = get_color(len(state))
313
+ idx = 0
314
+
315
+ for key, item in state.items():
316
+ if item["map"] is not None:
317
+ m = item["map"] < 255
318
+ alpha = 150
319
+ if current == key:
320
+ alpha = 200
321
+ image_np[m] = colors[idx] + (alpha,)
322
+ idx += 1
323
+
324
+ return image_np
325
+
326
+
327
+ # width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered])
328
+ def apply_new_res(w, h, state):
329
+ w, h = int(w), int(h)
330
+
331
+ for key, item in state.items():
332
+ if item["map"] is not None:
333
+ item["map"] = resize(item["map"], w, h)
334
+
335
+ update_img = gr.Image.update(value=create_mixed_img("", state, w, h))
336
+ return state, update_img
337
+
338
+
339
+ def detect_text(text, state, width, height):
340
+
341
+ if text is None or text == "":
342
+ return None, None, gr.Radio.update(value=None), None
343
+
344
+ t = text.split(",")
345
+ new_state = {}
346
+
347
+ for item in t:
348
+ item = item.strip()
349
+ if item == "":
350
+ continue
351
+ if state is not None and item in state:
352
+ new_state[item] = {
353
+ "map": state[item]["map"],
354
+ "weight": state[item]["weight"],
355
+ "mask_outsides": state[item]["mask_outsides"],
356
+ }
357
+ else:
358
+ new_state[item] = {
359
+ "map": None,
360
+ "weight": 0.5,
361
+ "mask_outsides": False
362
+ }
363
+ update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None)
364
+ update_img = gr.update(value=create_mixed_img("", new_state, width, height))
365
+ update_sketch = gr.update(value=None, interactive=False)
366
+ return new_state, update_sketch, update, update_img
367
+
368
+
369
+ def resize(img, w, h):
370
+ trs = transforms.Compose(
371
+ [
372
+ transforms.ToPILImage(),
373
+ transforms.Resize(min(h, w)),
374
+ transforms.CenterCrop((h, w)),
375
+ ]
376
+ )
377
+ result = np.array(trs(img), dtype=np.uint8)
378
+ return result
379
+
380
+
381
+ def switch_canvas(entry, state, width, height):
382
+ if entry == None:
383
+ return None, 0.5, False, create_mixed_img("", state, width, height)
384
+
385
+ return (
386
+ gr.update(value=None, interactive=True),
387
+ gr.update(value=state[entry]["weight"] if entry in state else 0.5),
388
+ gr.update(value=state[entry]["mask_outsides"] if entry in state else False),
389
+ create_mixed_img(entry, state, width, height),
390
+ )
391
+
392
+
393
+ def apply_canvas(selected, draw, state, w, h):
394
+ if selected in state:
395
+ w, h = int(w), int(h)
396
+ state[selected]["map"] = resize(draw, w, h)
397
+ return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
398
+
399
+
400
+ def apply_weight(selected, weight, state):
401
+ if selected in state:
402
+ state[selected]["weight"] = weight
403
+ return state
404
+
405
+
406
+ def apply_option(selected, mask, state):
407
+ if selected in state:
408
+ state[selected]["mask_outsides"] = mask
409
+ return state
410
+
411
+
412
+ # sp2, radio, width, height, global_stats
413
+ def apply_image(image, selected, w, h, strgength, mask, state):
414
+ if selected in state:
415
+ state[selected] = {
416
+ "map": resize(image, w, h),
417
+ "weight": strgength,
418
+ "mask_outsides": mask
419
+ }
420
+
421
+ return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
422
+
423
+
424
+ # [ti_state, lora_state, ti_vals, lora_vals, uploads]
425
+ def add_net(files, ti_state, lora_state):
426
+ if files is None:
427
+ return ti_state, "", lora_state, None
428
+
429
+ for file in files:
430
+ item = Path(file.name)
431
+ stripedname = str(item.stem).strip()
432
+ if item.suffix == ".pt":
433
+ state_dict = torch.load(file.name, map_location="cpu")
434
+ else:
435
+ state_dict = load_file(file.name, device="cpu")
436
+ if any("lora" in k for k in state_dict.keys()):
437
+ lora_state = file.name
438
+ else:
439
+ ti_state[stripedname] = file.name
440
+
441
+ return (
442
+ ti_state,
443
+ lora_state,
444
+ gr.Text.update(f"{[key for key in ti_state.keys()]}"),
445
+ gr.Text.update(f"{lora_state}"),
446
+ gr.Files.update(value=None),
447
+ )
448
+
449
+
450
+ # [ti_state, lora_state, ti_vals, lora_vals, uploads]
451
+ def clean_states(ti_state, lora_state):
452
+ return (
453
+ dict(),
454
+ None,
455
+ gr.Text.update(f""),
456
+ gr.Text.update(f""),
457
+ gr.File.update(value=None),
458
+ )
459
+
460
+
461
+ latent_upscale_modes = {
462
+ "Latent": {"upscale_method": "bilinear", "upscale_antialias": False},
463
+ "Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True},
464
+ "Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False},
465
+ "Latent (bicubic antialiased)": {
466
+ "upscale_method": "bicubic",
467
+ "upscale_antialias": True,
468
+ },
469
+ "Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False},
470
+ "Latent (nearest-exact)": {
471
+ "upscale_method": "nearest-exact",
472
+ "upscale_antialias": False,
473
+ },
474
+ }
475
+
476
+ css = """
477
+ .finetuned-diffusion-div div{
478
+ display:inline-flex;
479
+ align-items:center;
480
+ gap:.8rem;
481
+ font-size:1.75rem;
482
+ padding-top:2rem;
483
+ }
484
+ .finetuned-diffusion-div div h1{
485
+ font-weight:900;
486
+ margin-bottom:7px
487
+ }
488
+ .finetuned-diffusion-div p{
489
+ margin-bottom:10px;
490
+ font-size:94%
491
+ }
492
+ .box {
493
+ float: left;
494
+ height: 20px;
495
+ width: 20px;
496
+ margin-bottom: 15px;
497
+ border: 1px solid black;
498
+ clear: both;
499
+ }
500
+ a{
501
+ text-decoration:underline
502
+ }
503
+ .tabs{
504
+ margin-top:0;
505
+ margin-bottom:0
506
+ }
507
+ #gallery{
508
+ min-height:20rem
509
+ }
510
+ .no-border {
511
+ border: none !important;
512
+ }
513
+ """
514
+ with gr.Blocks(css=css) as demo:
515
+ gr.HTML(
516
+ f"""
517
+ <div class="finetuned-diffusion-div">
518
+ <div>
519
+ <h1>Demo for diffusion models</h1>
520
+ </div>
521
+ <p>Hso @ nyanko.sketch2img.gradio</p>
522
+ </div>
523
+ """
524
+ )
525
+ global_stats = gr.State(value={})
526
+
527
+ with gr.Row():
528
+
529
+ with gr.Column(scale=55):
530
+ model = gr.Dropdown(
531
+ choices=[k[0] for k in get_model_list()],
532
+ label="Model",
533
+ value=base_name,
534
+ )
535
+ image_out = gr.Image(height=512)
536
+ # gallery = gr.Gallery(
537
+ # label="Generated images", show_label=False, elem_id="gallery"
538
+ # ).style(grid=[1], height="auto")
539
+
540
+ with gr.Column(scale=45):
541
+
542
+ with gr.Group():
543
+
544
+ with gr.Row():
545
+ with gr.Column(scale=70):
546
+
547
+ prompt = gr.Textbox(
548
+ label="Prompt",
549
+ value="best quality, masterpiece, highres, an extremely delicate and beautiful, original, extremely detailed wallpaper, highres , 1girl",
550
+ show_label=True,
551
+ max_lines=4,
552
+ placeholder="Enter prompt.",
553
+ )
554
+ neg_prompt = gr.Textbox(
555
+ label="Negative Prompt",
556
+ value="simple background,monochrome ,lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits,twisting jawline, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, lowres, bad anatomy, bad hands, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, ugly,pregnant,vore,duplicate,morbid,mut ilated,tran nsexual, hermaphrodite,long neck,mutated hands,poorly drawn hands,poorly drawn face,mutation,deformed,blurry,bad anatomy,bad proportions,malformed limbs,extra limbs,cloned face,disfigured,gross proportions, missing arms, missing legs, extra arms,extra legs,pubic hair, plump,bad legs,error legs,username,blurry,bad feet",
557
+ show_label=True,
558
+ max_lines=4,
559
+ placeholder="Enter negative prompt.",
560
+ )
561
+
562
+ generate = gr.Button(value="Generate").style(
563
+ rounded=(False, True, True, False)
564
+ )
565
+
566
+ with gr.Tab("Options"):
567
+
568
+ with gr.Group():
569
+
570
+ # n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
571
+ with gr.Row():
572
+ guidance = gr.Slider(
573
+ label="Guidance scale", value=7.5, maximum=15
574
+ )
575
+ steps = gr.Slider(
576
+ label="Steps", value=25, minimum=2, maximum=50, step=1
577
+ )
578
+
579
+ with gr.Row():
580
+ width = gr.Slider(
581
+ label="Width", value=512, minimum=64, maximum=1024, step=64
582
+ )
583
+ height = gr.Slider(
584
+ label="Height", value=512, minimum=64, maximum=1024, step=64
585
+ )
586
+
587
+ sampler = gr.Dropdown(
588
+ value="DPM++ 2M Karras",
589
+ label="Sampler",
590
+ choices=[s[0] for s in samplers_k_diffusion],
591
+ )
592
+ seed = gr.Number(label="Seed (0 = random)", value=0)
593
+
594
+ with gr.Tab("Image to image"):
595
+ with gr.Group():
596
+
597
+ inf_image = gr.Image(
598
+ label="Image", height=256, tool="editor", type="pil"
599
+ )
600
+ inf_strength = gr.Slider(
601
+ label="Transformation strength",
602
+ minimum=0,
603
+ maximum=1,
604
+ step=0.01,
605
+ value=0.5,
606
+ )
607
+
608
+ def res_cap(g, w, h, x):
609
+ if g:
610
+ return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}"
611
+ else:
612
+ return "Enable upscaler"
613
+
614
+ with gr.Tab("Hires fix"):
615
+ with gr.Group():
616
+
617
+ hr_enabled = gr.Checkbox(label="Enable upscaler", value=False)
618
+ hr_method = gr.Dropdown(
619
+ [key for key in latent_upscale_modes.keys()],
620
+ value="Latent",
621
+ label="Upscale method",
622
+ )
623
+ hr_scale = gr.Slider(
624
+ label="Upscale factor",
625
+ minimum=1.0,
626
+ maximum=2.0,
627
+ step=0.1,
628
+ value=1.5,
629
+ )
630
+ hr_denoise = gr.Slider(
631
+ label="Denoising strength",
632
+ minimum=0.0,
633
+ maximum=1.0,
634
+ step=0.1,
635
+ value=0.8,
636
+ )
637
+
638
+ hr_scale.change(
639
+ lambda g, x, w, h: gr.Checkbox.update(
640
+ label=res_cap(g, w, h, x)
641
+ ),
642
+ inputs=[hr_enabled, hr_scale, width, height],
643
+ outputs=hr_enabled,
644
+ queue=False,
645
+ )
646
+ hr_enabled.change(
647
+ lambda g, x, w, h: gr.Checkbox.update(
648
+ label=res_cap(g, w, h, x)
649
+ ),
650
+ inputs=[hr_enabled, hr_scale, width, height],
651
+ outputs=hr_enabled,
652
+ queue=False,
653
+ )
654
+
655
+ with gr.Tab("Embeddings/Loras"):
656
+
657
+ ti_state = gr.State(dict())
658
+ lora_state = gr.State()
659
+
660
+ with gr.Group():
661
+ with gr.Row():
662
+ with gr.Column(scale=90):
663
+ ti_vals = gr.Text(label="Loaded embeddings")
664
+
665
+ with gr.Row():
666
+ with gr.Column(scale=90):
667
+ lora_vals = gr.Text(label="Loaded loras")
668
+
669
+ with gr.Row():
670
+
671
+ uploads = gr.Files(label="Upload new embeddings/lora")
672
+
673
+ with gr.Column():
674
+ lora_scale = gr.Slider(
675
+ label="Lora scale",
676
+ minimum=0,
677
+ maximum=2,
678
+ step=0.01,
679
+ value=1.0,
680
+ )
681
+ btn = gr.Button(value="Upload")
682
+ btn_del = gr.Button(value="Reset")
683
+
684
+ btn.click(
685
+ add_net,
686
+ inputs=[uploads, ti_state, lora_state],
687
+ outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
688
+ queue=False,
689
+ )
690
+ btn_del.click(
691
+ clean_states,
692
+ inputs=[ti_state, lora_state],
693
+ outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads],
694
+ queue=False,
695
+ )
696
+
697
+ # error_output = gr.Markdown()
698
+
699
+ gr.HTML(
700
+ f"""
701
+ <div class="finetuned-diffusion-div">
702
+ <div>
703
+ <h1>Paint with words</h1>
704
+ </div>
705
+ <p>
706
+ Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk).
707
+ </p>
708
+ </div>
709
+ """
710
+ )
711
+
712
+ with gr.Row():
713
+
714
+ with gr.Column(scale=55):
715
+
716
+ rendered = gr.Image(
717
+ invert_colors=True,
718
+ source="canvas",
719
+ interactive=False,
720
+ image_mode="RGBA",
721
+ )
722
+
723
+ with gr.Column(scale=45):
724
+
725
+ with gr.Group():
726
+ with gr.Row():
727
+ with gr.Column(scale=70):
728
+ g_strength = gr.Slider(
729
+ label="Weight scaling",
730
+ minimum=0,
731
+ maximum=0.8,
732
+ step=0.01,
733
+ value=0.4,
734
+ )
735
+
736
+ text = gr.Textbox(
737
+ lines=2,
738
+ interactive=True,
739
+ label="Token to Draw: (Separate by comma)",
740
+ )
741
+
742
+ radio = gr.Radio([], label="Tokens")
743
+
744
+ sk_update = gr.Button(value="Update").style(
745
+ rounded=(False, True, True, False)
746
+ )
747
+
748
+ # g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output])
749
+
750
+ with gr.Tab("SketchPad"):
751
+
752
+ sp = gr.Image(
753
+ image_mode="L",
754
+ tool="sketch",
755
+ source="canvas",
756
+ interactive=False,
757
+ )
758
+
759
+ mask_outsides = gr.Checkbox(
760
+ label="Mask other areas",
761
+ value=False
762
+ )
763
+
764
+ strength = gr.Slider(
765
+ label="Token strength",
766
+ minimum=0,
767
+ maximum=0.8,
768
+ step=0.01,
769
+ value=0.5,
770
+ )
771
+
772
+
773
+ sk_update.click(
774
+ detect_text,
775
+ inputs=[text, global_stats, width, height],
776
+ outputs=[global_stats, sp, radio, rendered],
777
+ queue=False,
778
+ )
779
+ radio.change(
780
+ switch_canvas,
781
+ inputs=[radio, global_stats, width, height],
782
+ outputs=[sp, strength, mask_outsides, rendered],
783
+ queue=False,
784
+ )
785
+ sp.edit(
786
+ apply_canvas,
787
+ inputs=[radio, sp, global_stats, width, height],
788
+ outputs=[global_stats, rendered],
789
+ queue=False,
790
+ )
791
+ strength.change(
792
+ apply_weight,
793
+ inputs=[radio, strength, global_stats],
794
+ outputs=[global_stats],
795
+ queue=False,
796
+ )
797
+ mask_outsides.change(
798
+ apply_option,
799
+ inputs=[radio, mask_outsides, global_stats],
800
+ outputs=[global_stats],
801
+ queue=False,
802
+ )
803
+
804
+ with gr.Tab("UploadFile"):
805
+
806
+ sp2 = gr.Image(
807
+ image_mode="L",
808
+ source="upload",
809
+ shape=(512, 512),
810
+ )
811
+
812
+ mask_outsides2 = gr.Checkbox(
813
+ label="Mask other areas",
814
+ value=False,
815
+ )
816
+
817
+ strength2 = gr.Slider(
818
+ label="Token strength",
819
+ minimum=0,
820
+ maximum=0.8,
821
+ step=0.01,
822
+ value=0.5,
823
+ )
824
+
825
+ apply_style = gr.Button(value="Apply")
826
+ apply_style.click(
827
+ apply_image,
828
+ inputs=[sp2, radio, width, height, strength2, mask_outsides2, global_stats],
829
+ outputs=[global_stats, rendered],
830
+ queue=False,
831
+ )
832
+
833
+ width.change(
834
+ apply_new_res,
835
+ inputs=[width, height, global_stats],
836
+ outputs=[global_stats, rendered],
837
+ queue=False,
838
+ )
839
+ height.change(
840
+ apply_new_res,
841
+ inputs=[width, height, global_stats],
842
+ outputs=[global_stats, rendered],
843
+ queue=False,
844
+ )
845
+
846
+ # color_stats = gr.State(value={})
847
+ # text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
848
+ # sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
849
+
850
+ inputs = [
851
+ prompt,
852
+ guidance,
853
+ steps,
854
+ width,
855
+ height,
856
+ seed,
857
+ neg_prompt,
858
+ global_stats,
859
+ g_strength,
860
+ inf_image,
861
+ inf_strength,
862
+ hr_enabled,
863
+ hr_method,
864
+ hr_scale,
865
+ hr_denoise,
866
+ sampler,
867
+ ti_state,
868
+ model,
869
+ lora_state,
870
+ lora_scale,
871
+ ]
872
+ outputs = [image_out]
873
+ prompt.submit(inference, inputs=inputs, outputs=outputs)
874
+ generate.click(inference, inputs=inputs, outputs=outputs)
875
+
876
+ print(f"Space built in {time.time() - start_time:.2f} seconds")
877
+ # demo.launch(share=True)
878
+ demo.launch(enable_queue=True, server_name="0.0.0.0", server_port=7860)
modules/lora.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48
6
+
7
+ import math
8
+ import os
9
+ import torch
10
+ import modules.safe as _
11
+ from safetensors.torch import load_file
12
+
13
+
14
+ class LoRAModule(torch.nn.Module):
15
+ """
16
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ lora_name,
22
+ org_module: torch.nn.Module,
23
+ multiplier=1.0,
24
+ lora_dim=4,
25
+ alpha=1,
26
+ ):
27
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
28
+ super().__init__()
29
+ self.lora_name = lora_name
30
+ self.lora_dim = lora_dim
31
+
32
+ if org_module.__class__.__name__ == "Conv2d":
33
+ in_dim = org_module.in_channels
34
+ out_dim = org_module.out_channels
35
+ self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
36
+ self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
37
+ else:
38
+ in_dim = org_module.in_features
39
+ out_dim = org_module.out_features
40
+ self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
41
+ self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
42
+
43
+ if type(alpha) == torch.Tensor:
44
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
45
+
46
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
47
+ self.scale = alpha / self.lora_dim
48
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
49
+
50
+ # same as microsoft's
51
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
52
+ torch.nn.init.zeros_(self.lora_up.weight)
53
+
54
+ self.multiplier = multiplier
55
+ self.org_module = org_module # remove in applying
56
+ self.enable = False
57
+
58
+ def resize(self, rank, alpha, multiplier):
59
+ self.alpha = torch.tensor(alpha)
60
+ self.multiplier = multiplier
61
+ self.scale = alpha / rank
62
+ if self.lora_down.__class__.__name__ == "Conv2d":
63
+ in_dim = self.lora_down.in_channels
64
+ out_dim = self.lora_up.out_channels
65
+ self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
66
+ self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
67
+ else:
68
+ in_dim = self.lora_down.in_features
69
+ out_dim = self.lora_up.out_features
70
+ self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
71
+ self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
72
+
73
+ def apply(self):
74
+ if hasattr(self, "org_module"):
75
+ self.org_forward = self.org_module.forward
76
+ self.org_module.forward = self.forward
77
+ del self.org_module
78
+
79
+ def forward(self, x):
80
+ if self.enable:
81
+ return (
82
+ self.org_forward(x)
83
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
84
+ )
85
+ return self.org_forward(x)
86
+
87
+
88
+ class LoRANetwork(torch.nn.Module):
89
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
90
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
91
+ LORA_PREFIX_UNET = "lora_unet"
92
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
93
+
94
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
95
+ super().__init__()
96
+ self.multiplier = multiplier
97
+ self.lora_dim = lora_dim
98
+ self.alpha = alpha
99
+
100
+ # create module instances
101
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
102
+ loras = []
103
+ for name, module in root_module.named_modules():
104
+ if module.__class__.__name__ in target_replace_modules:
105
+ for child_name, child_module in module.named_modules():
106
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
107
+ lora_name = prefix + "." + name + "." + child_name
108
+ lora_name = lora_name.replace(".", "_")
109
+ lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
110
+ loras.append(lora)
111
+ return loras
112
+
113
+ if isinstance(text_encoder, list):
114
+ self.text_encoder_loras = text_encoder
115
+ else:
116
+ self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
117
+ print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
118
+
119
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
120
+ print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")
121
+
122
+ self.weights_sd = None
123
+
124
+ # assertion
125
+ names = set()
126
+ for lora in self.text_encoder_loras + self.unet_loras:
127
+ assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
128
+ names.add(lora.lora_name)
129
+
130
+ lora.apply()
131
+ self.add_module(lora.lora_name, lora)
132
+
133
+ def reset(self):
134
+ for lora in self.text_encoder_loras + self.unet_loras:
135
+ lora.enable = False
136
+
137
+ def load(self, file, scale):
138
+
139
+ weights = None
140
+ if os.path.splitext(file)[1] == ".safetensors":
141
+ weights = load_file(file)
142
+ else:
143
+ weights = torch.load(file, map_location="cpu")
144
+
145
+ if not weights:
146
+ return
147
+
148
+ network_alpha = None
149
+ network_dim = None
150
+ for key, value in weights.items():
151
+ if network_alpha is None and "alpha" in key:
152
+ network_alpha = value
153
+ if network_dim is None and "lora_down" in key and len(value.size()) == 2:
154
+ network_dim = value.size()[0]
155
+
156
+ if network_alpha is None:
157
+ network_alpha = network_dim
158
+
159
+ weights_has_text_encoder = weights_has_unet = False
160
+ weights_to_modify = []
161
+
162
+ for key in weights.keys():
163
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
164
+ weights_has_text_encoder = True
165
+
166
+ if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
167
+ weights_has_unet = True
168
+
169
+ if weights_has_text_encoder:
170
+ weights_to_modify += self.text_encoder_loras
171
+
172
+ if weights_has_unet:
173
+ weights_to_modify += self.unet_loras
174
+
175
+ for lora in self.text_encoder_loras + self.unet_loras:
176
+ lora.resize(network_dim, network_alpha, scale)
177
+ if lora in weights_to_modify:
178
+ lora.enable = True
179
+
180
+ info = self.load_state_dict(weights, False)
181
+ if len(info.unexpected_keys) > 0:
182
+ print(f"Weights are loaded. Unexpected keys={info.unexpected_keys}")
183
+
modules/model.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+ import math
4
+ from pathlib import Path
5
+ import re
6
+ from collections import defaultdict
7
+ from typing import List, Optional, Union
8
+
9
+ import time
10
+ import k_diffusion
11
+ import numpy as np
12
+ import PIL
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange
17
+ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
18
+ from modules.prompt_parser import FrozenCLIPEmbedderWithCustomWords
19
+ from torch import einsum
20
+ from torch.autograd.function import Function
21
+
22
+ from diffusers import DiffusionPipeline
23
+ from diffusers.utils import PIL_INTERPOLATION, is_accelerate_available
24
+ from diffusers.utils import logging, randn_tensor
25
+
26
+ import modules.safe as _
27
+ from safetensors.torch import load_file
28
+
29
+ xformers_available = False
30
+ try:
31
+ import xformers
32
+
33
+ xformers_available = True
34
+ except ImportError:
35
+ pass
36
+
37
+ EPSILON = 1e-6
38
+ exists = lambda val: val is not None
39
+ default = lambda val, d: val if exists(val) else d
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ def get_attention_scores(attn, query, key, attention_mask=None):
44
+
45
+ if attn.upcast_attention:
46
+ query = query.float()
47
+ key = key.float()
48
+
49
+ attention_scores = torch.baddbmm(
50
+ torch.empty(
51
+ query.shape[0],
52
+ query.shape[1],
53
+ key.shape[1],
54
+ dtype=query.dtype,
55
+ device=query.device,
56
+ ),
57
+ query,
58
+ key.transpose(-1, -2),
59
+ beta=0,
60
+ alpha=attn.scale,
61
+ )
62
+
63
+ if attention_mask is not None:
64
+ attention_scores = attention_scores + attention_mask
65
+
66
+ if attn.upcast_softmax:
67
+ attention_scores = attention_scores.float()
68
+
69
+ return attention_scores
70
+
71
+
72
+ class CrossAttnProcessor(nn.Module):
73
+ def __call__(
74
+ self,
75
+ attn,
76
+ hidden_states,
77
+ encoder_hidden_states=None,
78
+ attention_mask=None,
79
+ ):
80
+ batch_size, sequence_length, _ = hidden_states.shape
81
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
82
+
83
+ encoder_states = hidden_states
84
+ is_xattn = False
85
+ if encoder_hidden_states is not None:
86
+ is_xattn = True
87
+ img_state = encoder_hidden_states["img_state"]
88
+ encoder_states = encoder_hidden_states["states"]
89
+ weight_func = encoder_hidden_states["weight_func"]
90
+ sigma = encoder_hidden_states["sigma"]
91
+
92
+ query = attn.to_q(hidden_states)
93
+ key = attn.to_k(encoder_states)
94
+ value = attn.to_v(encoder_states)
95
+
96
+ query = attn.head_to_batch_dim(query)
97
+ key = attn.head_to_batch_dim(key)
98
+ value = attn.head_to_batch_dim(value)
99
+
100
+ if is_xattn and isinstance(img_state, dict):
101
+ # use torch.baddbmm method (slow)
102
+ attention_scores = get_attention_scores(attn, query, key, attention_mask)
103
+ w = img_state[sequence_length].to(query.device)
104
+ cross_attention_weight = weight_func(w, sigma, attention_scores)
105
+ attention_scores += torch.repeat_interleave(
106
+ cross_attention_weight, repeats=attn.heads, dim=0
107
+ )
108
+
109
+ # calc probs
110
+ attention_probs = attention_scores.softmax(dim=-1)
111
+ attention_probs = attention_probs.to(query.dtype)
112
+ hidden_states = torch.bmm(attention_probs, value)
113
+
114
+ elif xformers_available:
115
+ hidden_states = xformers.ops.memory_efficient_attention(
116
+ query.contiguous(),
117
+ key.contiguous(),
118
+ value.contiguous(),
119
+ attn_bias=attention_mask,
120
+ )
121
+ hidden_states = hidden_states.to(query.dtype)
122
+
123
+ else:
124
+ q_bucket_size = 512
125
+ k_bucket_size = 1024
126
+
127
+ # use flash-attention
128
+ hidden_states = FlashAttentionFunction.apply(
129
+ query.contiguous(),
130
+ key.contiguous(),
131
+ value.contiguous(),
132
+ attention_mask,
133
+ False,
134
+ q_bucket_size,
135
+ k_bucket_size,
136
+ )
137
+ hidden_states = hidden_states.to(query.dtype)
138
+
139
+ hidden_states = attn.batch_to_head_dim(hidden_states)
140
+
141
+ # linear proj
142
+ hidden_states = attn.to_out[0](hidden_states)
143
+
144
+ # dropout
145
+ hidden_states = attn.to_out[1](hidden_states)
146
+
147
+ return hidden_states
148
+
149
+ class ModelWrapper:
150
+ def __init__(self, model, alphas_cumprod):
151
+ self.model = model
152
+ self.alphas_cumprod = alphas_cumprod
153
+
154
+ def apply_model(self, *args, **kwargs):
155
+ if len(args) == 3:
156
+ encoder_hidden_states = args[-1]
157
+ args = args[:2]
158
+ if kwargs.get("cond", None) is not None:
159
+ encoder_hidden_states = kwargs.pop("cond")
160
+ return self.model(
161
+ *args, encoder_hidden_states=encoder_hidden_states, **kwargs
162
+ ).sample
163
+
164
+
165
+ class StableDiffusionPipeline(DiffusionPipeline):
166
+
167
+ _optional_components = ["safety_checker", "feature_extractor"]
168
+
169
+ def __init__(
170
+ self,
171
+ vae,
172
+ text_encoder,
173
+ tokenizer,
174
+ unet,
175
+ scheduler,
176
+ ):
177
+ super().__init__()
178
+
179
+ # get correct sigmas from LMS
180
+ self.register_modules(
181
+ vae=vae,
182
+ text_encoder=text_encoder,
183
+ tokenizer=tokenizer,
184
+ unet=unet,
185
+ scheduler=scheduler,
186
+ )
187
+ self.setup_unet(self.unet)
188
+ self.setup_text_encoder()
189
+
190
+ def setup_text_encoder(self, n=1, new_encoder=None):
191
+ if new_encoder is not None:
192
+ self.text_encoder = new_encoder
193
+
194
+ self.prompt_parser = FrozenCLIPEmbedderWithCustomWords(self.tokenizer, self.text_encoder)
195
+ self.prompt_parser.CLIP_stop_at_last_layers = n
196
+
197
+ def setup_unet(self, unet):
198
+ unet = unet.to(self.device)
199
+ model = ModelWrapper(unet, self.scheduler.alphas_cumprod)
200
+ if self.scheduler.prediction_type == "v_prediction":
201
+ self.k_diffusion_model = CompVisVDenoiser(model)
202
+ else:
203
+ self.k_diffusion_model = CompVisDenoiser(model)
204
+
205
+ def get_scheduler(self, scheduler_type: str):
206
+ library = importlib.import_module("k_diffusion")
207
+ sampling = getattr(library, "sampling")
208
+ return getattr(sampling, scheduler_type)
209
+
210
+ def encode_sketchs(self, state, scale_ratio=8, g_strength=1.0, text_ids=None):
211
+ uncond, cond = text_ids[0], text_ids[1]
212
+
213
+ img_state = []
214
+ if state is None:
215
+ return torch.FloatTensor(0)
216
+
217
+ for k, v in state.items():
218
+ if v["map"] is None:
219
+ continue
220
+
221
+ v_input = self.tokenizer(
222
+ k,
223
+ max_length=self.tokenizer.model_max_length,
224
+ truncation=True,
225
+ add_special_tokens=False,
226
+ ).input_ids
227
+
228
+ dotmap = v["map"] < 255
229
+ out = dotmap.astype(float)
230
+ if v["mask_outsides"]:
231
+ out[out==0] = -1
232
+
233
+ arr = torch.from_numpy(
234
+ out * float(v["weight"]) * g_strength
235
+ )
236
+ img_state.append((v_input, arr))
237
+
238
+ if len(img_state) == 0:
239
+ return torch.FloatTensor(0)
240
+
241
+ w_tensors = dict()
242
+ cond = cond.tolist()
243
+ uncond = uncond.tolist()
244
+ for layer in self.unet.down_blocks:
245
+ c = int(len(cond))
246
+ w, h = img_state[0][1].shape
247
+ w_r, h_r = w // scale_ratio, h // scale_ratio
248
+
249
+ ret_cond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
250
+ ret_uncond_tensor = torch.zeros((1, int(w_r * h_r), c), dtype=torch.float32)
251
+
252
+ for v_as_tokens, img_where_color in img_state:
253
+ is_in = 0
254
+
255
+ ret = (
256
+ F.interpolate(
257
+ img_where_color.unsqueeze(0).unsqueeze(1),
258
+ scale_factor=1 / scale_ratio,
259
+ mode="bilinear",
260
+ align_corners=True,
261
+ )
262
+ .squeeze()
263
+ .reshape(-1, 1)
264
+ .repeat(1, len(v_as_tokens))
265
+ )
266
+
267
+ for idx, tok in enumerate(cond):
268
+ if cond[idx : idx + len(v_as_tokens)] == v_as_tokens:
269
+ is_in = 1
270
+ ret_cond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
271
+
272
+ for idx, tok in enumerate(uncond):
273
+ if uncond[idx : idx + len(v_as_tokens)] == v_as_tokens:
274
+ is_in = 1
275
+ ret_uncond_tensor[0, :, idx : idx + len(v_as_tokens)] += ret
276
+
277
+ if not is_in == 1:
278
+ print(f"tokens {v_as_tokens} not found in text")
279
+
280
+ w_tensors[w_r * h_r] = torch.cat([ret_uncond_tensor, ret_cond_tensor])
281
+ scale_ratio *= 2
282
+
283
+ return w_tensors
284
+
285
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
286
+ r"""
287
+ Enable sliced attention computation.
288
+
289
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
290
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
291
+
292
+ Args:
293
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
294
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
295
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
296
+ `attention_head_dim` must be a multiple of `slice_size`.
297
+ """
298
+ if slice_size == "auto":
299
+ # half the attention head size is usually a good trade-off between
300
+ # speed and memory
301
+ slice_size = self.unet.config.attention_head_dim // 2
302
+ self.unet.set_attention_slice(slice_size)
303
+
304
+ def disable_attention_slicing(self):
305
+ r"""
306
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
307
+ back to computing attention in one step.
308
+ """
309
+ # set slice_size = `None` to disable `attention slicing`
310
+ self.enable_attention_slicing(None)
311
+
312
+ def enable_sequential_cpu_offload(self, gpu_id=0):
313
+ r"""
314
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
315
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
316
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
317
+ """
318
+ if is_accelerate_available():
319
+ from accelerate import cpu_offload
320
+ else:
321
+ raise ImportError("Please install accelerate via `pip install accelerate`")
322
+
323
+ device = torch.device(f"cuda:{gpu_id}")
324
+
325
+ for cpu_offloaded_model in [
326
+ self.unet,
327
+ self.text_encoder,
328
+ self.vae,
329
+ self.safety_checker,
330
+ ]:
331
+ if cpu_offloaded_model is not None:
332
+ cpu_offload(cpu_offloaded_model, device)
333
+
334
+ @property
335
+ def _execution_device(self):
336
+ r"""
337
+ Returns the device on which the pipeline's models will be executed. After calling
338
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
339
+ hooks.
340
+ """
341
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
342
+ return self.device
343
+ for module in self.unet.modules():
344
+ if (
345
+ hasattr(module, "_hf_hook")
346
+ and hasattr(module._hf_hook, "execution_device")
347
+ and module._hf_hook.execution_device is not None
348
+ ):
349
+ return torch.device(module._hf_hook.execution_device)
350
+ return self.device
351
+
352
+ def decode_latents(self, latents):
353
+ latents = latents.to(self.device, dtype=self.vae.dtype)
354
+ latents = 1 / 0.18215 * latents
355
+ image = self.vae.decode(latents).sample
356
+ image = (image / 2 + 0.5).clamp(0, 1)
357
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
358
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
359
+ return image
360
+
361
+ def check_inputs(self, prompt, height, width, callback_steps):
362
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
363
+ raise ValueError(
364
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
365
+ )
366
+
367
+ if height % 8 != 0 or width % 8 != 0:
368
+ raise ValueError(
369
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
370
+ )
371
+
372
+ if (callback_steps is None) or (
373
+ callback_steps is not None
374
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
375
+ ):
376
+ raise ValueError(
377
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
378
+ f" {type(callback_steps)}."
379
+ )
380
+
381
+ def prepare_latents(
382
+ self,
383
+ batch_size,
384
+ num_channels_latents,
385
+ height,
386
+ width,
387
+ dtype,
388
+ device,
389
+ generator,
390
+ latents=None,
391
+ ):
392
+ shape = (batch_size, num_channels_latents, height // 8, width // 8)
393
+ if latents is None:
394
+ if device.type == "mps":
395
+ # randn does not work reproducibly on mps
396
+ latents = torch.randn(
397
+ shape, generator=generator, device="cpu", dtype=dtype
398
+ ).to(device)
399
+ else:
400
+ latents = torch.randn(
401
+ shape, generator=generator, device=device, dtype=dtype
402
+ )
403
+ else:
404
+ # if latents.shape != shape:
405
+ # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
406
+ latents = latents.to(device)
407
+
408
+ # scale the initial noise by the standard deviation required by the scheduler
409
+ return latents
410
+
411
+ def preprocess(self, image):
412
+ if isinstance(image, torch.Tensor):
413
+ return image
414
+ elif isinstance(image, PIL.Image.Image):
415
+ image = [image]
416
+
417
+ if isinstance(image[0], PIL.Image.Image):
418
+ w, h = image[0].size
419
+ w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
420
+
421
+ image = [
422
+ np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
423
+ None, :
424
+ ]
425
+ for i in image
426
+ ]
427
+ image = np.concatenate(image, axis=0)
428
+ image = np.array(image).astype(np.float32) / 255.0
429
+ image = image.transpose(0, 3, 1, 2)
430
+ image = 2.0 * image - 1.0
431
+ image = torch.from_numpy(image)
432
+ elif isinstance(image[0], torch.Tensor):
433
+ image = torch.cat(image, dim=0)
434
+ return image
435
+
436
+ @torch.no_grad()
437
+ def img2img(
438
+ self,
439
+ prompt: Union[str, List[str]],
440
+ num_inference_steps: int = 50,
441
+ guidance_scale: float = 7.5,
442
+ negative_prompt: Optional[Union[str, List[str]]] = None,
443
+ generator: Optional[torch.Generator] = None,
444
+ image: Optional[torch.FloatTensor] = None,
445
+ output_type: Optional[str] = "pil",
446
+ latents=None,
447
+ strength=1.0,
448
+ pww_state=None,
449
+ pww_attn_weight=1.0,
450
+ sampler_name="",
451
+ sampler_opt={},
452
+ start_time=-1,
453
+ timeout=180,
454
+ scale_ratio=8.0,
455
+ ):
456
+ sampler = self.get_scheduler(sampler_name)
457
+ if image is not None:
458
+ image = self.preprocess(image)
459
+ image = image.to(self.vae.device, dtype=self.vae.dtype)
460
+
461
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
462
+ latents = 0.18215 * init_latents
463
+
464
+ # 2. Define call parameters
465
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
466
+ device = self._execution_device
467
+ latents = latents.to(device, dtype=self.unet.dtype)
468
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
469
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
470
+ # corresponds to doing no classifier free guidance.
471
+ do_classifier_free_guidance = True
472
+ if guidance_scale <= 1.0:
473
+ raise ValueError("has to use guidance_scale")
474
+
475
+ # 3. Encode input prompt
476
+ text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
477
+ text_embeddings = text_embeddings.to(self.unet.dtype)
478
+
479
+ init_timestep = (
480
+ int(num_inference_steps / min(strength, 0.999)) if strength > 0 else 0
481
+ )
482
+ sigmas = self.get_sigmas(init_timestep, sampler_opt).to(
483
+ text_embeddings.device, dtype=text_embeddings.dtype
484
+ )
485
+
486
+ t_start = max(init_timestep - num_inference_steps, 0)
487
+ sigma_sched = sigmas[t_start:]
488
+
489
+ noise = randn_tensor(
490
+ latents.shape,
491
+ generator=generator,
492
+ device=device,
493
+ dtype=text_embeddings.dtype,
494
+ )
495
+ latents = latents.to(device)
496
+ latents = latents + noise * sigma_sched[0]
497
+
498
+ # 5. Prepare latent variables
499
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
500
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
501
+ latents.device
502
+ )
503
+
504
+ img_state = self.encode_sketchs(
505
+ pww_state,
506
+ g_strength=pww_attn_weight,
507
+ text_ids=text_ids,
508
+ )
509
+
510
+ def model_fn(x, sigma):
511
+
512
+ if start_time > 0 and timeout > 0:
513
+ assert (time.time() - start_time) < timeout, "inference process timed out"
514
+
515
+ latent_model_input = torch.cat([x] * 2)
516
+ weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
517
+ encoder_state = {
518
+ "img_state": img_state,
519
+ "states": text_embeddings,
520
+ "sigma": sigma[0],
521
+ "weight_func": weight_func,
522
+ }
523
+
524
+ noise_pred = self.k_diffusion_model(
525
+ latent_model_input, sigma, cond=encoder_state
526
+ )
527
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
528
+ noise_pred = noise_pred_uncond + guidance_scale * (
529
+ noise_pred_text - noise_pred_uncond
530
+ )
531
+ return noise_pred
532
+
533
+ sampler_args = self.get_sampler_extra_args_i2i(sigma_sched, sampler)
534
+ latents = sampler(model_fn, latents, **sampler_args)
535
+
536
+ # 8. Post-processing
537
+ image = self.decode_latents(latents)
538
+
539
+ # 10. Convert to PIL
540
+ if output_type == "pil":
541
+ image = self.numpy_to_pil(image)
542
+
543
+ return (image,)
544
+
545
+ def get_sigmas(self, steps, params):
546
+ discard_next_to_last_sigma = params.get("discard_next_to_last_sigma", False)
547
+ steps += 1 if discard_next_to_last_sigma else 0
548
+
549
+ if params.get("scheduler", None) == "karras":
550
+ sigma_min, sigma_max = (
551
+ self.k_diffusion_model.sigmas[0].item(),
552
+ self.k_diffusion_model.sigmas[-1].item(),
553
+ )
554
+ sigmas = k_diffusion.sampling.get_sigmas_karras(
555
+ n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=self.device
556
+ )
557
+ else:
558
+ sigmas = self.k_diffusion_model.get_sigmas(steps)
559
+
560
+ if discard_next_to_last_sigma:
561
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
562
+
563
+ return sigmas
564
+
565
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
566
+ def get_sampler_extra_args_t2i(self, sigmas, eta, steps, func):
567
+ extra_params_kwargs = {}
568
+
569
+ if "eta" in inspect.signature(func).parameters:
570
+ extra_params_kwargs["eta"] = eta
571
+
572
+ if "sigma_min" in inspect.signature(func).parameters:
573
+ extra_params_kwargs["sigma_min"] = sigmas[0].item()
574
+ extra_params_kwargs["sigma_max"] = sigmas[-1].item()
575
+
576
+ if "n" in inspect.signature(func).parameters:
577
+ extra_params_kwargs["n"] = steps
578
+ else:
579
+ extra_params_kwargs["sigmas"] = sigmas
580
+
581
+ return extra_params_kwargs
582
+
583
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/48a15821de768fea76e66f26df83df3fddf18f4b/modules/sd_samplers.py#L454
584
+ def get_sampler_extra_args_i2i(self, sigmas, func):
585
+ extra_params_kwargs = {}
586
+
587
+ if "sigma_min" in inspect.signature(func).parameters:
588
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
589
+ extra_params_kwargs["sigma_min"] = sigmas[-2]
590
+
591
+ if "sigma_max" in inspect.signature(func).parameters:
592
+ extra_params_kwargs["sigma_max"] = sigmas[0]
593
+
594
+ if "n" in inspect.signature(func).parameters:
595
+ extra_params_kwargs["n"] = len(sigmas) - 1
596
+
597
+ if "sigma_sched" in inspect.signature(func).parameters:
598
+ extra_params_kwargs["sigma_sched"] = sigmas
599
+
600
+ if "sigmas" in inspect.signature(func).parameters:
601
+ extra_params_kwargs["sigmas"] = sigmas
602
+
603
+ return extra_params_kwargs
604
+
605
+ @torch.no_grad()
606
+ def txt2img(
607
+ self,
608
+ prompt: Union[str, List[str]],
609
+ height: int = 512,
610
+ width: int = 512,
611
+ num_inference_steps: int = 50,
612
+ guidance_scale: float = 7.5,
613
+ negative_prompt: Optional[Union[str, List[str]]] = None,
614
+ eta: float = 0.0,
615
+ generator: Optional[torch.Generator] = None,
616
+ latents: Optional[torch.FloatTensor] = None,
617
+ output_type: Optional[str] = "pil",
618
+ callback_steps: Optional[int] = 1,
619
+ upscale=False,
620
+ upscale_x: float = 2.0,
621
+ upscale_method: str = "bicubic",
622
+ upscale_antialias: bool = False,
623
+ upscale_denoising_strength: int = 0.7,
624
+ pww_state=None,
625
+ pww_attn_weight=1.0,
626
+ sampler_name="",
627
+ sampler_opt={},
628
+ start_time=-1,
629
+ timeout=180,
630
+ ):
631
+ sampler = self.get_scheduler(sampler_name)
632
+ # 1. Check inputs. Raise error if not correct
633
+ self.check_inputs(prompt, height, width, callback_steps)
634
+
635
+ # 2. Define call parameters
636
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
637
+ device = self._execution_device
638
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
639
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
640
+ # corresponds to doing no classifier free guidance.
641
+ do_classifier_free_guidance = True
642
+ if guidance_scale <= 1.0:
643
+ raise ValueError("has to use guidance_scale")
644
+
645
+ # 3. Encode input prompt
646
+ text_ids, text_embeddings = self.prompt_parser([negative_prompt, prompt])
647
+ text_embeddings = text_embeddings.to(self.unet.dtype)
648
+
649
+ # 4. Prepare timesteps
650
+ sigmas = self.get_sigmas(num_inference_steps, sampler_opt).to(
651
+ text_embeddings.device, dtype=text_embeddings.dtype
652
+ )
653
+
654
+ # 5. Prepare latent variables
655
+ num_channels_latents = self.unet.in_channels
656
+ latents = self.prepare_latents(
657
+ batch_size,
658
+ num_channels_latents,
659
+ height,
660
+ width,
661
+ text_embeddings.dtype,
662
+ device,
663
+ generator,
664
+ latents,
665
+ )
666
+ latents = latents * sigmas[0]
667
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
668
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(
669
+ latents.device
670
+ )
671
+
672
+ img_state = self.encode_sketchs(
673
+ pww_state,
674
+ g_strength=pww_attn_weight,
675
+ text_ids=text_ids,
676
+ )
677
+
678
+ def model_fn(x, sigma):
679
+
680
+ if start_time > 0 and timeout > 0:
681
+ assert (time.time() - start_time) < timeout, "inference process timed out"
682
+
683
+ latent_model_input = torch.cat([x] * 2)
684
+ weight_func = lambda w, sigma, qk: w * math.log(1 + sigma) * qk.max()
685
+ encoder_state = {
686
+ "img_state": img_state,
687
+ "states": text_embeddings,
688
+ "sigma": sigma[0],
689
+ "weight_func": weight_func,
690
+ }
691
+
692
+ noise_pred = self.k_diffusion_model(
693
+ latent_model_input, sigma, cond=encoder_state
694
+ )
695
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
696
+ noise_pred = noise_pred_uncond + guidance_scale * (
697
+ noise_pred_text - noise_pred_uncond
698
+ )
699
+ return noise_pred
700
+
701
+ extra_args = self.get_sampler_extra_args_t2i(
702
+ sigmas, eta, num_inference_steps, sampler
703
+ )
704
+ latents = sampler(model_fn, latents, **extra_args)
705
+
706
+ if upscale:
707
+ target_height = height * upscale_x
708
+ target_width = width * upscale_x
709
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
710
+ latents = torch.nn.functional.interpolate(
711
+ latents,
712
+ size=(
713
+ int(target_height // vae_scale_factor),
714
+ int(target_width // vae_scale_factor),
715
+ ),
716
+ mode=upscale_method,
717
+ antialias=upscale_antialias,
718
+ )
719
+ return self.img2img(
720
+ prompt=prompt,
721
+ num_inference_steps=num_inference_steps,
722
+ guidance_scale=guidance_scale,
723
+ negative_prompt=negative_prompt,
724
+ generator=generator,
725
+ latents=latents,
726
+ strength=upscale_denoising_strength,
727
+ sampler_name=sampler_name,
728
+ sampler_opt=sampler_opt,
729
+ pww_state=None,
730
+ pww_attn_weight=pww_attn_weight / 2,
731
+ )
732
+
733
+ # 8. Post-processing
734
+ image = self.decode_latents(latents)
735
+
736
+ # 10. Convert to PIL
737
+ if output_type == "pil":
738
+ image = self.numpy_to_pil(image)
739
+
740
+ return (image,)
741
+
742
+
743
+ class FlashAttentionFunction(Function):
744
+ @staticmethod
745
+ @torch.no_grad()
746
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
747
+ """Algorithm 2 in the paper"""
748
+
749
+ device = q.device
750
+ max_neg_value = -torch.finfo(q.dtype).max
751
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
752
+
753
+ o = torch.zeros_like(q)
754
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), device=device)
755
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device=device)
756
+
757
+ scale = q.shape[-1] ** -0.5
758
+
759
+ if not exists(mask):
760
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
761
+ else:
762
+ mask = rearrange(mask, "b n -> b 1 1 n")
763
+ mask = mask.split(q_bucket_size, dim=-1)
764
+
765
+ row_splits = zip(
766
+ q.split(q_bucket_size, dim=-2),
767
+ o.split(q_bucket_size, dim=-2),
768
+ mask,
769
+ all_row_sums.split(q_bucket_size, dim=-2),
770
+ all_row_maxes.split(q_bucket_size, dim=-2),
771
+ )
772
+
773
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
774
+ q_start_index = ind * q_bucket_size - qk_len_diff
775
+
776
+ col_splits = zip(
777
+ k.split(k_bucket_size, dim=-2),
778
+ v.split(k_bucket_size, dim=-2),
779
+ )
780
+
781
+ for k_ind, (kc, vc) in enumerate(col_splits):
782
+ k_start_index = k_ind * k_bucket_size
783
+
784
+ attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
785
+
786
+ if exists(row_mask):
787
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
788
+
789
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
790
+ causal_mask = torch.ones(
791
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
792
+ ).triu(q_start_index - k_start_index + 1)
793
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
794
+
795
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
796
+ attn_weights -= block_row_maxes
797
+ exp_weights = torch.exp(attn_weights)
798
+
799
+ if exists(row_mask):
800
+ exp_weights.masked_fill_(~row_mask, 0.0)
801
+
802
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
803
+ min=EPSILON
804
+ )
805
+
806
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
807
+
808
+ exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
809
+
810
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
811
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
812
+
813
+ new_row_sums = (
814
+ exp_row_max_diff * row_sums
815
+ + exp_block_row_max_diff * block_row_sums
816
+ )
817
+
818
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
819
+ (exp_block_row_max_diff / new_row_sums) * exp_values
820
+ )
821
+
822
+ row_maxes.copy_(new_row_maxes)
823
+ row_sums.copy_(new_row_sums)
824
+
825
+ lse = all_row_sums.log() + all_row_maxes
826
+
827
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
828
+ ctx.save_for_backward(q, k, v, o, lse)
829
+
830
+ return o
831
+
832
+ @staticmethod
833
+ @torch.no_grad()
834
+ def backward(ctx, do):
835
+ """Algorithm 4 in the paper"""
836
+
837
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
838
+ q, k, v, o, lse = ctx.saved_tensors
839
+
840
+ device = q.device
841
+
842
+ max_neg_value = -torch.finfo(q.dtype).max
843
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
844
+
845
+ dq = torch.zeros_like(q)
846
+ dk = torch.zeros_like(k)
847
+ dv = torch.zeros_like(v)
848
+
849
+ row_splits = zip(
850
+ q.split(q_bucket_size, dim=-2),
851
+ o.split(q_bucket_size, dim=-2),
852
+ do.split(q_bucket_size, dim=-2),
853
+ mask,
854
+ lse.split(q_bucket_size, dim=-2),
855
+ dq.split(q_bucket_size, dim=-2),
856
+ )
857
+
858
+ for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
859
+ q_start_index = ind * q_bucket_size - qk_len_diff
860
+
861
+ col_splits = zip(
862
+ k.split(k_bucket_size, dim=-2),
863
+ v.split(k_bucket_size, dim=-2),
864
+ dk.split(k_bucket_size, dim=-2),
865
+ dv.split(k_bucket_size, dim=-2),
866
+ )
867
+
868
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
869
+ k_start_index = k_ind * k_bucket_size
870
+
871
+ attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
872
+
873
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
874
+ causal_mask = torch.ones(
875
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
876
+ ).triu(q_start_index - k_start_index + 1)
877
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
878
+
879
+ p = torch.exp(attn_weights - lsec)
880
+
881
+ if exists(row_mask):
882
+ p.masked_fill_(~row_mask, 0.0)
883
+
884
+ dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
885
+ dp = einsum("... i d, ... j d -> ... i j", doc, vc)
886
+
887
+ D = (doc * oc).sum(dim=-1, keepdims=True)
888
+ ds = p * scale * (dp - D)
889
+
890
+ dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
891
+ dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
892
+
893
+ dqc.add_(dq_chunk)
894
+ dkc.add_(dk_chunk)
895
+ dvc.add_(dv_chunk)
896
+
897
+ return dq, dk, dv, None, None, None, None
modules/prompt_parser.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+
7
+ # Code from https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/8e2aeee4a127b295bfc880800e4a312e0f049b85, modified.
8
+
9
+ class PromptChunk:
10
+ """
11
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
12
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
13
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
14
+ so just 75 tokens from prompt.
15
+ """
16
+
17
+ def __init__(self):
18
+ self.tokens = []
19
+ self.multipliers = []
20
+ self.fixes = []
21
+
22
+
23
+ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
24
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
25
+ have unlimited prompt length and assign weights to tokens in prompt.
26
+ """
27
+
28
+ def __init__(self, text_encoder, enable_emphasis=True):
29
+ super().__init__()
30
+
31
+ self.device = lambda: text_encoder.device
32
+ self.enable_emphasis = enable_emphasis
33
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
34
+ depending on model."""
35
+
36
+ self.chunk_length = 75
37
+
38
+ def empty_chunk(self):
39
+ """creates an empty PromptChunk and returns it"""
40
+
41
+ chunk = PromptChunk()
42
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
43
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
44
+ return chunk
45
+
46
+ def get_target_prompt_token_count(self, token_count):
47
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
48
+
49
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
50
+
51
+ def tokenize_line(self, line):
52
+ """
53
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
54
+ represent the prompt.
55
+ Returns the list and the total number of tokens in the prompt.
56
+ """
57
+
58
+ if self.enable_emphasis:
59
+ parsed = parse_prompt_attention(line)
60
+ else:
61
+ parsed = [[line, 1.0]]
62
+
63
+ tokenized = self.tokenize([text for text, _ in parsed])
64
+
65
+ chunks = []
66
+ chunk = PromptChunk()
67
+ token_count = 0
68
+ last_comma = -1
69
+
70
+ def next_chunk(is_last=False):
71
+ """puts current chunk into the list of results and produces the next one - empty;
72
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
73
+ nonlocal token_count
74
+ nonlocal last_comma
75
+ nonlocal chunk
76
+
77
+ if is_last:
78
+ token_count += len(chunk.tokens)
79
+ else:
80
+ token_count += self.chunk_length
81
+
82
+ to_add = self.chunk_length - len(chunk.tokens)
83
+ if to_add > 0:
84
+ chunk.tokens += [self.id_end] * to_add
85
+ chunk.multipliers += [1.0] * to_add
86
+
87
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
88
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
89
+
90
+ last_comma = -1
91
+ chunks.append(chunk)
92
+ chunk = PromptChunk()
93
+
94
+ comma_padding_backtrack = 20 # default value in https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/shared.py#L410
95
+ for tokens, (text, weight) in zip(tokenized, parsed):
96
+ if text == "BREAK" and weight == -1:
97
+ next_chunk()
98
+ continue
99
+
100
+ position = 0
101
+ while position < len(tokens):
102
+ token = tokens[position]
103
+
104
+ if token == self.comma_token:
105
+ last_comma = len(chunk.tokens)
106
+
107
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
108
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
109
+ elif (
110
+ comma_padding_backtrack != 0
111
+ and len(chunk.tokens) == self.chunk_length
112
+ and last_comma != -1
113
+ and len(chunk.tokens) - last_comma <= comma_padding_backtrack
114
+ ):
115
+ break_location = last_comma + 1
116
+
117
+ reloc_tokens = chunk.tokens[break_location:]
118
+ reloc_mults = chunk.multipliers[break_location:]
119
+
120
+ chunk.tokens = chunk.tokens[:break_location]
121
+ chunk.multipliers = chunk.multipliers[:break_location]
122
+
123
+ next_chunk()
124
+ chunk.tokens = reloc_tokens
125
+ chunk.multipliers = reloc_mults
126
+
127
+ if len(chunk.tokens) == self.chunk_length:
128
+ next_chunk()
129
+
130
+ chunk.tokens.append(token)
131
+ chunk.multipliers.append(weight)
132
+ position += 1
133
+
134
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
135
+ next_chunk(is_last=True)
136
+
137
+ return chunks, token_count
138
+
139
+ def process_texts(self, texts):
140
+ """
141
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
142
+ length, in tokens, of all texts.
143
+ """
144
+
145
+ token_count = 0
146
+
147
+ cache = {}
148
+ batch_chunks = []
149
+ for line in texts:
150
+ if line in cache:
151
+ chunks = cache[line]
152
+ else:
153
+ chunks, current_token_count = self.tokenize_line(line)
154
+ token_count = max(current_token_count, token_count)
155
+
156
+ cache[line] = chunks
157
+
158
+ batch_chunks.append(chunks)
159
+
160
+ return batch_chunks, token_count
161
+
162
+ def forward(self, texts):
163
+ """
164
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
165
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
166
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
167
+ An example shape returned by this function can be: (2, 77, 768).
168
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
169
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
170
+ """
171
+
172
+ batch_chunks, token_count = self.process_texts(texts)
173
+ chunk_count = max([len(x) for x in batch_chunks])
174
+
175
+ zs = []
176
+ ts = []
177
+ for i in range(chunk_count):
178
+ batch_chunk = [
179
+ chunks[i] if i < len(chunks) else self.empty_chunk()
180
+ for chunks in batch_chunks
181
+ ]
182
+
183
+ tokens = [x.tokens for x in batch_chunk]
184
+ multipliers = [x.multipliers for x in batch_chunk]
185
+ # self.embeddings.fixes = [x.fixes for x in batch_chunk]
186
+
187
+ # for fixes in self.embeddings.fixes:
188
+ # for position, embedding in fixes:
189
+ # used_embeddings[embedding.name] = embedding
190
+
191
+ z = self.process_tokens(tokens, multipliers)
192
+ zs.append(z)
193
+ ts.append(tokens)
194
+
195
+ return np.hstack(ts), torch.hstack(zs)
196
+
197
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
198
+ """
199
+ sends one single prompt chunk to be encoded by transformers neural network.
200
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
201
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
202
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
203
+ corresponds to one token.
204
+ """
205
+ tokens = torch.asarray(remade_batch_tokens).to(self.device())
206
+
207
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
208
+ if self.id_end != self.id_pad:
209
+ for batch_pos in range(len(remade_batch_tokens)):
210
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
211
+ tokens[batch_pos, index + 1 : tokens.shape[1]] = self.id_pad
212
+
213
+ z = self.encode_with_transformers(tokens)
214
+
215
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
216
+ batch_multipliers = torch.asarray(batch_multipliers).to(self.device())
217
+ original_mean = z.mean()
218
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
219
+ new_mean = z.mean()
220
+ z = z * (original_mean / new_mean)
221
+
222
+ return z
223
+
224
+
225
+ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
226
+ def __init__(self, tokenizer, text_encoder):
227
+ super().__init__(text_encoder)
228
+ self.tokenizer = tokenizer
229
+ self.text_encoder = text_encoder
230
+
231
+ vocab = self.tokenizer.get_vocab()
232
+
233
+ self.comma_token = vocab.get(",</w>", None)
234
+
235
+ self.token_mults = {}
236
+ tokens_with_parens = [
237
+ (k, v)
238
+ for k, v in vocab.items()
239
+ if "(" in k or ")" in k or "[" in k or "]" in k
240
+ ]
241
+ for text, ident in tokens_with_parens:
242
+ mult = 1.0
243
+ for c in text:
244
+ if c == "[":
245
+ mult /= 1.1
246
+ if c == "]":
247
+ mult *= 1.1
248
+ if c == "(":
249
+ mult *= 1.1
250
+ if c == ")":
251
+ mult /= 1.1
252
+
253
+ if mult != 1.0:
254
+ self.token_mults[ident] = mult
255
+
256
+ self.id_start = self.tokenizer.bos_token_id
257
+ self.id_end = self.tokenizer.eos_token_id
258
+ self.id_pad = self.id_end
259
+
260
+ def tokenize(self, texts):
261
+ tokenized = self.tokenizer(
262
+ texts, truncation=False, add_special_tokens=False
263
+ )["input_ids"]
264
+
265
+ return tokenized
266
+
267
+ def encode_with_transformers(self, tokens):
268
+ CLIP_stop_at_last_layers = 1
269
+ tokens = tokens.to(self.text_encoder.device)
270
+ outputs = self.text_encoder(tokens, output_hidden_states=True)
271
+
272
+ if CLIP_stop_at_last_layers > 1:
273
+ z = outputs.hidden_states[-CLIP_stop_at_last_layers]
274
+ z = self.text_encoder.text_model.final_layer_norm(z)
275
+ else:
276
+ z = outputs.last_hidden_state
277
+
278
+ return z
279
+
280
+
281
+ re_attention = re.compile(
282
+ r"""
283
+ \\\(|
284
+ \\\)|
285
+ \\\[|
286
+ \\]|
287
+ \\\\|
288
+ \\|
289
+ \(|
290
+ \[|
291
+ :([+-]?[.\d]+)\)|
292
+ \)|
293
+ ]|
294
+ [^\\()\[\]:]+|
295
+ :
296
+ """,
297
+ re.X,
298
+ )
299
+
300
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
301
+
302
+
303
+ def parse_prompt_attention(text):
304
+ """
305
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
306
+ Accepted tokens are:
307
+ (abc) - increases attention to abc by a multiplier of 1.1
308
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
309
+ [abc] - decreases attention to abc by a multiplier of 1.1
310
+ \( - literal character '('
311
+ \[ - literal character '['
312
+ \) - literal character ')'
313
+ \] - literal character ']'
314
+ \\ - literal character '\'
315
+ anything else - just text
316
+
317
+ >>> parse_prompt_attention('normal text')
318
+ [['normal text', 1.0]]
319
+ >>> parse_prompt_attention('an (important) word')
320
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
321
+ >>> parse_prompt_attention('(unbalanced')
322
+ [['unbalanced', 1.1]]
323
+ >>> parse_prompt_attention('\(literal\]')
324
+ [['(literal]', 1.0]]
325
+ >>> parse_prompt_attention('(unnecessary)(parens)')
326
+ [['unnecessaryparens', 1.1]]
327
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
328
+ [['a ', 1.0],
329
+ ['house', 1.5730000000000004],
330
+ [' ', 1.1],
331
+ ['on', 1.0],
332
+ [' a ', 1.1],
333
+ ['hill', 0.55],
334
+ [', sun, ', 1.1],
335
+ ['sky', 1.4641000000000006],
336
+ ['.', 1.1]]
337
+ """
338
+
339
+ res = []
340
+ round_brackets = []
341
+ square_brackets = []
342
+
343
+ round_bracket_multiplier = 1.1
344
+ square_bracket_multiplier = 1 / 1.1
345
+
346
+ def multiply_range(start_position, multiplier):
347
+ for p in range(start_position, len(res)):
348
+ res[p][1] *= multiplier
349
+
350
+ for m in re_attention.finditer(text):
351
+ text = m.group(0)
352
+ weight = m.group(1)
353
+
354
+ if text.startswith("\\"):
355
+ res.append([text[1:], 1.0])
356
+ elif text == "(":
357
+ round_brackets.append(len(res))
358
+ elif text == "[":
359
+ square_brackets.append(len(res))
360
+ elif weight is not None and len(round_brackets) > 0:
361
+ multiply_range(round_brackets.pop(), float(weight))
362
+ elif text == ")" and len(round_brackets) > 0:
363
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
364
+ elif text == "]" and len(square_brackets) > 0:
365
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
366
+ else:
367
+ parts = re.split(re_break, text)
368
+ for i, part in enumerate(parts):
369
+ if i > 0:
370
+ res.append(["BREAK", -1])
371
+ res.append([part, 1.0])
372
+
373
+ for pos in round_brackets:
374
+ multiply_range(pos, round_bracket_multiplier)
375
+
376
+ for pos in square_brackets:
377
+ multiply_range(pos, square_bracket_multiplier)
378
+
379
+ if len(res) == 0:
380
+ res = [["", 1.0]]
381
+
382
+ # merge runs of identical weights
383
+ i = 0
384
+ while i + 1 < len(res):
385
+ if res[i][1] == res[i + 1][1]:
386
+ res[i][0] += res[i + 1][0]
387
+ res.pop(i + 1)
388
+ else:
389
+ i += 1
390
+
391
+ return res
modules/safe.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this code is adapted from the script contributed by anon from /h/
2
+ # modified, from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/6cff4401824299a983c8e13424018efc347b4a2b/modules/safe.py
3
+
4
+ import io
5
+ import pickle
6
+ import collections
7
+ import sys
8
+ import traceback
9
+
10
+ import torch
11
+ import numpy
12
+ import _codecs
13
+ import zipfile
14
+ import re
15
+
16
+
17
+ # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
18
+ TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
19
+
20
+
21
+ def encode(*args):
22
+ out = _codecs.encode(*args)
23
+ return out
24
+
25
+
26
+ class RestrictedUnpickler(pickle.Unpickler):
27
+ extra_handler = None
28
+
29
+ def persistent_load(self, saved_id):
30
+ assert saved_id[0] == 'storage'
31
+ return TypedStorage()
32
+
33
+ def find_class(self, module, name):
34
+ if self.extra_handler is not None:
35
+ res = self.extra_handler(module, name)
36
+ if res is not None:
37
+ return res
38
+
39
+ if module == 'collections' and name == 'OrderedDict':
40
+ return getattr(collections, name)
41
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
42
+ return getattr(torch._utils, name)
43
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
44
+ return getattr(torch, name)
45
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
46
+ return getattr(torch.nn.modules.container, name)
47
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
48
+ return getattr(numpy.core.multiarray, name)
49
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
50
+ return getattr(numpy, name)
51
+ if module == '_codecs' and name == 'encode':
52
+ return encode
53
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
54
+ import pytorch_lightning.callbacks
55
+ return pytorch_lightning.callbacks.model_checkpoint
56
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
57
+ import pytorch_lightning.callbacks.model_checkpoint
58
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
59
+ if module == "__builtin__" and name == 'set':
60
+ return set
61
+
62
+ # Forbid everything else.
63
+ raise Exception(f"global '{module}/{name}' is forbidden")
64
+
65
+
66
+ # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
67
+ allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
68
+ data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
69
+
70
+ def check_zip_filenames(filename, names):
71
+ for name in names:
72
+ if allowed_zip_names_re.match(name):
73
+ continue
74
+
75
+ raise Exception(f"bad file inside {filename}: {name}")
76
+
77
+
78
+ def check_pt(filename, extra_handler):
79
+ try:
80
+
81
+ # new pytorch format is a zip file
82
+ with zipfile.ZipFile(filename) as z:
83
+ check_zip_filenames(filename, z.namelist())
84
+
85
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
86
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
87
+ if len(data_pkl_filenames) == 0:
88
+ raise Exception(f"data.pkl not found in {filename}")
89
+ if len(data_pkl_filenames) > 1:
90
+ raise Exception(f"Multiple data.pkl found in {filename}")
91
+ with z.open(data_pkl_filenames[0]) as file:
92
+ unpickler = RestrictedUnpickler(file)
93
+ unpickler.extra_handler = extra_handler
94
+ unpickler.load()
95
+
96
+ except zipfile.BadZipfile:
97
+
98
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
99
+ with open(filename, "rb") as file:
100
+ unpickler = RestrictedUnpickler(file)
101
+ unpickler.extra_handler = extra_handler
102
+ for i in range(5):
103
+ unpickler.load()
104
+
105
+
106
+ def load(filename, *args, **kwargs):
107
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
108
+
109
+
110
+ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
111
+ """
112
+ this function is intended to be used by extensions that want to load models with
113
+ some extra classes in them that the usual unpickler would find suspicious.
114
+
115
+ Use the extra_handler argument to specify a function that takes module and field name as text,
116
+ and returns that field's value:
117
+
118
+ ```python
119
+ def extra(module, name):
120
+ if module == 'collections' and name == 'OrderedDict':
121
+ return collections.OrderedDict
122
+
123
+ return None
124
+
125
+ safe.load_with_extra('model.pt', extra_handler=extra)
126
+ ```
127
+
128
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
129
+ definitely unsafe.
130
+ """
131
+
132
+ try:
133
+ check_pt(filename, extra_handler)
134
+
135
+ except pickle.UnpicklingError:
136
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
137
+ print(traceback.format_exc(), file=sys.stderr)
138
+ print("The file is most likely corrupted.", file=sys.stderr)
139
+ return None
140
+
141
+ except Exception:
142
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
143
+ print(traceback.format_exc(), file=sys.stderr)
144
+ print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
145
+ print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
146
+ return None
147
+
148
+ return unsafe_torch_load(filename, *args, **kwargs)
149
+
150
+
151
+ class Extra:
152
+ """
153
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
154
+ (because it's not your code making the torch.load call). The intended use is like this:
155
+
156
+ ```
157
+ import torch
158
+ from modules import safe
159
+
160
+ def handler(module, name):
161
+ if module == 'torch' and name in ['float64', 'float16']:
162
+ return getattr(torch, name)
163
+
164
+ return None
165
+
166
+ with safe.Extra(handler):
167
+ x = torch.load('model.pt')
168
+ ```
169
+ """
170
+
171
+ def __init__(self, handler):
172
+ self.handler = handler
173
+
174
+ def __enter__(self):
175
+ global global_extra_handler
176
+
177
+ assert global_extra_handler is None, 'already inside an Extra() block'
178
+ global_extra_handler = self.handler
179
+
180
+ def __exit__(self, exc_type, exc_val, exc_tb):
181
+ global global_extra_handler
182
+
183
+ global_extra_handler = None
184
+
185
+
186
+ unsafe_torch_load = torch.load
187
+ torch.load = load
188
+ global_extra_handler = None