ginipick commited on
Commit
17d8233
·
verified ·
1 Parent(s): ef45d8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -194
app.py CHANGED
@@ -1,5 +1,6 @@
1
- import os
2
  import spaces
 
3
  import time
4
  import gradio as gr
5
  import torch
@@ -8,9 +9,11 @@ from torchvision import transforms
8
  from dataclasses import dataclass
9
  import math
10
  from typing import Callable
 
11
  from tqdm import tqdm
12
  import bitsandbytes as bnb
13
  from bitsandbytes.nn.modules import Params4bit, QuantState
 
14
  import torch
15
  import random
16
  from einops import rearrange, repeat
@@ -18,8 +21,11 @@ from diffusers import AutoencoderKL
18
  from torch import Tensor, nn
19
  from transformers import CLIPTextModel, CLIPTokenizer
20
  from transformers import T5EncoderModel, T5Tokenizer
21
- from transformers import MarianMTModel, MarianTokenizer, pipeline
22
- from huggingface_hub import snapshot_download
 
 
 
23
 
24
  class HFEmbedder(nn.Module):
25
  def __init__(self, version: str, max_length: int, **hf_kwargs):
@@ -54,6 +60,7 @@ class HFEmbedder(nn.Module):
54
  output_hidden_states=False,
55
  )
56
  return outputs[self.output_key]
 
57
 
58
  device = "cuda"
59
  t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
@@ -63,6 +70,9 @@ ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="va
63
  # freeze(t5)
64
 
65
 
 
 
 
66
  def functional_linear_4bits(x, weight, bias):
67
  out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
68
  out = out.to(x)
@@ -200,6 +210,9 @@ class Linear(ForgeLoader4Bit):
200
  nn.Linear = Linear
201
 
202
 
 
 
 
203
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
204
  q, k = apply_rope(q, k, pe)
205
 
@@ -724,6 +737,8 @@ def get_image(image) -> torch.Tensor | None:
724
  return img[None, ...]
725
 
726
 
 
 
727
 
728
  from huggingface_hub import hf_hub_download
729
  from safetensors.torch import load_file
@@ -734,157 +749,43 @@ model = Flux().to(dtype=torch.bfloat16, device="cuda")
734
  result = model.load_state_dict(sd)
735
  model_zero_init = False
736
 
737
-
738
- # 언어-모델 매핑 딕셔너리 추가
739
- TRANSLATORS = {
740
-
741
- "Korean": "Helsinki-NLP/opus-mt-ko-en",
742
- "Japanese": "Helsinki-NLP/opus-mt-ja-en",
743
- "Chinese": "Helsinki-NLP/opus-mt-zh-en",
744
- "Russian": "Helsinki-NLP/opus-mt-ru-en",
745
- "Spanish": "Helsinki-NLP/opus-mt-es-en",
746
- "French": "Helsinki-NLP/opus-mt-fr-en",
747
- "Arabic": "Helsinki-NLP/opus-mt-ar-en",
748
- "Bengali": "Helsinki-NLP/opus-mt-bn-en",
749
- "Estonian": "Helsinki-NLP/opus-mt-et-en",
750
- "Polish": "Helsinki-NLP/opus-mt-pl-en",
751
- "Swedish": "Helsinki-NLP/opus-mt-sv-en",
752
- "Thai": "Helsinki-NLP/opus-mt-th-en",
753
- "Urdu": "Helsinki-NLP/opus-mt-ur-en",
754
- "Bulgarian": "Helsinki-NLP/opus-mt-bg-en",
755
- "Catalan": "Helsinki-NLP/opus-mt-ca-en",
756
- "Czech": "Helsinki-NLP/opus-mt-cs-en",
757
- "Azerbaijani": "Helsinki-NLP/opus-mt-az-en",
758
- "Basque": "Helsinki-NLP/opus-mt-bat-en",
759
- "Bicolano": "Helsinki-NLP/opus-mt-bcl-en",
760
- "Bemba": "Helsinki-NLP/opus-mt-bem-en",
761
- "Berber": "Helsinki-NLP/opus-mt-ber-en",
762
- "Bislama": "Helsinki-NLP/opus-mt-bi-en",
763
- "Bantu": "Helsinki-NLP/opus-mt-bnt-en",
764
- "Brazilian Sign Language": "Helsinki-NLP/opus-mt-bzs-en",
765
- "Caucasian": "Helsinki-NLP/opus-mt-cau-en",
766
- "Cebuano": "Helsinki-NLP/opus-mt-ceb-en",
767
- "Celtic": "Helsinki-NLP/opus-mt-cel-en",
768
- "Chuukese": "Helsinki-NLP/opus-mt-chk-en",
769
- "Creoles and pidgins (French)": "Helsinki-NLP/opus-mt-cpf-en",
770
- "Seychelles Creole": "Helsinki-NLP/opus-mt-crs-en",
771
- "American Sign Language": "Helsinki-NLP/opus-mt-ase-en",
772
- "Artificial Language": "Helsinki-NLP/opus-mt-art-en",
773
- "Atlantic-Congo": "Helsinki-NLP/opus-mt-alv-en",
774
- "Afroasiatic": "Helsinki-NLP/opus-mt-afa-en",
775
- "Afrikaans": "Helsinki-NLP/opus-mt-af-en",
776
- "Austroasiatic": "Helsinki-NLP/opus-mt-aav-en"
777
- }
778
-
779
- translators_cache = {}
780
-
781
-
782
-
783
- # 모델 캐시 디렉토리 설정
784
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
785
-
786
- def download_model(model_name):
787
- """모델을 미리 다운로드"""
788
- try:
789
- cache_dir = os.path.join('/tmp/transformers_cache', model_name.split('/')[-1])
790
- snapshot_download(
791
- repo_id=model_name,
792
- cache_dir=cache_dir,
793
- local_files_only=False
794
- )
795
- return cache_dir
796
- except Exception as e:
797
- print(f"Error downloading model {model_name}: {e}")
798
- return None
799
-
800
- def get_translator(lang):
801
- """번역기 초기화 및 반환"""
802
- if lang == "English":
803
- return None
804
-
805
- if lang not in translators_cache:
806
- try:
807
- model_name = TRANSLATORS[lang]
808
-
809
- # pipeline 사용 대신 직접 모델 로드
810
- tokenizer = MarianTokenizer.from_pretrained(model_name)
811
- model = MarianMTModel.from_pretrained(model_name)
812
-
813
- # CPU에서 실행
814
- model = model.to("cpu").eval()
815
-
816
- translators_cache[lang] = {
817
- "model": model,
818
- "tokenizer": tokenizer
819
- }
820
- print(f"Successfully loaded translator for {lang}")
821
-
822
- except Exception as e:
823
- print(f"Error loading translator for {lang}: {e}")
824
- return None
825
-
826
- return translators_cache[lang]
827
-
828
- def translate_text(text, translator_info):
829
- """번역 수행"""
830
- if translator_info is None:
831
- return text
832
-
833
- try:
834
- tokenizer = translator_info["tokenizer"]
835
- model = translator_info["model"]
836
-
837
- # 입력 텍스트 전처리
838
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
839
-
840
- # 번역 수행
841
- with torch.no_grad():
842
- outputs = model.generate(**inputs)
843
-
844
- # 번역 결과 디코딩
845
- translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
846
-
847
- print(f"Original text: {text}")
848
- print(f"Translated text: {translated}")
849
-
850
- return translated
851
-
852
- except Exception as e:
853
- print(f"Translation error: {e}")
854
- return text
855
 
856
 
857
  @spaces.GPU
858
  @torch.no_grad()
859
  def generate_image(
860
- prompt, source_lang, width, height, guidance, inference_steps, seed,
861
  do_img2img, init_image, image2image_strength, resize_img,
862
  progress=gr.Progress(track_tqdm=True),
863
  ):
864
- # 번역 처리
865
- try:
866
- if source_lang != "English":
867
- translator_info = get_translator(source_lang)
868
- if translator_info is not None:
869
- translated_prompt = translate_text(prompt, translator_info)
870
- print(f"Using translated prompt: {translated_prompt}")
871
- else:
872
- print(f"No translator available for {source_lang}, using original prompt")
873
- translated_prompt = prompt
874
- else:
875
- translated_prompt = prompt
876
- except Exception as e:
877
- print(f"Translation failed: {e}")
878
- translated_prompt = prompt
879
-
880
-
881
-
882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  if seed == 0:
884
  seed = int(random.random() * 1000000)
885
 
886
  device = "cuda" if torch.cuda.is_available() else "cpu"
887
  torch_device = torch.device(device)
 
 
888
 
889
  global model, model_zero_init
890
  if not model_zero_init:
@@ -901,11 +802,10 @@ def generate_image(
901
  height = init_image.shape[-2]
902
  width = init_image.shape[-1]
903
  init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
904
- init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
905
 
906
  generator = torch.Generator(device=device).manual_seed(seed)
907
- x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16),
908
- device=device, dtype=torch.bfloat16, generator=generator)
909
 
910
  num_steps = inference_steps
911
  timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
@@ -916,18 +816,22 @@ def generate_image(
916
  timesteps = timesteps[t_idx:]
917
  x = t * x + (1.0 - t) * init_image.to(x.dtype)
918
 
919
- inp = prepare(t5=t5, clip=clip, img=x, prompt=translated_prompt)
920
  x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
921
 
 
 
 
922
  x = unpack(x.float(), height, width)
923
  with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
924
- x = (x / ae.config.scaling_factor) + ae.config.shift_factor
925
  x = ae.decode(x).sample
926
 
927
  x = x.clamp(-1, 1)
928
  x = rearrange(x[0], "c h w -> h w c")
929
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
930
 
 
931
  return img, seed, translated_prompt
932
 
933
  css = """
@@ -936,21 +840,14 @@ footer {
936
  }
937
  """
938
 
 
939
  def create_demo():
940
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
 
941
  with gr.Row():
942
  with gr.Column():
943
- source_lang = gr.Dropdown(
944
- choices=["English"] + sorted(list(TRANSLATORS.keys())),
945
- value="English",
946
- label="Source Language"
947
- )
948
 
949
- prompt = gr.Textbox(
950
- label="Prompt",
951
- value="A beautiful landscape"
952
- )
953
-
954
  width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
955
  height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
956
  guidance = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Guidance", value=3.5)
@@ -964,44 +861,13 @@ def create_demo():
964
  seed = gr.Number(label="Seed", precision=-1)
965
  do_img2img = gr.Checkbox(label="Image to Image", value=False)
966
  init_image = gr.Image(label="Input Image", visible=False)
967
- image2image_strength = gr.Slider(
968
- minimum=0.0, maximum=1.0, step=0.01,
969
- label="Noising strength", value=0.8, visible=False
970
- )
971
  resize_img = gr.Checkbox(label="Resize image", value=True, visible=False)
972
  generate_button = gr.Button("Generate")
973
 
974
  with gr.Column():
975
  output_image = gr.Image(label="Generated Image")
976
  output_seed = gr.Text(label="Used Seed")
977
- translated_prompt = gr.Text(label="Translated Prompt")
978
-
979
- # 다국어 예제
980
- examples = [
981
- # English
982
- ["A beautiful sunset over mountains", "English", 768, 768, 3.5, 30, 0, False, None, 0.8, True],
983
- # Korean
984
- ["벚꽃이 흩날리는 서울의 봄 풍경", "Korean", 768, 768, 3.5, 30, 0, False, None, 0.8, True],
985
- # Japanese
986
- ["富士山と桜の美しい風景", "Japanese", 768, 768, 3.5, 30, 0, False, None, 0.8, True],
987
- # Chinese
988
- ["长城日落的壮丽景色", "Chinese", 768, 768, 3.5, 30, 0, False, None, 0.8, True],
989
- # Spanish
990
- ["Un hermoso atardecer en la playa", "Spanish", 768, 768, 3.5, 30, 0, False, None, 0.8, True]
991
- ]
992
-
993
- gr.Examples(
994
- examples=examples,
995
- inputs=[
996
- prompt, source_lang, width, height, guidance, inference_steps,
997
- seed, do_img2img, init_image, image2image_strength, resize_img
998
- ],
999
- outputs=[output_image, output_seed, translated_prompt],
1000
- fn=generate_image,
1001
- cache_examples=True
1002
- )
1003
-
1004
-
1005
 
1006
  do_img2img.change(
1007
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
@@ -1011,16 +877,18 @@ def create_demo():
1011
 
1012
  generate_button.click(
1013
  fn=generate_image,
1014
- inputs=[
1015
- prompt, source_lang, width, height, guidance, inference_steps,
1016
- seed, do_img2img, init_image, image2image_strength, resize_img
1017
- ],
1018
- outputs=[output_image, output_seed, translated_prompt]
1019
  )
 
 
 
 
 
 
1020
 
1021
  return demo
1022
 
1023
  if __name__ == "__main__":
1024
- print("Starting demo...")
1025
  demo = create_demo()
1026
- demo.launch(share=True)
 
1
+ # import os
2
  import spaces
3
+
4
  import time
5
  import gradio as gr
6
  import torch
 
9
  from dataclasses import dataclass
10
  import math
11
  from typing import Callable
12
+
13
  from tqdm import tqdm
14
  import bitsandbytes as bnb
15
  from bitsandbytes.nn.modules import Params4bit, QuantState
16
+
17
  import torch
18
  import random
19
  from einops import rearrange, repeat
 
21
  from torch import Tensor, nn
22
  from transformers import CLIPTextModel, CLIPTokenizer
23
  from transformers import T5EncoderModel, T5Tokenizer
24
+ # from optimum.quanto import freeze, qfloat8, quantize
25
+ from transformers import pipeline
26
+
27
+ ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
28
+ ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
29
 
30
  class HFEmbedder(nn.Module):
31
  def __init__(self, version: str, max_length: int, **hf_kwargs):
 
60
  output_hidden_states=False,
61
  )
62
  return outputs[self.output_key]
63
+
64
 
65
  device = "cuda"
66
  t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
 
70
  # freeze(t5)
71
 
72
 
73
+ # ---------------- NF4 ----------------
74
+
75
+
76
  def functional_linear_4bits(x, weight, bias):
77
  out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
78
  out = out.to(x)
 
210
  nn.Linear = Linear
211
 
212
 
213
+ # ---------------- Model ----------------
214
+
215
+
216
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
217
  q, k = apply_rope(q, k, pe)
218
 
 
737
  return img[None, ...]
738
 
739
 
740
+ # ---------------- Demo ----------------
741
+
742
 
743
  from huggingface_hub import hf_hub_download
744
  from safetensors.torch import load_file
 
749
  result = model.load_state_dict(sd)
750
  model_zero_init = False
751
 
752
+ # model = Flux().to(dtype=torch.bfloat16, device="cuda")
753
+ # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
 
756
  @spaces.GPU
757
  @torch.no_grad()
758
  def generate_image(
759
+ prompt, width, height, guidance, inference_steps, seed,
760
  do_img2img, init_image, image2image_strength, resize_img,
761
  progress=gr.Progress(track_tqdm=True),
762
  ):
763
+ translated_prompt = prompt
764
+
765
+ # 한글 또는 일본어 문자 감지
766
+ def contains_korean(text):
767
+ return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
+ def contains_japanese(text):
770
+ return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
771
+
772
+ # 한글이나 일본어가 있으면 번역
773
+ if contains_korean(prompt):
774
+ translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
775
+ print(f"Translated Korean prompt: {translated_prompt}")
776
+ prompt = translated_prompt
777
+ elif contains_japanese(prompt):
778
+ translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
779
+ print(f"Translated Japanese prompt: {translated_prompt}")
780
+ prompt = translated_prompt
781
+
782
  if seed == 0:
783
  seed = int(random.random() * 1000000)
784
 
785
  device = "cuda" if torch.cuda.is_available() else "cpu"
786
  torch_device = torch.device(device)
787
+
788
+
789
 
790
  global model, model_zero_init
791
  if not model_zero_init:
 
802
  height = init_image.shape[-2]
803
  width = init_image.shape[-1]
804
  init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
805
+ init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
806
 
807
  generator = torch.Generator(device=device).manual_seed(seed)
808
+ x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
 
809
 
810
  num_steps = inference_steps
811
  timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
 
816
  timesteps = timesteps[t_idx:]
817
  x = t * x + (1.0 - t) * init_image.to(x.dtype)
818
 
819
+ inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
820
  x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
821
 
822
+ # with profile(activities=[ProfilerActivity.CPU],record_shapes=True,profile_memory=True) as prof:
823
+ # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
824
+
825
  x = unpack(x.float(), height, width)
826
  with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
827
+ x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
828
  x = ae.decode(x).sample
829
 
830
  x = x.clamp(-1, 1)
831
  x = rearrange(x[0], "c h w -> h w c")
832
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
833
 
834
+
835
  return img, seed, translated_prompt
836
 
837
  css = """
 
840
  }
841
  """
842
 
843
+
844
  def create_demo():
845
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
846
+
847
  with gr.Row():
848
  with gr.Column():
849
+ prompt = gr.Textbox(label="Prompt(한글 가능)", value="A cute and fluffy golden retriever puppy sitting upright, holding a neatly designed white sign with bold, colorful lettering that reads 'Have a Happy Day!' in cheerful fonts. The puppy has expressive, sparkling eyes, a happy smile, and fluffy ears slightly flopped. The background is a vibrant and sunny meadow with soft-focus flowers, glowing sunlight filtering through the trees, and a warm golden glow that enhances the joyful atmosphere. The sign is framed with small decorative flowers, adding a charming and wholesome touch. Ensure the text on the sign is clear and legible.")
 
 
 
 
850
 
 
 
 
 
 
851
  width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
852
  height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
853
  guidance = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Guidance", value=3.5)
 
861
  seed = gr.Number(label="Seed", precision=-1)
862
  do_img2img = gr.Checkbox(label="Image to Image", value=False)
863
  init_image = gr.Image(label="Input Image", visible=False)
864
+ image2image_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Noising strength", value=0.8, visible=False)
 
 
 
865
  resize_img = gr.Checkbox(label="Resize image", value=True, visible=False)
866
  generate_button = gr.Button("Generate")
867
 
868
  with gr.Column():
869
  output_image = gr.Image(label="Generated Image")
870
  output_seed = gr.Text(label="Used Seed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
 
872
  do_img2img.change(
873
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
 
877
 
878
  generate_button.click(
879
  fn=generate_image,
880
+ inputs=[prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img],
881
+ outputs=[output_image, output_seed]
 
 
 
882
  )
883
+
884
+ examples = [
885
+ "a tiny astronaut hatching from an egg on the moon",
886
+ "a cat holding a sign that says hello world",
887
+ "an anime illustration of a wiener schnitzel",
888
+ ]
889
 
890
  return demo
891
 
892
  if __name__ == "__main__":
 
893
  demo = create_demo()
894
+ demo.launch()