multimodalart HF staff commited on
Commit
5cd0360
·
verified ·
1 Parent(s): 652664b

revert custom

Browse files
Files changed (1) hide show
  1. app.py +10 -116
app.py CHANGED
@@ -303,7 +303,7 @@ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_stren
303
  last_lora = repo_name
304
  return image
305
 
306
- def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora_path, progress=gr.Progress(track_tqdm=True)):
307
  selected_state_index = selected_state.index
308
  st = time.time()
309
  face_image = center_crop_image_as_square(face_image)
@@ -334,17 +334,13 @@ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_stre
334
  print(sdxl_loras[selected_state_index]["repo"])
335
  if negative == "":
336
  negative = None
 
 
 
 
 
337
 
338
- if not selected_state and not custom_loaded_lora:
339
- raise gr.Error("You must select a style")
340
- elif custom_loaded_lora:
341
- repo_name = "Loaded custom LoRA"
342
- full_path_lora = custom_lora_path
343
- else:
344
- repo_name = sdxl_loras[selected_state_index]["repo"]
345
- weight_name = sdxl_loras[selected_state_index]["weights"]
346
- full_path_lora = state_dicts[repo_name]["saved_name"]
347
-
348
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
349
  cross_attention_kwargs = None
350
  et = time.time()
@@ -372,92 +368,6 @@ def swap_gallery(order, sdxl_loras):
372
  def deselect():
373
  return gr.Gallery(selected_index=None)
374
 
375
- lora_archive = "/data"
376
-
377
- def get_huggingface_safetensors(link):
378
- split_link = link.split("/")
379
- if(len(split_link) == 2):
380
- model_card = ModelCard.load(link)
381
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
382
-
383
- print(image_path)
384
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
385
- fs = HfFileSystem()
386
- try:
387
- list_of_files = fs.ls(link, detail=False)
388
- for file in list_of_files:
389
- if(file.endswith(".safetensors")):
390
- safetensors_name = file.replace("/", "_")
391
- if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
392
- fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
393
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
394
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_url}"
395
- except:
396
- raise gr.Error("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
397
- return split_link[1], f"{lora_archive}/{safetensors_name}", image_url
398
-
399
- def get_civitai_safetensors(link):
400
- link_split = link.split("civitai.com/")
401
- pattern = re.compile(r'models\/(\d+)')
402
- regex_match = pattern.search(link_split[1])
403
- if(regex_match):
404
- civitai_model_id = regex_match.group(1)
405
-
406
- else:
407
- raise gr.Error("No CivitAI model id found in your URL")
408
- model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token=dc9537afe3dcbbfb4696dd79ee8c23d7"
409
- x = requests.get(model_request_url)
410
- model_data = x.json()
411
- if(model_data["nsfw"] == True):
412
- raise gr.Error("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
413
- elif(model_data["type"] != "LORA"):
414
- raise gr.Error("The model isn't tagged at CivitAI as a LoRA")
415
- model_link_download = None
416
- image_url = None
417
- for model in model_data["modelVersions"]:
418
- if(model["baseModel"] == "SDXL 1.0"):
419
- model_link_download = f"{model['downloadUrl']}/?token=dc9537afe3dcbbfb4696dd79ee8c23d7"
420
- safetensors_name = model["files"][0]["name"]
421
- if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
422
- safetensors_file_request = requests.get(model_link_download)
423
- with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
424
- file.write(safetensors_file_request.content)
425
- for image in model["images"]:
426
- if(image["nsfwLevel"] == 1):
427
- image_url = image["url"]
428
- break
429
- break
430
- return model_data["name"], f"{lora_archive}/{safetensors_name}", image_url
431
- if(not model_link_download):
432
- raise gr.Error("We couldn't find a SDXL LoRA on the model you've sent")
433
- def check_custom_model(link):
434
- if(link.startswith("https://")):
435
- if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
436
- link_split = link.split("huggingface.co/")
437
- return get_huggingface_safetensors(link_split[1])
438
- elif(link.startswith("https://civitai.com") or link.startswith("https://www.civitai.com")):
439
- return get_civitai_safetensors(link)
440
- else:
441
- return get_huggingface_safetensors(link)
442
-
443
- def show_loading_widget():
444
- return gr.update(visible=True)
445
-
446
- def load_custom_lora(link):
447
- title, path, image = check_custom_model(link)
448
- card = f'''
449
- <div class="custom_lora_card">
450
- <span>Loaded custom LoRA:</span>
451
- <div class="card_internal">
452
- <h3>{title}</h3>
453
- <img src="{image}" />
454
- </div>
455
- </div>
456
- '''
457
- return gr.update(visible=True), card, gr.update(visible=True), path, gr.Gallery(selected_index=None)
458
-
459
- def remove_custom_lora():
460
- return "", gr.update(visible=False), gr.update(visible=False), None
461
  with gr.Blocks(css="custom.css") as demo:
462
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
463
  title = gr.HTML(
@@ -471,7 +381,6 @@ with gr.Blocks(css="custom.css") as demo:
471
  elem_id="title",
472
  )
473
  selected_state = gr.State()
474
- custom_loaded_lora = gr.State()
475
  with gr.Row(elem_id="main_app"):
476
  with gr.Column(scale=4):
477
  with gr.Group(elem_id="gallery_box"):
@@ -492,9 +401,7 @@ with gr.Blocks(css="custom.css") as demo:
492
  show_share_button=False,
493
  height=550
494
  )
495
- custom_model = gr.Textbox(label="Enter a custom Hugging Face or CivitAI SDXL LoRA", placeholder="Paste Hugging Face or CivitAI model path...")
496
- custom_model_card = gr.HTML(visible=False)
497
- custom_model_button = gr.Button("Remove custom LoRA", visible=False)
498
  with gr.Column(scale=5):
499
  with gr.Row():
500
  prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
@@ -524,19 +431,6 @@ with gr.Blocks(css="custom.css") as demo:
524
  # outputs=[gallery, gr_sdxl_loras],
525
  # queue=False
526
  #)
527
- custom_model.change(
528
- fn=show_loading_widget,
529
- outputs=[custom_model_card]
530
- ).then(
531
- fn=load_custom_lora,
532
- inputs=[custom_model],
533
- outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery],
534
- queue=False
535
- )
536
- custom_model_button.click(
537
- fn=remove_custom_lora,
538
- outputs=[custom_model, custom_loaded_lora, custom_model_card, custom_loaded_lora]
539
- )
540
  gallery.select(
541
  fn=update_selection,
542
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
@@ -558,7 +452,7 @@ with gr.Blocks(css="custom.css") as demo:
558
  show_progress=False
559
  ).success(
560
  fn=run_lora,
561
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
562
  outputs=[result, share_group],
563
  )
564
  button.click(
@@ -568,7 +462,7 @@ with gr.Blocks(css="custom.css") as demo:
568
  show_progress=False
569
  ).success(
570
  fn=run_lora,
571
- inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
572
  outputs=[result, share_group],
573
  )
574
  share_button.click(None, [], [], js=share_js)
 
303
  last_lora = repo_name
304
  return image
305
 
306
+ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, progress=gr.Progress(track_tqdm=True)):
307
  selected_state_index = selected_state.index
308
  st = time.time()
309
  face_image = center_crop_image_as_square(face_image)
 
334
  print(sdxl_loras[selected_state_index]["repo"])
335
  if negative == "":
336
  negative = None
337
+
338
+ if not selected_state:
339
+ raise gr.Error("You must select a LoRA")
340
+ repo_name = sdxl_loras[selected_state_index]["repo"]
341
+ weight_name = sdxl_loras[selected_state_index]["weights"]
342
 
343
+ full_path_lora = state_dicts[repo_name]["saved_name"]
 
 
 
 
 
 
 
 
 
344
  #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
345
  cross_attention_kwargs = None
346
  et = time.time()
 
368
  def deselect():
369
  return gr.Gallery(selected_index=None)
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  with gr.Blocks(css="custom.css") as demo:
372
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
373
  title = gr.HTML(
 
381
  elem_id="title",
382
  )
383
  selected_state = gr.State()
 
384
  with gr.Row(elem_id="main_app"):
385
  with gr.Column(scale=4):
386
  with gr.Group(elem_id="gallery_box"):
 
401
  show_share_button=False,
402
  height=550
403
  )
404
+ custom_model = gr.Textbox(label="Enter a custom Hugging Face or CivitAI SDXL LoRA", interactive=False, placeholder="Coming soon...")
 
 
405
  with gr.Column(scale=5):
406
  with gr.Row():
407
  prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
 
431
  # outputs=[gallery, gr_sdxl_loras],
432
  # queue=False
433
  #)
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  gallery.select(
435
  fn=update_selection,
436
  inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
 
452
  show_progress=False
453
  ).success(
454
  fn=run_lora,
455
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
456
  outputs=[result, share_group],
457
  )
458
  button.click(
 
462
  show_progress=False
463
  ).success(
464
  fn=run_lora,
465
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras],
466
  outputs=[result, share_group],
467
  )
468
  share_button.click(None, [], [], js=share_js)