Spaces:
Running
on
T4
Running
on
T4
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from gradio import inputs
|
9 |
+
from diffusers import (
|
10 |
+
AutoencoderKL,
|
11 |
+
DDIMScheduler,
|
12 |
+
UNet2DConditionModel,
|
13 |
+
)
|
14 |
+
from modules.model_pww import CrossAttnProcessor, StableDiffusionPipeline, load_lora_attn_procs
|
15 |
+
from torchvision import transforms
|
16 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
17 |
+
from PIL import Image
|
18 |
+
from pathlib import Path
|
19 |
+
from safetensors.torch import load_file
|
20 |
+
import modules.safe as _
|
21 |
+
|
22 |
+
models = [
|
23 |
+
("AbyssOrangeMix_Base", "OrangeMix/AbyssOrangeMix2"),
|
24 |
+
]
|
25 |
+
|
26 |
+
base_name = "AbyssOrangeMix_Base"
|
27 |
+
base_model = "OrangeMix/AbyssOrangeMix2"
|
28 |
+
|
29 |
+
samplers_k_diffusion = [
|
30 |
+
("Euler a", "sample_euler_ancestral", {}),
|
31 |
+
("Euler", "sample_euler", {}),
|
32 |
+
("LMS", "sample_lms", {}),
|
33 |
+
("Heun", "sample_heun", {}),
|
34 |
+
("DPM2", "sample_dpm_2", {"discard_next_to_last_sigma": True}),
|
35 |
+
("DPM2 a", "sample_dpm_2_ancestral", {"discard_next_to_last_sigma": True}),
|
36 |
+
("DPM++ 2S a", "sample_dpmpp_2s_ancestral", {}),
|
37 |
+
("DPM++ 2M", "sample_dpmpp_2m", {}),
|
38 |
+
("DPM++ SDE", "sample_dpmpp_sde", {}),
|
39 |
+
("DPM fast", "sample_dpm_fast", {}),
|
40 |
+
("DPM adaptive", "sample_dpm_adaptive", {}),
|
41 |
+
("LMS Karras", "sample_lms", {"scheduler": "karras"}),
|
42 |
+
(
|
43 |
+
"DPM2 Karras",
|
44 |
+
"sample_dpm_2",
|
45 |
+
{"scheduler": "karras", "discard_next_to_last_sigma": True},
|
46 |
+
),
|
47 |
+
(
|
48 |
+
"DPM2 a Karras",
|
49 |
+
"sample_dpm_2_ancestral",
|
50 |
+
{"scheduler": "karras", "discard_next_to_last_sigma": True},
|
51 |
+
),
|
52 |
+
("DPM++ 2S a Karras", "sample_dpmpp_2s_ancestral", {"scheduler": "karras"}),
|
53 |
+
("DPM++ 2M Karras", "sample_dpmpp_2m", {"scheduler": "karras"}),
|
54 |
+
("DPM++ SDE Karras", "sample_dpmpp_sde", {"scheduler": "karras"}),
|
55 |
+
]
|
56 |
+
|
57 |
+
start_time = time.time()
|
58 |
+
|
59 |
+
scheduler = DDIMScheduler.from_pretrained(
|
60 |
+
base_model,
|
61 |
+
subfolder="scheduler",
|
62 |
+
)
|
63 |
+
vae = AutoencoderKL.from_pretrained(
|
64 |
+
"stabilityai/sd-vae-ft-ema",
|
65 |
+
torch_dtype=torch.float32
|
66 |
+
)
|
67 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
68 |
+
base_model,
|
69 |
+
subfolder="text_encoder",
|
70 |
+
torch_dtype=torch.float32,
|
71 |
+
)
|
72 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
73 |
+
base_model,
|
74 |
+
subfolder="tokenizer",
|
75 |
+
torch_dtype=torch.float32,
|
76 |
+
)
|
77 |
+
unet = UNet2DConditionModel.from_pretrained(
|
78 |
+
base_model,
|
79 |
+
subfolder="unet",
|
80 |
+
torch_dtype=torch.float32,
|
81 |
+
)
|
82 |
+
pipe = StableDiffusionPipeline(
|
83 |
+
text_encoder=text_encoder,
|
84 |
+
tokenizer=tokenizer,
|
85 |
+
unet=unet,
|
86 |
+
vae=vae,
|
87 |
+
scheduler=scheduler,
|
88 |
+
)
|
89 |
+
|
90 |
+
unet.set_attn_processor(CrossAttnProcessor)
|
91 |
+
if torch.cuda.is_available():
|
92 |
+
pipe = pipe.to("cuda")
|
93 |
+
|
94 |
+
def get_model_list():
|
95 |
+
model_available = []
|
96 |
+
for model in models:
|
97 |
+
if Path(model[1]).is_dir():
|
98 |
+
model_available.append(model)
|
99 |
+
return model_available
|
100 |
+
|
101 |
+
|
102 |
+
unet_cache = dict()
|
103 |
+
|
104 |
+
|
105 |
+
def get_model(name):
|
106 |
+
keys = [k[0] for k in models]
|
107 |
+
if name not in unet_cache:
|
108 |
+
if name not in keys:
|
109 |
+
raise ValueError(name)
|
110 |
+
else:
|
111 |
+
unet = UNet2DConditionModel.from_pretrained(
|
112 |
+
models[keys.index(name)][1],
|
113 |
+
subfolder="unet",
|
114 |
+
torch_dtype=torch.float32,
|
115 |
+
)
|
116 |
+
unet_cache[name] = unet
|
117 |
+
|
118 |
+
g_unet = unet_cache[name]
|
119 |
+
g_unet.set_attn_processor(None)
|
120 |
+
return g_unet
|
121 |
+
|
122 |
+
|
123 |
+
def error_str(error, title="Error"):
|
124 |
+
return (
|
125 |
+
f"""#### {title}
|
126 |
+
{error}"""
|
127 |
+
if error
|
128 |
+
else ""
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
te_base_weight = text_encoder.get_input_embeddings().weight.data.detach().clone()
|
133 |
+
|
134 |
+
|
135 |
+
def restore_all():
|
136 |
+
global te_base_weight, tokenizer
|
137 |
+
text_encoder.get_input_embeddings().weight.data = te_base_weight
|
138 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
139 |
+
"/root/workspace/storage/models/orangemix",
|
140 |
+
subfolder="tokenizer",
|
141 |
+
torch_dtype=torch.float16,
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
def inference(
|
146 |
+
prompt,
|
147 |
+
guidance,
|
148 |
+
steps,
|
149 |
+
width=512,
|
150 |
+
height=512,
|
151 |
+
seed=0,
|
152 |
+
neg_prompt="",
|
153 |
+
state=None,
|
154 |
+
g_strength=0.4,
|
155 |
+
img_input=None,
|
156 |
+
i2i_scale=0.5,
|
157 |
+
hr_enabled=False,
|
158 |
+
hr_method="Latent",
|
159 |
+
hr_scale=1.5,
|
160 |
+
hr_denoise=0.8,
|
161 |
+
sampler="DPM++ 2M Karras",
|
162 |
+
embs=None,
|
163 |
+
model=None,
|
164 |
+
lora_state=None,
|
165 |
+
lora_scale=None,
|
166 |
+
):
|
167 |
+
global pipe, unet, tokenizer, text_encoder
|
168 |
+
if seed is None or seed == 0:
|
169 |
+
seed = random.randint(0, 2147483647)
|
170 |
+
if torch.cuda.is_available():
|
171 |
+
generator = torch.Generator("cuda").manual_seed(int(seed))
|
172 |
+
else:
|
173 |
+
generator = torch.Generator().manual_seed(int(seed))
|
174 |
+
|
175 |
+
local_unet = get_model(model)
|
176 |
+
if lora_state is not None and lora_state != "":
|
177 |
+
load_lora_attn_procs(lora_state, local_unet, lora_scale)
|
178 |
+
else:
|
179 |
+
local_unet.set_attn_processor(CrossAttnProcessor())
|
180 |
+
|
181 |
+
pipe.setup_unet(local_unet)
|
182 |
+
sampler_name, sampler_opt = None, None
|
183 |
+
for label, funcname, options in samplers_k_diffusion:
|
184 |
+
if label == sampler:
|
185 |
+
sampler_name, sampler_opt = funcname, options
|
186 |
+
|
187 |
+
if embs is not None and len(embs) > 0:
|
188 |
+
delta_weight = []
|
189 |
+
for name, file in embs.items():
|
190 |
+
if str(file).endswith(".pt"):
|
191 |
+
loaded_learned_embeds = torch.load(file, map_location="cpu")
|
192 |
+
else:
|
193 |
+
loaded_learned_embeds = load_file(file, device="cpu")
|
194 |
+
loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"]
|
195 |
+
added_length = tokenizer.add_tokens(name)
|
196 |
+
|
197 |
+
assert added_length == loaded_learned_embeds.shape[0]
|
198 |
+
delta_weight.append(loaded_learned_embeds)
|
199 |
+
|
200 |
+
delta_weight = torch.cat(delta_weight, dim=0)
|
201 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
202 |
+
text_encoder.get_input_embeddings().weight.data[-delta_weight.shape[0]:] = delta_weight
|
203 |
+
|
204 |
+
config = {
|
205 |
+
"negative_prompt": neg_prompt,
|
206 |
+
"num_inference_steps": int(steps),
|
207 |
+
"guidance_scale": guidance,
|
208 |
+
"generator": generator,
|
209 |
+
"sampler_name": sampler_name,
|
210 |
+
"sampler_opt": sampler_opt,
|
211 |
+
"pww_state": state,
|
212 |
+
"pww_attn_weight": g_strength,
|
213 |
+
}
|
214 |
+
|
215 |
+
if img_input is not None:
|
216 |
+
ratio = min(height / img_input.height, width / img_input.width)
|
217 |
+
img_input = img_input.resize(
|
218 |
+
(int(img_input.width * ratio), int(img_input.height * ratio)), Image.LANCZOS
|
219 |
+
)
|
220 |
+
result = pipe.img2img(prompt, image=img_input, strength=i2i_scale, **config)
|
221 |
+
elif hr_enabled:
|
222 |
+
result = pipe.txt2img(
|
223 |
+
prompt,
|
224 |
+
width=width,
|
225 |
+
height=height,
|
226 |
+
upscale=True,
|
227 |
+
upscale_x=hr_scale,
|
228 |
+
upscale_denoising_strength=hr_denoise,
|
229 |
+
**config,
|
230 |
+
**latent_upscale_modes[hr_method],
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
result = pipe.txt2img(prompt, width=width, height=height, **config)
|
234 |
+
|
235 |
+
# restore
|
236 |
+
if embs is not None and len(embs) > 0:
|
237 |
+
restore_all()
|
238 |
+
return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
|
239 |
+
|
240 |
+
|
241 |
+
color_list = []
|
242 |
+
|
243 |
+
|
244 |
+
def get_color(n):
|
245 |
+
for _ in range(n - len(color_list)):
|
246 |
+
color_list.append(tuple(np.random.random(size=3) * 256))
|
247 |
+
return color_list
|
248 |
+
|
249 |
+
|
250 |
+
def create_mixed_img(current, state, w=512, h=512):
|
251 |
+
w, h = int(w), int(h)
|
252 |
+
image_np = np.full([h, w, 4], 255)
|
253 |
+
colors = get_color(len(state))
|
254 |
+
idx = 0
|
255 |
+
|
256 |
+
for key, item in state.items():
|
257 |
+
if item["map"] is not None:
|
258 |
+
m = item["map"] < 255
|
259 |
+
alpha = 150
|
260 |
+
if current == key:
|
261 |
+
alpha = 200
|
262 |
+
image_np[m] = colors[idx] + (alpha,)
|
263 |
+
idx += 1
|
264 |
+
|
265 |
+
return image_np
|
266 |
+
|
267 |
+
|
268 |
+
# width.change(apply_new_res, inputs=[width, height, global_stats], outputs=[global_stats, sp, rendered])
|
269 |
+
def apply_new_res(w, h, state):
|
270 |
+
w, h = int(w), int(h)
|
271 |
+
|
272 |
+
for key, item in state.items():
|
273 |
+
if item["map"] is not None:
|
274 |
+
item["map"] = resize(item["map"], w, h)
|
275 |
+
|
276 |
+
update_img = gr.Image.update(value=create_mixed_img("", state, w, h))
|
277 |
+
return state, update_img
|
278 |
+
|
279 |
+
|
280 |
+
def detect_text(text, state, width, height):
|
281 |
+
|
282 |
+
t = text.split(",")
|
283 |
+
new_state = {}
|
284 |
+
|
285 |
+
for item in t:
|
286 |
+
item = item.strip()
|
287 |
+
if item == "":
|
288 |
+
continue
|
289 |
+
if item in state:
|
290 |
+
new_state[item] = {
|
291 |
+
"map": state[item]["map"],
|
292 |
+
"weight": state[item]["weight"],
|
293 |
+
}
|
294 |
+
else:
|
295 |
+
new_state[item] = {
|
296 |
+
"map": None,
|
297 |
+
"weight": 0.5,
|
298 |
+
}
|
299 |
+
update = gr.Radio.update(choices=[key for key in new_state.keys()], value=None)
|
300 |
+
update_img = gr.update(value=create_mixed_img("", new_state, width, height))
|
301 |
+
update_sketch = gr.update(value=None, interactive=False)
|
302 |
+
return new_state, update_sketch, update, update_img
|
303 |
+
|
304 |
+
|
305 |
+
def resize(img, w, h):
|
306 |
+
trs = transforms.Compose(
|
307 |
+
[
|
308 |
+
transforms.ToPILImage(),
|
309 |
+
transforms.Resize(min(h, w)),
|
310 |
+
transforms.CenterCrop((h, w)),
|
311 |
+
]
|
312 |
+
)
|
313 |
+
result = np.array(trs(img), dtype=np.uint8)
|
314 |
+
return result
|
315 |
+
|
316 |
+
|
317 |
+
def switch_canvas(entry, state, width, height):
|
318 |
+
if entry == None:
|
319 |
+
return None, 0.5, create_mixed_img("", state, width, height)
|
320 |
+
return (
|
321 |
+
gr.update(value=None, interactive=True),
|
322 |
+
gr.update(value=state[entry]["weight"]),
|
323 |
+
create_mixed_img(entry, state, width, height),
|
324 |
+
)
|
325 |
+
|
326 |
+
|
327 |
+
def apply_canvas(selected, draw, state, w, h):
|
328 |
+
w, h = int(w), int(h)
|
329 |
+
state[selected]["map"] = resize(draw, w, h)
|
330 |
+
return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
|
331 |
+
|
332 |
+
|
333 |
+
def apply_weight(selected, weight, state):
|
334 |
+
state[selected]["weight"] = weight
|
335 |
+
return state
|
336 |
+
|
337 |
+
|
338 |
+
# sp2, radio, width, height, global_stats
|
339 |
+
def apply_image(image, selected, w, h, strgength, state):
|
340 |
+
if selected is not None:
|
341 |
+
state[selected] = {"map": resize(image, w, h), "weight": strgength}
|
342 |
+
return state, gr.Image.update(value=create_mixed_img(selected, state, w, h))
|
343 |
+
|
344 |
+
|
345 |
+
# [ti_state, lora_state, ti_vals, lora_vals, uploads]
|
346 |
+
def add_net(files: list[tempfile._TemporaryFileWrapper], ti_state, lora_state):
|
347 |
+
if files is None:
|
348 |
+
return ti_state, "", lora_state, None
|
349 |
+
|
350 |
+
for file in files:
|
351 |
+
item = Path(file.name)
|
352 |
+
stripedname = str(item.stem).strip()
|
353 |
+
if item.suffix == ".pt":
|
354 |
+
state_dict = torch.load(file.name, map_location="cpu")
|
355 |
+
else:
|
356 |
+
state_dict = load_file(file.name, device="cpu")
|
357 |
+
if any("lora" in k for k in state_dict.keys()):
|
358 |
+
lora_state = file.name
|
359 |
+
else:
|
360 |
+
ti_state[stripedname] = file.name
|
361 |
+
|
362 |
+
return ti_state, lora_state, gr.Text.update(f"{[key for key in ti_state.keys()]}"), gr.Text.update(f"{lora_state}"), gr.Files.update(value=None)
|
363 |
+
|
364 |
+
# [ti_state, lora_state, ti_vals, lora_vals, uploads]
|
365 |
+
def clean_states(ti_state, lora_state):
|
366 |
+
return dict(), None, gr.Text.update(f""), gr.Text.update(f""), gr.File.update(value=None)
|
367 |
+
|
368 |
+
|
369 |
+
latent_upscale_modes = {
|
370 |
+
"Latent": {"upscale_method": "bilinear", "upscale_antialias": False},
|
371 |
+
"Latent (antialiased)": {"upscale_method": "bilinear", "upscale_antialias": True},
|
372 |
+
"Latent (bicubic)": {"upscale_method": "bicubic", "upscale_antialias": False},
|
373 |
+
"Latent (bicubic antialiased)": {
|
374 |
+
"upscale_method": "bicubic",
|
375 |
+
"upscale_antialias": True,
|
376 |
+
},
|
377 |
+
"Latent (nearest)": {"upscale_method": "nearest", "upscale_antialias": False},
|
378 |
+
"Latent (nearest-exact)": {
|
379 |
+
"upscale_method": "nearest-exact",
|
380 |
+
"upscale_antialias": False,
|
381 |
+
},
|
382 |
+
}
|
383 |
+
|
384 |
+
css = """
|
385 |
+
.finetuned-diffusion-div div{
|
386 |
+
display:inline-flex;
|
387 |
+
align-items:center;
|
388 |
+
gap:.8rem;
|
389 |
+
font-size:1.75rem;
|
390 |
+
padding-top:2rem;
|
391 |
+
}
|
392 |
+
.finetuned-diffusion-div div h1{
|
393 |
+
font-weight:900;
|
394 |
+
margin-bottom:7px
|
395 |
+
}
|
396 |
+
.finetuned-diffusion-div p{
|
397 |
+
margin-bottom:10px;
|
398 |
+
font-size:94%
|
399 |
+
}
|
400 |
+
.box {
|
401 |
+
float: left;
|
402 |
+
height: 20px;
|
403 |
+
width: 20px;
|
404 |
+
margin-bottom: 15px;
|
405 |
+
border: 1px solid black;
|
406 |
+
clear: both;
|
407 |
+
}
|
408 |
+
a{
|
409 |
+
text-decoration:underline
|
410 |
+
}
|
411 |
+
.tabs{
|
412 |
+
margin-top:0;
|
413 |
+
margin-bottom:0
|
414 |
+
}
|
415 |
+
#gallery{
|
416 |
+
min-height:20rem
|
417 |
+
}
|
418 |
+
.no-border {
|
419 |
+
border: none !important;
|
420 |
+
}
|
421 |
+
"""
|
422 |
+
with gr.Blocks(css=css) as demo:
|
423 |
+
gr.HTML(
|
424 |
+
f"""
|
425 |
+
<div class="finetuned-diffusion-div">
|
426 |
+
<div>
|
427 |
+
<h1>Demo for diffusion models</h1>
|
428 |
+
</div>
|
429 |
+
<p>Hso @ nyanko.sketch2img.gradio</p>
|
430 |
+
</div>
|
431 |
+
"""
|
432 |
+
)
|
433 |
+
global_stats = gr.State(value={})
|
434 |
+
|
435 |
+
with gr.Row():
|
436 |
+
|
437 |
+
with gr.Column(scale=55):
|
438 |
+
model = gr.Dropdown(
|
439 |
+
choices=[k[0] for k in get_model_list()],
|
440 |
+
label="Model",
|
441 |
+
value=base_name,
|
442 |
+
)
|
443 |
+
image_out = gr.Image(height=512)
|
444 |
+
# gallery = gr.Gallery(
|
445 |
+
# label="Generated images", show_label=False, elem_id="gallery"
|
446 |
+
# ).style(grid=[1], height="auto")
|
447 |
+
|
448 |
+
with gr.Column(scale=45):
|
449 |
+
|
450 |
+
with gr.Group():
|
451 |
+
|
452 |
+
with gr.Row():
|
453 |
+
with gr.Column(scale=70):
|
454 |
+
|
455 |
+
prompt = gr.Textbox(
|
456 |
+
label="Prompt",
|
457 |
+
value="loli cat girl, blue eyes, flat chest, solo, long messy silver hair, blue capelet, garden, cat ears, cat tail, upper body",
|
458 |
+
show_label=True,
|
459 |
+
max_lines=4,
|
460 |
+
placeholder="Enter prompt.",
|
461 |
+
)
|
462 |
+
neg_prompt = gr.Textbox(
|
463 |
+
label="Negative Prompt",
|
464 |
+
value="bad quality, low quality, jpeg artifact, cropped",
|
465 |
+
show_label=True,
|
466 |
+
max_lines=4,
|
467 |
+
placeholder="Enter negative prompt.",
|
468 |
+
)
|
469 |
+
|
470 |
+
generate = gr.Button(value="Generate").style(
|
471 |
+
rounded=(False, True, True, False)
|
472 |
+
)
|
473 |
+
|
474 |
+
with gr.Tab("Options"):
|
475 |
+
|
476 |
+
with gr.Group():
|
477 |
+
|
478 |
+
# n_images = gr.Slider(label="Images", value=1, minimum=1, maximum=4, step=1)
|
479 |
+
with gr.Row():
|
480 |
+
guidance = gr.Slider(
|
481 |
+
label="Guidance scale", value=7.5, maximum=15
|
482 |
+
)
|
483 |
+
steps = gr.Slider(
|
484 |
+
label="Steps", value=25, minimum=2, maximum=75, step=1
|
485 |
+
)
|
486 |
+
|
487 |
+
with gr.Row():
|
488 |
+
width = gr.Slider(
|
489 |
+
label="Width", value=512, minimum=64, maximum=2048, step=64
|
490 |
+
)
|
491 |
+
height = gr.Slider(
|
492 |
+
label="Height", value=512, minimum=64, maximum=2048, step=64
|
493 |
+
)
|
494 |
+
|
495 |
+
sampler = gr.Dropdown(
|
496 |
+
value="DPM++ 2M Karras",
|
497 |
+
label="Sampler",
|
498 |
+
choices=[s[0] for s in samplers_k_diffusion],
|
499 |
+
)
|
500 |
+
seed = gr.Number(label="Seed (0 = random)", value=0)
|
501 |
+
|
502 |
+
with gr.Tab("Image to image"):
|
503 |
+
with gr.Group():
|
504 |
+
|
505 |
+
inf_image = gr.Image(
|
506 |
+
label="Image", height=256, tool="editor", type="pil"
|
507 |
+
)
|
508 |
+
inf_strength = gr.Slider(
|
509 |
+
label="Transformation strength",
|
510 |
+
minimum=0,
|
511 |
+
maximum=1,
|
512 |
+
step=0.01,
|
513 |
+
value=0.5,
|
514 |
+
)
|
515 |
+
|
516 |
+
def res_cap(g, w, h, x):
|
517 |
+
if g:
|
518 |
+
return f"Enable upscaler: {w}x{h} to {int(w*x)}x{int(h*x)}"
|
519 |
+
else:
|
520 |
+
return "Enable upscaler"
|
521 |
+
|
522 |
+
with gr.Tab("Hires fix"):
|
523 |
+
with gr.Group():
|
524 |
+
|
525 |
+
hr_enabled = gr.Checkbox(label="Enable upscaler", value=False)
|
526 |
+
hr_method = gr.Dropdown(
|
527 |
+
[key for key in latent_upscale_modes.keys()],
|
528 |
+
value="Latent",
|
529 |
+
label="Upscale method",
|
530 |
+
)
|
531 |
+
hr_scale = gr.Slider(
|
532 |
+
label="Upscale factor",
|
533 |
+
minimum=1.0,
|
534 |
+
maximum=3,
|
535 |
+
step=0.1,
|
536 |
+
value=1.5,
|
537 |
+
)
|
538 |
+
hr_denoise = gr.Slider(
|
539 |
+
label="Denoising strength",
|
540 |
+
minimum=0.0,
|
541 |
+
maximum=1.0,
|
542 |
+
step=0.1,
|
543 |
+
value=0.8,
|
544 |
+
)
|
545 |
+
|
546 |
+
hr_scale.change(
|
547 |
+
lambda g, x, w, h: gr.Checkbox.update(
|
548 |
+
label=res_cap(g, w, h, x)
|
549 |
+
),
|
550 |
+
inputs=[hr_enabled, hr_scale, width, height],
|
551 |
+
outputs=hr_enabled,
|
552 |
+
)
|
553 |
+
hr_enabled.change(
|
554 |
+
lambda g, x, w, h: gr.Checkbox.update(
|
555 |
+
label=res_cap(g, w, h, x)
|
556 |
+
),
|
557 |
+
inputs=[hr_enabled, hr_scale, width, height],
|
558 |
+
outputs=hr_enabled,
|
559 |
+
)
|
560 |
+
|
561 |
+
with gr.Tab("Embeddings/Loras"):
|
562 |
+
|
563 |
+
ti_state = gr.State(dict())
|
564 |
+
lora_state = gr.State()
|
565 |
+
|
566 |
+
with gr.Group():
|
567 |
+
with gr.Row():
|
568 |
+
with gr.Column(scale=90):
|
569 |
+
ti_vals = gr.Text(label="Loaded embeddings")
|
570 |
+
|
571 |
+
with gr.Row():
|
572 |
+
with gr.Column(scale=90):
|
573 |
+
lora_vals = gr.Text(label="Loaded loras")
|
574 |
+
|
575 |
+
with gr.Row():
|
576 |
+
|
577 |
+
uploads = gr.Files(label="Upload new embeddings/lora")
|
578 |
+
|
579 |
+
with gr.Column():
|
580 |
+
lora_scale = gr.Slider(
|
581 |
+
label="Lora scale",
|
582 |
+
minimum=0,
|
583 |
+
maximum=2,
|
584 |
+
step=0.01,
|
585 |
+
value=1.0,
|
586 |
+
)
|
587 |
+
btn = gr.Button(value="Upload")
|
588 |
+
btn_del = gr.Button(value="Reset")
|
589 |
+
|
590 |
+
btn.click(
|
591 |
+
add_net, inputs=[uploads, ti_state, lora_state], outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads]
|
592 |
+
)
|
593 |
+
btn_del.click(
|
594 |
+
clean_states, inputs=[ti_state, lora_state], outputs=[ti_state, lora_state, ti_vals, lora_vals, uploads]
|
595 |
+
)
|
596 |
+
|
597 |
+
# error_output = gr.Markdown()
|
598 |
+
|
599 |
+
gr.HTML(
|
600 |
+
f"""
|
601 |
+
<div class="finetuned-diffusion-div">
|
602 |
+
<div>
|
603 |
+
<h1>Paint with words</h1>
|
604 |
+
</div>
|
605 |
+
<p>
|
606 |
+
Will use the following formula: w = scale * token_weight_martix * log(1 + sigma) * max(qk).
|
607 |
+
</p>
|
608 |
+
</div>
|
609 |
+
"""
|
610 |
+
)
|
611 |
+
|
612 |
+
with gr.Row():
|
613 |
+
|
614 |
+
with gr.Column(scale=55):
|
615 |
+
|
616 |
+
rendered = gr.Image(
|
617 |
+
invert_colors=True,
|
618 |
+
source="canvas",
|
619 |
+
interactive=False,
|
620 |
+
image_mode="RGBA",
|
621 |
+
)
|
622 |
+
|
623 |
+
with gr.Column(scale=45):
|
624 |
+
|
625 |
+
with gr.Group():
|
626 |
+
with gr.Row():
|
627 |
+
with gr.Column(scale=70):
|
628 |
+
g_strength = gr.Slider(
|
629 |
+
label="Weight scaling",
|
630 |
+
minimum=0,
|
631 |
+
maximum=0.8,
|
632 |
+
step=0.01,
|
633 |
+
value=0.4,
|
634 |
+
)
|
635 |
+
|
636 |
+
text = gr.Textbox(
|
637 |
+
lines=2,
|
638 |
+
interactive=True,
|
639 |
+
label="Token to Draw: (Separate by comma)",
|
640 |
+
)
|
641 |
+
|
642 |
+
radio = gr.Radio([], label="Tokens")
|
643 |
+
|
644 |
+
sk_update = gr.Button(value="Update").style(
|
645 |
+
rounded=(False, True, True, False)
|
646 |
+
)
|
647 |
+
|
648 |
+
# g_strength.change(lambda b: gr.update(f"Scaled additional attn: $w = {b} \log (1 + \sigma) \std (Q^T K)$."), inputs=g_strength, outputs=[g_output])
|
649 |
+
|
650 |
+
with gr.Tab("SketchPad"):
|
651 |
+
|
652 |
+
sp = gr.Image(
|
653 |
+
image_mode="L",
|
654 |
+
tool="sketch",
|
655 |
+
source="canvas",
|
656 |
+
interactive=False,
|
657 |
+
)
|
658 |
+
|
659 |
+
strength = gr.Slider(
|
660 |
+
label="Token strength",
|
661 |
+
minimum=0,
|
662 |
+
maximum=0.8,
|
663 |
+
step=0.01,
|
664 |
+
value=0.5,
|
665 |
+
)
|
666 |
+
|
667 |
+
sk_update.click(
|
668 |
+
detect_text,
|
669 |
+
inputs=[text, global_stats, width, height],
|
670 |
+
outputs=[global_stats, sp, radio, rendered],
|
671 |
+
)
|
672 |
+
radio.change(
|
673 |
+
switch_canvas,
|
674 |
+
inputs=[radio, global_stats, width, height],
|
675 |
+
outputs=[sp, strength, rendered],
|
676 |
+
)
|
677 |
+
sp.edit(
|
678 |
+
apply_canvas,
|
679 |
+
inputs=[radio, sp, global_stats, width, height],
|
680 |
+
outputs=[global_stats, rendered],
|
681 |
+
)
|
682 |
+
strength.change(
|
683 |
+
apply_weight,
|
684 |
+
inputs=[radio, strength, global_stats],
|
685 |
+
outputs=[global_stats],
|
686 |
+
)
|
687 |
+
|
688 |
+
with gr.Tab("UploadFile"):
|
689 |
+
|
690 |
+
sp2 = gr.Image(
|
691 |
+
image_mode="L",
|
692 |
+
source="upload",
|
693 |
+
shape=(512, 512),
|
694 |
+
)
|
695 |
+
|
696 |
+
strength2 = gr.Slider(
|
697 |
+
label="Token strength",
|
698 |
+
minimum=0,
|
699 |
+
maximum=0.8,
|
700 |
+
step=0.01,
|
701 |
+
value=0.5,
|
702 |
+
)
|
703 |
+
|
704 |
+
apply_style = gr.Button(value="Apply")
|
705 |
+
apply_style.click(
|
706 |
+
apply_image,
|
707 |
+
inputs=[sp2, radio, width, height, strength2, global_stats],
|
708 |
+
outputs=[global_stats, rendered],
|
709 |
+
)
|
710 |
+
|
711 |
+
width.change(
|
712 |
+
apply_new_res,
|
713 |
+
inputs=[width, height, global_stats],
|
714 |
+
outputs=[global_stats, rendered],
|
715 |
+
)
|
716 |
+
height.change(
|
717 |
+
apply_new_res,
|
718 |
+
inputs=[width, height, global_stats],
|
719 |
+
outputs=[global_stats, rendered],
|
720 |
+
)
|
721 |
+
|
722 |
+
# color_stats = gr.State(value={})
|
723 |
+
# text.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
|
724 |
+
# sp.change(detect_color, inputs=[sp, text, color_stats], outputs=[color_stats, rendered])
|
725 |
+
|
726 |
+
inputs = [
|
727 |
+
prompt,
|
728 |
+
guidance,
|
729 |
+
steps,
|
730 |
+
width,
|
731 |
+
height,
|
732 |
+
seed,
|
733 |
+
neg_prompt,
|
734 |
+
global_stats,
|
735 |
+
g_strength,
|
736 |
+
inf_image,
|
737 |
+
inf_strength,
|
738 |
+
hr_enabled,
|
739 |
+
hr_method,
|
740 |
+
hr_scale,
|
741 |
+
hr_denoise,
|
742 |
+
sampler,
|
743 |
+
ti_state,
|
744 |
+
model,
|
745 |
+
lora_state,
|
746 |
+
lora_scale
|
747 |
+
]
|
748 |
+
outputs = [image_out]
|
749 |
+
prompt.submit(inference, inputs=inputs, outputs=outputs)
|
750 |
+
generate.click(inference, inputs=inputs, outputs=outputs)
|
751 |
+
|
752 |
+
print(f"Space built in {time.time() - start_time:.2f} seconds")
|
753 |
+
# demo.launch(share=True)
|
754 |
+
demo.launch(share=True, enable_queue=True)
|