primerz commited on
Commit
5235988
·
verified ·
1 Parent(s): 8be8f68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -549
app.py CHANGED
@@ -1,606 +1,166 @@
1
- import gradio as gr
2
- import torch
3
- import spaces
4
- torch.jit.script = lambda f: f
5
- import timm
6
  import time
7
-
8
- from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
9
- from safetensors.torch import load_file
10
- from share_btn import community_icon_html, loading_icon_html, share_js
11
- from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
12
-
13
- import lora
14
- import copy
15
  import json
16
- import gc
17
  import random
18
- from urllib.parse import quote
19
- import gdown
20
- import os
21
- import re
22
  import requests
 
 
 
 
 
 
 
 
 
 
23
 
 
 
24
  import diffusers
25
  from diffusers.utils import load_image
26
  from diffusers.models import ControlNetModel
27
  from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel
28
- import cv2
29
- import torch
30
- import numpy as np
31
- from PIL import Image
32
-
33
  from insightface.app import FaceAnalysis
34
- from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
35
  from controlnet_aux import ZoeDetector
36
-
37
  from compel import Compel, ReturnedEmbeddingsType
38
-
39
  from gradio_imageslider import ImageSlider
40
 
 
 
 
 
 
 
 
41
 
42
- #from gradio_imageslider import ImageSlider
 
43
 
 
44
  with open("sdxl_loras.json", "r") as file:
45
- data = json.load(file)
46
- sdxl_loras_raw = [
47
- {
48
- "image": item["image"],
49
- "title": item["title"],
50
- "repo": item["repo"],
51
- "trigger_word": item["trigger_word"],
52
- "weights": item["weights"],
53
- "is_compatible": item["is_compatible"],
54
- "is_pivotal": item.get("is_pivotal", False),
55
- "text_embedding_weights": item.get("text_embedding_weights", None),
56
- "likes": item.get("likes", 0),
57
- "downloads": item.get("downloads", 0),
58
- "is_nc": item.get("is_nc", False),
59
- "new": item.get("new", False),
60
- }
61
- for item in data
62
- ]
63
 
64
  with open("defaults_data.json", "r") as file:
65
  lora_defaults = json.load(file)
66
-
67
-
68
- device = "cuda"
69
-
70
- state_dicts = {}
71
-
72
- for item in sdxl_loras_raw:
73
- saved_name = hf_hub_download(item["repo"], item["weights"])
74
-
75
- if not saved_name.endswith('.safetensors'):
76
- state_dict = torch.load(saved_name)
77
- else:
78
- state_dict = load_file(saved_name)
79
-
80
- state_dicts[item["repo"]] = {
81
- "saved_name": saved_name,
82
- "state_dict": state_dict
83
- }
84
 
85
- sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
86
-
87
- # download models
88
- hf_hub_download(
89
- repo_id="InstantX/InstantID",
90
- filename="ControlNetModel/config.json",
91
- local_dir="/data/checkpoints",
92
- )
93
- hf_hub_download(
94
- repo_id="InstantX/InstantID",
95
- filename="ControlNetModel/diffusion_pytorch_model.safetensors",
96
- local_dir="/data/checkpoints",
97
- )
98
- hf_hub_download(
99
- repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints"
100
- )
101
- hf_hub_download(
102
- repo_id="latent-consistency/lcm-lora-sdxl",
103
- filename="pytorch_lora_weights.safetensors",
104
- local_dir="/data/checkpoints",
105
- )
106
- # download antelopev2
107
- #if not os.path.exists("/data/antelopev2.zip"):
108
- # gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="/data/", quiet=False, fuzzy=True)
109
- # os.system("unzip /data/antelopev2.zip -d /data/models/")
110
 
 
111
  antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
112
- print(antelope_download)
113
- app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
 
 
114
  app.prepare(ctx_id=0, det_size=(640, 640))
115
 
116
- # prepare models under ./checkpoints
117
- face_adapter = f'/data/checkpoints/ip-adapter.bin'
118
- controlnet_path = f'/data/checkpoints/ControlNetModel'
119
 
120
- # load IdentityNet
121
- st = time.time()
122
  identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
123
- zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0",torch_dtype=torch.float16)
124
- et = time.time()
125
- elapsed_time = et - st
126
- print('Loading ControlNet took: ', elapsed_time, 'seconds')
127
- st = time.time()
128
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
129
- et = time.time()
130
- elapsed_time = et - st
131
- print('Loading VAE took: ', elapsed_time, 'seconds')
132
- st = time.time()
133
 
134
- #pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("stablediffusionapi/albedobase-xl-v21",
135
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("frankjoshua/albedobaseXL_v21",
136
- vae=vae,
137
- controlnet=[identitynet, zoedepthnet],
138
- torch_dtype=torch.float16)
 
 
139
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
140
  pipe.load_ip_adapter_instantid(face_adapter)
141
  pipe.set_ip_adapter_scale(0.8)
142
- et = time.time()
143
- elapsed_time = et - st
144
- print('Loading pipeline took: ', elapsed_time, 'seconds')
145
- st = time.time()
146
- compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
147
- et = time.time()
148
- elapsed_time = et - st
149
- print('Loading Compel took: ', elapsed_time, 'seconds')
150
 
151
- st = time.time()
 
 
 
 
 
 
 
 
152
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
153
- et = time.time()
154
- elapsed_time = et - st
155
- print('Loading Zoe took: ', elapsed_time, 'seconds')
156
  zoe.to(device)
157
  pipe.to(device)
158
 
 
159
  last_lora = ""
160
  last_fused = False
161
- js = '''
162
- var button = document.getElementById('button');
163
- // Add a click event listener to the button
164
- button.addEventListener('click', function() {
165
- element.classList.add('selected');
166
- });
167
- '''
168
- lora_archive = "/data"
169
 
170
- def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
171
- lora_repo = sdxl_loras[selected_state.index]["repo"]
172
- new_placeholder = "Type a prompt to use your selected LoRA"
173
- weight_name = sdxl_loras[selected_state.index]["weights"]
174
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
175
 
176
  for lora_list in lora_defaults:
177
- if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
178
  face_strength = lora_list.get("face_strength", 0.85)
179
  image_strength = lora_list.get("image_strength", 0.15)
180
  weight = lora_list.get("weight", 0.9)
181
  depth_control_scale = lora_list.get("depth_control_scale", 0.8)
182
  negative = lora_list.get("negative", "")
183
-
184
- if(is_new):
185
- if(selected_state.index == 0):
186
- selected_state.index = -9999
187
- else:
188
- selected_state.index *= -1
189
-
190
  return (
191
- updated_text,
192
- gr.update(placeholder=new_placeholder),
193
- face_strength,
194
- image_strength,
195
- weight,
196
- depth_control_scale,
197
- negative,
198
- selected_state
199
  )
200
 
201
- def center_crop_image_as_square(img):
202
  square_size = min(img.size)
203
-
204
- left = (img.width - square_size) / 2
205
- top = (img.height - square_size) / 2
206
- right = (img.width + square_size) / 2
207
- bottom = (img.height + square_size) / 2
208
-
209
- img_cropped = img.crop((left, top, right, bottom))
210
- return img_cropped
211
-
212
- def check_selected(selected_state, custom_lora):
213
- if not selected_state and not custom_lora:
214
- raise gr.Error("You must select a style")
215
-
216
- def merge_incompatible_lora(full_path_lora, lora_scale):
217
- for weights_file in [full_path_lora]:
218
- if ";" in weights_file:
219
- weights_file, multiplier = weights_file.split(";")
220
- multiplier = float(multiplier)
221
- else:
222
- multiplier = lora_scale
223
-
224
- lora_model, weights_sd = lora.create_network_from_weights(
225
- multiplier,
226
- full_path_lora,
227
- pipe.vae,
228
- pipe.text_encoder,
229
- pipe.unet,
230
- for_inference=True,
231
- )
232
- lora_model.merge_to(
233
- pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
234
- )
235
- del weights_sd
236
- del lora_model
237
-
238
- @spaces.GPU(duration=80)
239
- def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st):
240
- print(loaded_state_dict)
241
- et = time.time()
242
- elapsed_time = et - st
243
- print('Getting into the decorated function took: ', elapsed_time, 'seconds')
244
  global last_fused, last_lora
245
- print("Last LoRA: ", last_lora)
246
- print("Current LoRA: ", repo_name)
247
- print("Last fused: ", last_fused)
248
- #prepare face zoe
249
- st = time.time()
250
- with torch.no_grad():
251
- image_zoe = zoe(face_image)
252
- width, height = face_kps.size
253
- images = [face_kps, image_zoe.resize((height, width))]
254
- et = time.time()
255
- elapsed_time = et - st
256
- print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
257
- if last_lora != repo_name:
258
- if(last_fused):
259
- st = time.time()
260
- pipe.unfuse_lora()
261
- pipe.unload_lora_weights()
262
- pipe.unload_textual_inversion()
263
- et = time.time()
264
- elapsed_time = et - st
265
- print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
266
- st = time.time()
267
- pipe.load_lora_weights(loaded_state_dict)
268
- pipe.fuse_lora(lora_scale)
269
- et = time.time()
270
- elapsed_time = et - st
271
- print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
272
- last_fused = True
273
- is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
274
- if(is_pivotal):
275
- #Add the textual inversion embeddings from pivotal tuning models
276
- text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
277
- embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
278
- state_dict_embedding = load_file(embedding_path)
279
- pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
280
- pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
281
-
282
- print("Processing prompt...")
283
- st = time.time()
284
  conditioning, pooled = compel(prompt)
285
- if(negative):
286
- negative_conditioning, negative_pooled = compel(negative)
287
- else:
288
- negative_conditioning, negative_pooled = None, None
289
- et = time.time()
290
- elapsed_time = et - st
291
- print('Prompt processing took: ', elapsed_time, 'seconds')
292
- print("Processing image...")
293
- st = time.time()
294
- image = pipe(
295
- prompt_embeds=conditioning,
296
- pooled_prompt_embeds=pooled,
297
- negative_prompt_embeds=negative_conditioning,
298
- negative_pooled_prompt_embeds=negative_pooled,
299
- width=1024,
300
- height=1024,
301
- image_embeds=face_emb,
302
- image=face_image,
303
- strength=1-image_strength,
304
- control_image=images,
305
- num_inference_steps=20,
306
- guidance_scale = guidance_scale,
307
- controlnet_conditioning_scale=[face_strength, depth_control_scale],
308
- ).images[0]
309
- et = time.time()
310
- elapsed_time = et - st
311
- print('Image processing took: ', elapsed_time, 'seconds')
312
- last_lora = repo_name
313
- return image
314
 
315
- def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora, progress=gr.Progress(track_tqdm=True)):
316
- print("Custom LoRA: ", custom_lora)
317
- custom_lora_path = custom_lora[0] if custom_lora else None
318
- selected_state_index = selected_state.index if selected_state else -1
319
- st = time.time()
320
- face_image = center_crop_image_as_square(face_image)
321
- try:
322
- face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
323
- face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
324
- face_emb = face_info['embedding']
325
- face_kps = draw_kps(face_image, face_info['kps'])
326
- except:
327
- raise gr.Error("No face found in your image. Only face images work here. Try again")
328
- et = time.time()
329
- elapsed_time = et - st
330
- print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
331
-
332
- st = time.time()
333
-
334
- if(custom_lora_path and custom_lora[1]):
335
- prompt = f"{prompt} {custom_lora[1]}"
336
- else:
337
- for lora_list in lora_defaults:
338
- if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
339
- prompt_full = lora_list.get("prompt", None)
340
- if(prompt_full):
341
- prompt = prompt_full.replace("<subject>", prompt)
342
-
343
- print("Prompt:", prompt)
344
- if(prompt == ""):
345
- prompt = "a person"
346
- print(f"Executing prompt: {prompt}")
347
- #print("Selected State: ", selected_state_index)
348
- #print(sdxl_loras[selected_state_index]["repo"])
349
- if negative == "":
350
- negative = None
351
- print("Custom Loaded LoRA: ", custom_lora_path)
352
- if not selected_state and not custom_lora_path:
353
- raise gr.Error("You must select a style")
354
- elif custom_lora_path:
355
- repo_name = custom_lora_path
356
- full_path_lora = custom_lora_path
357
- else:
358
- repo_name = sdxl_loras[selected_state_index]["repo"]
359
- weight_name = sdxl_loras[selected_state_index]["weights"]
360
- full_path_lora = state_dicts[repo_name]["saved_name"]
361
- print("Full path LoRA ", full_path_lora)
362
- #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
363
- cross_attention_kwargs = None
364
- et = time.time()
365
- elapsed_time = et - st
366
- print('Small content processing took: ', elapsed_time, 'seconds')
367
-
368
- st = time.time()
369
- image = generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, full_path_lora, lora_scale, sdxl_loras, selected_state_index, st)
370
- return (face_image, image), gr.update(visible=True)
371
-
372
- run_lora.zerogpu = True
373
-
374
- def shuffle_gallery(sdxl_loras):
375
- random.shuffle(sdxl_loras)
376
- return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
377
-
378
- def classify_gallery(sdxl_loras):
379
- sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
380
- return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
381
-
382
- def swap_gallery(order, sdxl_loras):
383
- if(order == "random"):
384
- return shuffle_gallery(sdxl_loras)
385
- else:
386
- return classify_gallery(sdxl_loras)
387
-
388
- def deselect():
389
- return gr.Gallery(selected_index=None)
390
-
391
- def get_huggingface_safetensors(link):
392
- split_link = link.split("/")
393
- if(len(split_link) == 2):
394
- model_card = ModelCard.load(link)
395
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
396
- trigger_word = model_card.data.get("instance_prompt", "")
397
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
398
- fs = HfFileSystem()
399
- try:
400
- list_of_files = fs.ls(link, detail=False)
401
- for file in list_of_files:
402
- if(file.endswith(".safetensors")):
403
- safetensors_name = file.replace("/", "_")
404
- if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
405
- fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
406
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
407
- image_elements = file.split("/")
408
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
409
- except:
410
- gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
411
- raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
412
- return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
413
-
414
- def get_civitai_safetensors(link):
415
- link_split = link.split("civitai.com/")
416
- pattern = re.compile(r'models\/(\d+)')
417
- regex_match = pattern.search(link_split[1])
418
- if(regex_match):
419
- civitai_model_id = regex_match.group(1)
420
- else:
421
- gr.Warning("No CivitAI model id found in your URL")
422
- raise Exception("No CivitAI model id found in your URL")
423
- model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token={os.getenv('CIVITAI_TOKEN')}"
424
- x = requests.get(model_request_url)
425
- if(x.status_code != 200):
426
- raise Exception("Invalid CivitAI URL")
427
- model_data = x.json()
428
- #if(model_data["nsfw"] == True or model_data["nsfwLevel"] > 20):
429
- # gr.Warning("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
430
- # raise Exception("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
431
- if(model_data["type"] != "LORA"):
432
- gr.Warning("The model isn't tagged at CivitAI as a LoRA")
433
- raise Exception("The model isn't tagged at CivitAI as a LoRA")
434
- model_link_download = None
435
- image_url = None
436
- trigger_word = ""
437
- for model in model_data["modelVersions"]:
438
- if(model["baseModel"] == "SDXL 1.0"):
439
- model_link_download = f"{model['downloadUrl']}/?token={os.getenv('CIVITAI_TOKEN')}"
440
- safetensors_name = model["files"][0]["name"]
441
- if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
442
- safetensors_file_request = requests.get(model_link_download)
443
- if(safetensors_file_request.status_code != 200):
444
- raise Exception("Invalid CivitAI download link")
445
- with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
446
- file.write(safetensors_file_request.content)
447
- trigger_word = model.get("trainedWords", [""])[0]
448
- for image in model["images"]:
449
- if(image["nsfwLevel"] == 1):
450
- image_url = image["url"]
451
- break
452
- break
453
- if(not model_link_download):
454
- gr.Warning("We couldn't find a SDXL LoRA on the model you've sent")
455
- raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
456
- return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
457
-
458
- def check_custom_model(link):
459
- if(link.startswith("https://")):
460
- if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
461
- link_split = link.split("huggingface.co/")
462
- return get_huggingface_safetensors(link_split[1])
463
- elif(link.startswith("https://civitai.com") or link.startswith("https://www.civitai.com")):
464
- return get_civitai_safetensors(link)
465
- else:
466
- return get_huggingface_safetensors(link)
467
-
468
- def show_loading_widget():
469
- return gr.update(visible=True)
470
 
471
- def load_custom_lora(link):
472
- if(link):
473
- try:
474
- title, path, trigger_word, image = check_custom_model(link)
475
- card = f'''
476
- <div class="custom_lora_card">
477
- <span>Loaded custom LoRA:</span>
478
- <div class="card_internal">
479
- <img src="{image}" />
480
- <div>
481
- <h3>{title}</h3>
482
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
483
- </div>
484
- </div>
485
- </div>
486
- '''
487
- return gr.update(visible=True), card, gr.update(visible=True), [path, trigger_word], gr.Gallery(selected_index=None), f"Custom: {path}"
488
- except Exception as e:
489
- gr.Warning("Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content")
490
- return gr.update(visible=True), "Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
491
- else:
492
- return gr.update(visible=False), "", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
493
 
494
- def remove_custom_lora():
495
- return "", gr.update(visible=False), gr.update(visible=False), None
496
- with gr.Blocks(css="custom.css") as demo:
497
- gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
498
- title = gr.HTML(
499
- """<h1><img src="https://i.imgur.com/DVoGw04.png">
500
- <span>Face to All<br><small style="
501
- font-size: 13px;
502
- display: block;
503
- font-weight: normal;
504
- opacity: 0.75;
505
- ">🧨 diffusers InstantID + ControlNet<br> inspired by fofr's <a href="https://github.com/fofr/cog-face-to-many" target="_blank">face-to-many</a></small></span></h1>""",
506
- elem_id="title",
507
- )
508
- selected_state = gr.State()
509
- custom_loaded_lora = gr.State()
510
- with gr.Row(elem_id="main_app"):
511
- with gr.Column(scale=4, elem_id="box_column"):
512
- with gr.Group(elem_id="gallery_box"):
513
- photo = gr.Image(label="Upload a picture of yourself", interactive=True, type="pil", height=300)
514
- selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected", )
515
- #order_gallery = gr.Radio(choices=["random", "likes"], value="random", label="Order by", elem_id="order_radio")
516
- #new_gallery = gr.Gallery(
517
- # label="New LoRAs",
518
- # elem_id="gallery_new",
519
- # columns=3,
520
- # value=[(item["image"], item["title"]) for item in sdxl_loras_raw_new], allow_preview=False, show_share_button=False)
521
- gallery = gr.Gallery(
522
- #value=[(item["image"], item["title"]) for item in sdxl_loras],
523
- label="Pick a style from the gallery",
524
- allow_preview=False,
525
- columns=4,
526
- elem_id="gallery",
527
- show_share_button=False,
528
- height=550
529
- )
530
- custom_model = gr.Textbox(label="or enter a custom Hugging Face or CivitAI SDXL LoRA", placeholder="Paste Hugging Face or CivitAI model path...")
531
- custom_model_card = gr.HTML(visible=False)
532
- custom_model_button = gr.Button("Remove custom LoRA", visible=False)
533
- with gr.Column(scale=5):
534
- with gr.Row():
535
- prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
536
- button = gr.Button("Run", elem_id="run_button")
537
- result = ImageSlider(
538
- interactive=False, label="Generated Image", elem_id="result-image", position=0.1
539
- )
540
- with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
541
- community_icon = gr.HTML(community_icon_html)
542
- loading_icon = gr.HTML(loading_icon_html)
543
- share_button = gr.Button("Share to community", elem_id="share-btn")
544
- with gr.Accordion("Advanced options", open=False):
545
- negative = gr.Textbox(label="Negative Prompt")
546
- weight = gr.Slider(0, 10, value=0.9, step=0.1, label="LoRA weight")
547
- face_strength = gr.Slider(0, 2, value=0.85, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models")
548
- image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo")
549
- guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale")
550
- depth_control_scale = gr.Slider(0, 1, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght")
551
- prompt_title = gr.Markdown(
552
- value="### Click on a LoRA in the gallery to select it",
553
- visible=True,
554
- elem_id="selected_lora",
555
- )
556
- #order_gallery.change(
557
- # fn=swap_gallery,
558
- # inputs=[order_gallery, gr_sdxl_loras],
559
- # outputs=[gallery, gr_sdxl_loras],
560
- # queue=False
561
- #)
562
- custom_model.input(
563
- fn=load_custom_lora,
564
- inputs=[custom_model],
565
- outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title],
566
- )
567
- custom_model_button.click(
568
- fn=remove_custom_lora,
569
- outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora]
570
- )
571
- gallery.select(
572
- fn=update_selection,
573
- inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
574
- outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
575
- show_progress=False
576
- )
577
- #new_gallery.select(
578
- # fn=update_selection,
579
- # inputs=[gr_sdxl_loras_new, gr.State(True)],
580
- # outputs=[prompt_title, prompt, prompt, selected_state, gallery],
581
- # queue=False,
582
- # show_progress=False
583
- #)
584
- prompt.submit(
585
- fn=check_selected,
586
- inputs=[selected_state, custom_loaded_lora],
587
- show_progress=False
588
- ).success(
589
- fn=run_lora,
590
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
591
- outputs=[result, share_group],
592
- )
593
- button.click(
594
- fn=check_selected,
595
- inputs=[selected_state, custom_loaded_lora],
596
- show_progress=False
597
- ).success(
598
- fn=run_lora,
599
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
600
- outputs=[result, share_group],
601
- )
602
- share_button.click(None, [], [], js=share_js)
603
- demo.load(fn=classify_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], js=js)
604
 
605
- demo.queue(default_concurrency_limit=None, api_open=True)
606
- demo.launch(share=True)
 
1
+ import os
2
+ import re
 
 
 
3
  import time
 
 
 
 
 
 
 
 
4
  import json
5
+ import copy
6
  import random
 
 
 
 
7
  import requests
8
+ import torch
9
+ import cv2
10
+ import numpy as np
11
+ import gradio as gr
12
+ import spaces
13
+ from PIL import Image
14
+ from urllib.parse import quote
15
+
16
+ # Disable Torch JIT compilation for compatibility
17
+ torch.jit.script = lambda f: f
18
 
19
+ # Model & Utilities
20
+ import timm
21
  import diffusers
22
  from diffusers.utils import load_image
23
  from diffusers.models import ControlNetModel
24
  from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel
25
+ from safetensors.torch import load_file
26
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
 
 
 
27
  from insightface.app import FaceAnalysis
 
28
  from controlnet_aux import ZoeDetector
 
29
  from compel import Compel, ReturnedEmbeddingsType
 
30
  from gradio_imageslider import ImageSlider
31
 
32
+ # Custom imports
33
+ try:
34
+ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
35
+ from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
36
+ except ImportError as e:
37
+ print(f"Import Error: {e}. Check if modules exist or paths are correct.")
38
+ exit()
39
 
40
+ # Device setup
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
43
+ # Load LoRA configuration
44
  with open("sdxl_loras.json", "r") as file:
45
+ sdxl_loras_raw = json.load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  with open("defaults_data.json", "r") as file:
48
  lora_defaults = json.load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Download required models
51
+ CHECKPOINT_DIR = "/data/checkpoints"
52
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=CHECKPOINT_DIR)
53
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINT_DIR)
54
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=CHECKPOINT_DIR)
55
+ hf_hub_download(repo_id="latent-consistency/lcm-lora-sdxl", filename="pytorch_lora_weights.safetensors", local_dir=CHECKPOINT_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Download Antelopev2 Face Recognition model
58
  antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
59
+ print("Antelopev2 Download Path:", antelope_download)
60
+
61
+ # Initialize FaceAnalysis
62
+ app = FaceAnalysis(name="antelopev2", root="/data", providers=["CPUExecutionProvider"])
63
  app.prepare(ctx_id=0, det_size=(640, 640))
64
 
65
+ # Load identity & depth models
66
+ face_adapter = os.path.join(CHECKPOINT_DIR, "ip-adapter.bin")
67
+ controlnet_path = os.path.join(CHECKPOINT_DIR, "ControlNetModel")
68
 
 
 
69
  identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
70
+ zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16)
71
+
 
 
 
72
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
 
 
 
 
73
 
74
+ # Load main pipeline
75
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
76
+ "frankjoshua/albedobaseXL_v21",
77
+ vae=vae,
78
+ controlnet=[identitynet, zoedepthnet],
79
+ torch_dtype=torch.float16
80
+ )
81
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
82
  pipe.load_ip_adapter_instantid(face_adapter)
83
  pipe.set_ip_adapter_scale(0.8)
 
 
 
 
 
 
 
 
84
 
85
+ # Initialize Compel for text conditioning
86
+ compel = Compel(
87
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
88
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
89
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
90
+ requires_pooled=[False, True]
91
+ )
92
+
93
+ # Load ZoeDetector for depth estimation
94
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
 
 
 
95
  zoe.to(device)
96
  pipe.to(device)
97
 
98
+ # LoRA Management
99
  last_lora = ""
100
  last_fused = False
 
 
 
 
 
 
 
 
101
 
102
+ # --- Utility Functions ---
103
+ def update_selection(selected_state, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative):
104
+ index = selected_state.index
105
+ lora_repo = sdxl_loras[index]["repo"]
106
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
107
 
108
  for lora_list in lora_defaults:
109
+ if lora_list["model"] == lora_repo:
110
  face_strength = lora_list.get("face_strength", 0.85)
111
  image_strength = lora_list.get("image_strength", 0.15)
112
  weight = lora_list.get("weight", 0.9)
113
  depth_control_scale = lora_list.get("depth_control_scale", 0.8)
114
  negative = lora_list.get("negative", "")
115
+
 
 
 
 
 
 
116
  return (
117
+ updated_text, gr.update(placeholder="Type a prompt"), face_strength,
118
+ image_strength, weight, depth_control_scale, negative, selected_state
 
 
 
 
 
 
119
  )
120
 
121
+ def center_crop_image(img):
122
  square_size = min(img.size)
123
+ left = (img.width - square_size) // 2
124
+ top = (img.height - square_size) // 2
125
+ return img.crop((left, top, left + square_size, top + square_size))
126
+
127
+ def process_face(image):
128
+ face_info = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
129
+ face_info = sorted(face_info, key=lambda x: (x['bbox'][2]-x['bbox'][0]) * (x['bbox'][3]-x['bbox'][1]))[-1]
130
+ face_emb = face_info['embedding']
131
+ face_kps = draw_kps(image, face_info['kps'])
132
+ return face_emb, face_kps
133
+
134
+ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, lora_scale):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  global last_fused, last_lora
136
+ if last_lora != repo_name and last_fused:
137
+ pipe.unfuse_lora()
138
+ pipe.unload_lora_weights()
139
+ pipe.load_lora_weights(repo_name)
140
+ pipe.fuse_lora(lora_scale)
141
+ last_lora, last_fused = repo_name, True
142
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  conditioning, pooled = compel(prompt)
144
+ negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ images = [face_kps, zoe(face_image).resize(face_kps.size)]
147
+ return pipe(
148
+ prompt_embeds=conditioning, pooled_prompt_embeds=pooled,
149
+ negative_prompt_embeds=negative_conditioning, negative_pooled_prompt_embeds=negative_pooled,
150
+ width=1024, height=1024, image_embeds=face_emb, image=face_image,
151
+ strength=1-image_strength, control_image=images, num_inference_steps=20,
152
+ guidance_scale=guidance_scale, controlnet_conditioning_scale=[face_strength, depth_control_scale]
153
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # --- UI Setup ---
156
+ with gr.Blocks() as demo:
157
+ photo = gr.Image(label="Upload a picture", interactive=True, type="pil", height=300)
158
+ gallery = gr.Gallery(label="Pick a style", allow_preview=False, columns=4, height=550)
159
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt...")
160
+ button = gr.Button("Run")
161
+ result = ImageSlider(interactive=False, label="Generated Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ button.click(fn=generate_image, inputs=[prompt, gr.State(), gr.State()], outputs=result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ demo.queue()
166
+ demo.launch(share=True)