multimodalart HF staff commited on
Commit
bd61b23
·
verified ·
1 Parent(s): d230c07

attempt custom

Browse files
Files changed (1) hide show
  1. app.py +115 -9
app.py CHANGED
@@ -303,7 +303,7 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
303
  last_lora = repo_name
304
  return image
305
 
306
- def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
307
  selected_state_index = selected_state.index
308
  st = time.time()
309
  face_image = center_crop_image_as_square(face_image)
@@ -334,13 +334,17 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
334
  print(sdxl_loras[selected_state_index]["repo"])
335
  if negative == "":
336
  negative = None
337
-
338
- if not selected_state:
339
- raise gr.Error("You must select a LoRA")
340
- repo_name = sdxl_loras[selected_state_index]["repo"]
341
- weight_name = sdxl_loras[selected_state_index]["weights"]
342
 
343
- full_path_lora = state_dicts[repo_name]["saved_name"]
 
 
 
 
 
 
 
 
 
344
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
345
  cross_attention_kwargs = None
346
  et = time.time()
@@ -368,6 +372,92 @@ def swap_gallery(order, sdxl_loras):
368
  def deselect():
369
  return gr.Gallery(selected_index=None)
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  with gr.Blocks(css="custom.css") as demo:
372
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
373
  title = gr.HTML(
@@ -381,6 +471,7 @@ with gr.Blocks(css="custom.css") as demo:
381
  elem_id="title",
382
  )
383
  selected_state = gr.State()
 
384
  with gr.Row(elem_id="main_app"):
385
  with gr.Column(scale=4):
386
  with gr.Group(elem_id="gallery_box"):
@@ -402,6 +493,8 @@ with gr.Blocks(css="custom.css") as demo:
402
  height=550
403
  )
404
  custom_model = gr.Textbox(label="Enter a custom Hugging Face or CivitAI SDXL LoRA", interactive=False, placeholder="Coming soon...")
 
 
405
  with gr.Column(scale=5):
406
  with gr.Row():
407
  prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
@@ -431,6 +524,19 @@ with gr.Blocks(css="custom.css") as demo:
431
  # outputs=[gallery, gr_sdxl_loras],
432
  # queue=False
433
  #)
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  gallery.select(
435
  fn=update_selection,
436
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
@@ -452,7 +558,7 @@ with gr.Blocks(css="custom.css") as demo:
452
  show_progress=False
453
  ).success(
454
  fn=run_lora,
455
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
456
  outputs=[result, share_group],
457
  )
458
  button.click(
@@ -462,7 +568,7 @@ with gr.Blocks(css="custom.css") as demo:
462
  show_progress=False
463
  ).success(
464
  fn=run_lora,
465
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
466
  outputs=[result, share_group],
467
  )
468
  share_button.click(None, [], [], js=share_js)
 
303
  last_lora = repo_name
304
  return image
305
 
306
+ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora_path, progress=gr.Progress(track_tqdm=True)):
307
  selected_state_index = selected_state.index
308
  st = time.time()
309
  face_image = center_crop_image_as_square(face_image)
 
334
  print(sdxl_loras[selected_state_index]["repo"])
335
  if negative == "":
336
  negative = None
 
 
 
 
 
337
 
338
+ if not selected_state and not custom_loaded_lora:
339
+ raise gr.Error("You must select a style")
340
+ elif custom_loaded_lora:
341
+ repo_name = "Loaded custom LoRA"
342
+ full_path_lora = custom_lora_path
343
+ else:
344
+ repo_name = sdxl_loras[selected_state_index]["repo"]
345
+ weight_name = sdxl_loras[selected_state_index]["weights"]
346
+ full_path_lora = state_dicts[repo_name]["saved_name"]
347
+
348
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
349
  cross_attention_kwargs = None
350
  et = time.time()
 
372
  def deselect():
373
  return gr.Gallery(selected_index=None)
374
 
375
+ lora_archive = "/data"
376
+
377
+ def get_huggingface_safetensors(link):
378
+ split_link = link.split("/")
379
+ if(len(split_link) == 2):
380
+ model_card = ModelCard.load(link)
381
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
382
+
383
+ print(image_path)
384
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
385
+ fs = HfFileSystem()
386
+ try:
387
+ list_of_files = fs.ls(link, detail=False)
388
+ for file in list_of_files:
389
+ if(file.endswith(".safetensors")):
390
+ safetensors_name = file.replace("/", "_")
391
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
392
+ fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
393
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
394
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_url}"
395
+ except:
396
+ raise gr.Error("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
397
+ return split_link[1], f"{lora_archive}/{safetensors_name}", image_url
398
+
399
+ def get_civitai_safetensors(link):
400
+ link_split = link.split("civitai.com/")
401
+ pattern = re.compile(r'models\/(\d+)')
402
+ regex_match = pattern.search(link_split[1])
403
+ if(regex_match):
404
+ civitai_model_id = regex_match.group(1)
405
+
406
+ else:
407
+ raise gr.Error("No CivitAI model id found in your URL")
408
+ model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token=dc9537afe3dcbbfb4696dd79ee8c23d7"
409
+ x = requests.get(model_request_url)
410
+ model_data = x.json()
411
+ if(model_data["nsfw"] == True):
412
+ raise gr.Error("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
413
+ elif(model_data["type"] != "LORA"):
414
+ raise gr.Error("The model isn't tagged at CivitAI as a LoRA")
415
+ model_link_download = None
416
+ image_url = None
417
+ for model in model_data["modelVersions"]:
418
+ if(model["baseModel"] == "SDXL 1.0"):
419
+ model_link_download = f"{model['downloadUrl']}/?token=dc9537afe3dcbbfb4696dd79ee8c23d7"
420
+ safetensors_name = model["files"][0]["name"]
421
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
422
+ safetensors_file_request = requests.get(model_link_download)
423
+ with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
424
+ file.write(safetensors_file_request.content)
425
+ for image in model["images"]:
426
+ if(image["nsfwLevel"] == 1):
427
+ image_url = image["url"]
428
+ break
429
+ break
430
+ return model_data["name"], f"{lora_archive}/{safetensors_name}", image_url
431
+ if(not model_link_download):
432
+ raise gr.Error("We couldn't find a SDXL LoRA on the model you've sent")
433
+ def check_custom_model(link):
434
+ if(link.startswith("https://")):
435
+ if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
436
+ link_split = link.split("huggingface.co/")
437
+ return get_huggingface_safetensors(link_split[1])
438
+ elif(link.startswith("https://civitai.com") or link.startswith("https://www.civitai.com")):
439
+ return get_civitai_safetensors(link)
440
+ else:
441
+ return get_huggingface_safetensors(link)
442
+
443
+ def show_loading_widget():
444
+ return gr.update(visible=True)
445
+
446
+ def load_custom_lora(link):
447
+ title, path, image = check_custom_model(link)
448
+ card = f'''
449
+ <div class="custom_lora_card">
450
+ <span>Loaded custom LoRA:</span>
451
+ <div class="card_internal">
452
+ <h3>{title}</h3>
453
+ <img src="{image}" />
454
+ </div>
455
+ </div>
456
+ '''
457
+ return gr.update(visible=True), card, gr.update(visible=True), path, gr.Gallery(selected_index=None)
458
+
459
+ def remove_custom_lora():
460
+ return "", gr.update(visible=False), gr.update(visible=False), None
461
  with gr.Blocks(css="custom.css") as demo:
462
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
463
  title = gr.HTML(
 
471
  elem_id="title",
472
  )
473
  selected_state = gr.State()
474
+ custom_loaded_lora = gr.State()
475
  with gr.Row(elem_id="main_app"):
476
  with gr.Column(scale=4):
477
  with gr.Group(elem_id="gallery_box"):
 
493
  height=550
494
  )
495
  custom_model = gr.Textbox(label="Enter a custom Hugging Face or CivitAI SDXL LoRA", interactive=False, placeholder="Coming soon...")
496
+ custom_model_card = gr.HTML(visible=False)
497
+ custom_model_button = gr.Button("Remove custom LoRA", visible=False)
498
  with gr.Column(scale=5):
499
  with gr.Row():
500
  prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
 
524
  # outputs=[gallery, gr_sdxl_loras],
525
  # queue=False
526
  #)
527
+ custom_model.change(
528
+ fn=show_loading_widget,
529
+ outputs=[custom_model_card]
530
+ ).then(
531
+ fn=load_custom_lora,
532
+ inputs=[custom_model],
533
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery],
534
+ queue=False
535
+ )
536
+ custom_model_button.click(
537
+ fn=remove_custom_lora,
538
+ outputs=[custom_model, custom_loaded_lora, custom_model_card, custom_loaded_lora]
539
+ )
540
  gallery.select(
541
  fn=update_selection,
542
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
 
558
  show_progress=False
559
  ).success(
560
  fn=run_lora,
561
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
562
  outputs=[result, share_group],
563
  )
564
  button.click(
 
568
  show_progress=False
569
  ).success(
570
  fn=run_lora,
571
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
572
  outputs=[result, share_group],
573
  )
574
  share_button.click(None, [], [], js=share_js)