openfree commited on
Commit
3552d05
1 Parent(s): 4765e59

Update app-backup1.py

Browse files
Files changed (1) hide show
  1. app-backup1.py +202 -16
app-backup1.py CHANGED
@@ -19,6 +19,8 @@ from gradio_imageslider import ImageSlider
19
  import numpy as np
20
  import warnings
21
 
 
 
22
 
23
  huggingface_token = os.getenv("HF_TOKEN")
24
 
@@ -325,7 +327,7 @@ def remove_custom_lora(selected_indices, current_loras):
325
  lora_image_3
326
  )
327
 
328
- @spaces.GPU(duration=75)
329
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
330
  print("Generating image...")
331
  pipe.to("cuda")
@@ -345,7 +347,7 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
345
  ):
346
  yield img
347
 
348
- @spaces.GPU(duration=75)
349
  def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
350
  pipe_i2i.to("cuda")
351
  generator = torch.Generator(device="cuda").manual_seed(seed)
@@ -364,9 +366,11 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
364
  ).images[0]
365
  return final_image
366
 
367
- def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
 
 
368
  try:
369
- # 한글 감지 및 번역 (이 부분은 그대로 유지)
370
  if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
371
  translated = translator(prompt, max_length=512)[0]['translation_text']
372
  print(f"Original prompt: {prompt}")
@@ -378,7 +382,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
378
 
379
  selected_loras = [loras_state[idx] for idx in selected_indices]
380
 
381
- # Build the prompt with trigger words (이 부분은 그대로 유지)
382
  prepends = []
383
  appends = []
384
  for lora in selected_loras:
@@ -401,27 +405,40 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
401
  # Load LoRA weights with respective scales
402
  lora_names = []
403
  lora_weights = []
 
404
  with calculateDuration("Loading LoRA weights"):
405
  for idx, lora in enumerate(selected_loras):
406
  try:
407
  lora_name = f"lora_{idx}"
408
  lora_path = lora['repo']
409
- weight_name = lora.get("weights")
410
- print(f"Loading LoRA {lora_name} from {lora_path}")
 
 
 
 
411
  if image_input is not None:
412
- if weight_name:
413
- pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=lora_name)
414
- else:
415
- pipe_i2i.load_lora_weights(lora_path, adapter_name=lora_name)
 
416
  else:
417
- if weight_name:
418
- pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=lora_name)
419
- else:
420
- pipe.load_lora_weights(lora_path, adapter_name=lora_name)
 
 
421
  lora_names.append(lora_name)
422
  lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2 if idx == 1 else lora_scale_3)
 
 
423
  except Exception as e:
424
  print(f"Failed to load LoRA {lora_name}: {str(e)}")
 
 
 
425
 
426
  print("Loaded LoRAs:", lora_names)
427
  print("Adapter weights:", lora_weights)
@@ -437,11 +454,12 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
437
 
438
  print(f"Active adapters after loading: {pipe.get_active_adapters()}")
439
 
440
- # 여기서부터 이미지 생성 로직 (이 부분은 그대로 유지)
441
  with calculateDuration("Randomizing seed"):
442
  if randomize_seed:
443
  seed = random.randint(0, MAX_SEED)
444
 
 
445
  if image_input is not None:
446
  final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
447
  else:
@@ -522,6 +540,138 @@ def update_history(new_image, history):
522
  history.insert(0, new_image)
523
  return history
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  custom_theme = gr.themes.Base(
526
  primary_hue="blue",
527
  secondary_hue="purple",
@@ -825,6 +975,25 @@ input:focus, textarea:focus {
825
  max-width: 90% !important;
826
  margin: 0 !important; /* auto에서 0으로 변경 */
827
  margin-left: 20px !important; /* 왼쪽 여백 추가 */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
  }
829
  '''
830
 
@@ -839,6 +1008,9 @@ with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
839
  갤러리에서 원하는 모델을 선택(최대 3개까지) < 프롬프트에 한글 또는 영문으로 원하는 내용을 입력 < Generate 버튼 실행
840
  """
841
  )
 
 
 
842
 
843
  with gr.Row(elem_id="lora_gallery", equal_height=True):
844
  gallery = gr.Gallery(
@@ -855,6 +1027,7 @@ with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
855
  preview=False
856
  )
857
 
 
858
  with gr.Tab(label="Generate"):
859
  # Prompt and Generate Button
860
  with gr.Row():
@@ -1023,6 +1196,19 @@ with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
1023
  outputs=history_gallery
1024
  )
1025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
  if __name__ == "__main__":
1027
  app.queue(max_size=20)
1028
  app.launch(debug=True)
 
19
  import numpy as np
20
  import warnings
21
 
22
+ # 상단에 허깅페이스 USERNAME (해당 계정) 반드시 개별 지정할것
23
+ USERNAME = "openfree"
24
 
25
  huggingface_token = os.getenv("HF_TOKEN")
26
 
 
327
  lora_image_3
328
  )
329
 
330
+
331
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
332
  print("Generating image...")
333
  pipe.to("cuda")
 
347
  ):
348
  yield img
349
 
350
+
351
  def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
352
  pipe_i2i.to("cuda")
353
  generator = torch.Generator(device="cuda").manual_seed(seed)
 
366
  ).images[0]
367
  return final_image
368
 
369
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices,
370
+ lora_scale_1, lora_scale_2, lora_scale_3, randomize_seed, seed,
371
+ width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
372
  try:
373
+ # 한글 감지 및 번역
374
  if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
375
  translated = translator(prompt, max_length=512)[0]['translation_text']
376
  print(f"Original prompt: {prompt}")
 
382
 
383
  selected_loras = [loras_state[idx] for idx in selected_indices]
384
 
385
+ # Build the prompt with trigger words
386
  prepends = []
387
  appends = []
388
  for lora in selected_loras:
 
405
  # Load LoRA weights with respective scales
406
  lora_names = []
407
  lora_weights = []
408
+
409
  with calculateDuration("Loading LoRA weights"):
410
  for idx, lora in enumerate(selected_loras):
411
  try:
412
  lora_name = f"lora_{idx}"
413
  lora_path = lora['repo']
414
+
415
+ # Private 모델인 경우 특별 처리
416
+ if lora.get('private', False):
417
+ lora_path = load_private_model(lora_path, huggingface_token)
418
+ print(f"Using private model path: {lora_path}")
419
+
420
  if image_input is not None:
421
+ pipe_i2i.load_lora_weights(
422
+ lora_path,
423
+ adapter_name=lora_name,
424
+ token=huggingface_token
425
+ )
426
  else:
427
+ pipe.load_lora_weights(
428
+ lora_path,
429
+ adapter_name=lora_name,
430
+ token=huggingface_token
431
+ )
432
+
433
  lora_names.append(lora_name)
434
  lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2 if idx == 1 else lora_scale_3)
435
+ print(f"Successfully loaded LoRA {lora_name} from {lora_path}")
436
+
437
  except Exception as e:
438
  print(f"Failed to load LoRA {lora_name}: {str(e)}")
439
+ continue
440
+
441
+
442
 
443
  print("Loaded LoRAs:", lora_names)
444
  print("Adapter weights:", lora_weights)
 
454
 
455
  print(f"Active adapters after loading: {pipe.get_active_adapters()}")
456
 
457
+ # Randomize seed if needed
458
  with calculateDuration("Randomizing seed"):
459
  if randomize_seed:
460
  seed = random.randint(0, MAX_SEED)
461
 
462
+ # Generate image
463
  if image_input is not None:
464
  final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
465
  else:
 
540
  history.insert(0, new_image)
541
  return history
542
 
543
+
544
+
545
+ def refresh_models(huggingface_token):
546
+ try:
547
+ headers = {
548
+ "Authorization": f"Bearer {huggingface_token}",
549
+ "Accept": "application/json"
550
+ }
551
+
552
+ username = USERNAME
553
+ api_url = f"https://huggingface.co/api/models?author={username}"
554
+ response = requests.get(api_url, headers=headers)
555
+ if response.status_code != 200:
556
+ raise Exception(f"Failed to fetch models from HuggingFace. Status code: {response.status_code}")
557
+
558
+ all_models = response.json()
559
+ print(f"Found {len(all_models)} models for user {username}")
560
+
561
+ user_models = [
562
+ model for model in all_models
563
+ if model.get('tags') and ('flux' in [tag.lower() for tag in model.get('tags', [])] or
564
+ 'flux-lora' in [tag.lower() for tag in model.get('tags', [])])
565
+ ]
566
+
567
+ print(f"Found {len(user_models)} FLUX models")
568
+
569
+ new_models = []
570
+ for model in user_models:
571
+ try:
572
+ model_id = model['id']
573
+ model_card_url = f"https://huggingface.co/api/models/{model_id}"
574
+ model_info_response = requests.get(model_card_url, headers=headers)
575
+ model_info = model_info_response.json()
576
+
577
+ # 이미지 URL에 토큰을 포함시키는 방식으로 변경
578
+ is_private = model.get('private', False)
579
+ base_image_name = "1732195028106__000001000_0.jpg" # 기본 이미지 이름
580
+
581
+ try:
582
+ # 실제 이미지 파일 확인
583
+ fs = HfFileSystem(token=huggingface_token)
584
+ samples_path = f"{model_id}/samples"
585
+ files = fs.ls(samples_path, detail=True)
586
+ jpg_files = [
587
+ f['name'] for f in files
588
+ if isinstance(f, dict) and
589
+ 'name' in f and
590
+ f['name'].lower().endswith('.jpg') and
591
+ any(char.isdigit() for char in os.path.basename(f['name']))
592
+ ]
593
+
594
+ if jpg_files:
595
+ base_image_name = os.path.basename(jpg_files[0])
596
+ except Exception as e:
597
+ print(f"Error accessing samples folder for {model_id}: {str(e)}")
598
+
599
+ # 이미지 URL 구성 (토큰 포함)
600
+ if is_private:
601
+ # Private 모델의 경우 로컬 캐시 경로 사용
602
+ cache_dir = f"models/{model_id.replace('/', '_')}/samples"
603
+ os.makedirs(cache_dir, exist_ok=True)
604
+
605
+ # 이미지 다운로드
606
+ image_url = f"https://huggingface.co/{model_id}/resolve/main/samples/{base_image_name}"
607
+ local_image_path = os.path.join(cache_dir, base_image_name)
608
+
609
+ if not os.path.exists(local_image_path):
610
+ response = requests.get(image_url, headers=headers)
611
+ if response.status_code == 200:
612
+ with open(local_image_path, 'wb') as f:
613
+ f.write(response.content)
614
+
615
+ image_url = local_image_path
616
+ else:
617
+ image_url = f"https://huggingface.co/{model_id}/resolve/main/samples/{base_image_name}"
618
+
619
+ model_info = {
620
+ "image": image_url,
621
+ "title": f"[Private] {model_id.split('/')[-1]}" if is_private else model_id.split('/')[-1],
622
+ "repo": model_id,
623
+ "weights": "pytorch_lora_weights.safetensors",
624
+ "trigger_word": model_info.get('instance_prompt', ''),
625
+ "private": is_private
626
+ }
627
+ new_models.append(model_info)
628
+ print(f"Added model: {model_id} with image: {image_url}")
629
+
630
+ except Exception as e:
631
+ print(f"Error processing model {model['id']}: {str(e)}")
632
+ continue
633
+
634
+ updated_loras = new_models + [lora for lora in loras if lora['repo'] not in [m['repo'] for m in new_models]]
635
+
636
+ print(f"Total models after refresh: {len(updated_loras)}")
637
+ return updated_loras
638
+ except Exception as e:
639
+ print(f"Error refreshing models: {str(e)}")
640
+ return loras
641
+
642
+ def load_private_model(model_id, huggingface_token):
643
+ """Private 모델을 로드하는 함수"""
644
+ try:
645
+ headers = {"Authorization": f"Bearer {huggingface_token}"}
646
+
647
+ # 모델 다운로드
648
+ local_dir = snapshot_download(
649
+ repo_id=model_id,
650
+ token=huggingface_token,
651
+ local_dir=f"models/{model_id.replace('/', '_')}",
652
+ local_dir_use_symlinks=False
653
+ )
654
+
655
+ # safetensors 파일 찾기
656
+ safetensors_file = None
657
+ for root, dirs, files in os.walk(local_dir):
658
+ for file in files:
659
+ if file.endswith('.safetensors'):
660
+ safetensors_file = os.path.join(root, file)
661
+ break
662
+ if safetensors_file:
663
+ break
664
+
665
+ if not safetensors_file:
666
+ raise Exception(f"No .safetensors file found in {local_dir}")
667
+
668
+ print(f"Found safetensors file: {safetensors_file}")
669
+ return safetensors_file # 전체 경로를 반환
670
+
671
+ except Exception as e:
672
+ print(f"Error loading private model {model_id}: {str(e)}")
673
+ raise e
674
+
675
  custom_theme = gr.themes.Base(
676
  primary_hue="blue",
677
  secondary_hue="purple",
 
975
  max-width: 90% !important;
976
  margin: 0 !important; /* auto에서 0으로 변경 */
977
  margin-left: 20px !important; /* 왼쪽 여백 추가 */
978
+
979
+ /* 새로고침 버튼 스타일 */
980
+ #refresh-button {
981
+ margin: 10px;
982
+ padding: 8px 16px;
983
+ background-color: #4a5568;
984
+ color: white;
985
+ border-radius: 8px;
986
+ transition: all 0.3s ease;
987
+ }
988
+
989
+ #refresh-button:hover {
990
+ background-color: #2d3748;
991
+ transform: scale(1.05);
992
+ }
993
+
994
+ #refresh-button:active {
995
+ transform: scale(0.95);
996
+ }
997
  }
998
  '''
999
 
 
1008
  갤러리에서 원하는 모델을 선택(최대 3개까지) < 프롬프트에 한글 또는 영문으로 원하는 내용을 입력 < Generate 버튼 실행
1009
  """
1010
  )
1011
+ # 새로고침 버튼 추가
1012
+ with gr.Row():
1013
+ refresh_button = gr.Button("🔄 모델 새로고침(나만의 맞춤 학습된 Private 모델 불러오기)", variant="secondary")
1014
 
1015
  with gr.Row(elem_id="lora_gallery", equal_height=True):
1016
  gallery = gr.Gallery(
 
1027
  preview=False
1028
  )
1029
 
1030
+
1031
  with gr.Tab(label="Generate"):
1032
  # Prompt and Generate Button
1033
  with gr.Row():
 
1196
  outputs=history_gallery
1197
  )
1198
 
1199
+ # 새로고침 버튼 이벤트 핸들러
1200
+ def refresh_gallery():
1201
+ updated_loras = refresh_models(huggingface_token)
1202
+ return (
1203
+ gr.update(value=[(item["image"], item["title"]) for item in updated_loras]),
1204
+ updated_loras
1205
+ )
1206
+
1207
+ refresh_button.click(
1208
+ refresh_gallery,
1209
+ outputs=[gallery, loras_state]
1210
+ )
1211
+
1212
  if __name__ == "__main__":
1213
  app.queue(max_size=20)
1214
  app.launch(debug=True)