ginipick commited on
Commit
d450fd7
·
verified ·
1 Parent(s): 189b9e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -24,9 +24,6 @@ from transformers import T5EncoderModel, T5Tokenizer
24
  # from optimum.quanto import freeze, qfloat8, quantize
25
  from transformers import pipeline
26
 
27
- ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
28
- ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
29
-
30
  class HFEmbedder(nn.Module):
31
  def __init__(self, version: str, max_length: int, **hf_kwargs):
32
  super().__init__()
@@ -749,8 +746,12 @@ model = Flux().to(dtype=torch.bfloat16, device="cuda")
749
  result = model.load_state_dict(sd)
750
  model_zero_init = False
751
 
752
- # model = Flux().to(dtype=torch.bfloat16, device="cuda")
753
- # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
 
 
 
 
754
 
755
 
756
  @spaces.GPU
@@ -762,14 +763,17 @@ def generate_image(
762
  ):
763
  translated_prompt = prompt
764
 
765
- # 한글 또는 일본어 문자 감지
766
  def contains_korean(text):
767
  return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
768
 
769
  def contains_japanese(text):
770
  return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
771
 
772
- # 한글이나 일본어가 있으면 번역
 
 
 
773
  if contains_korean(prompt):
774
  translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
775
  print(f"Translated Korean prompt: {translated_prompt}")
@@ -778,6 +782,14 @@ def generate_image(
778
  translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
779
  print(f"Translated Japanese prompt: {translated_prompt}")
780
  prompt = translated_prompt
 
 
 
 
 
 
 
 
781
 
782
  if seed == 0:
783
  seed = int(random.random() * 1000000)
 
24
  # from optimum.quanto import freeze, qfloat8, quantize
25
  from transformers import pipeline
26
 
 
 
 
27
  class HFEmbedder(nn.Module):
28
  def __init__(self, version: str, max_length: int, **hf_kwargs):
29
  super().__init__()
 
746
  result = model.load_state_dict(sd)
747
  model_zero_init = False
748
 
749
+
750
+
751
+ ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
752
+ ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
753
+ zh_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
754
+
755
 
756
 
757
  @spaces.GPU
 
763
  ):
764
  translated_prompt = prompt
765
 
766
+ # 한글, 일본어, 중국어 문자 감지
767
  def contains_korean(text):
768
  return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
769
 
770
  def contains_japanese(text):
771
  return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
772
 
773
+ def contains_chinese(text):
774
+ return any('\u4e00' <= c <= '\u9fff' for c in text)
775
+
776
+ # 한글, 일본어, 중국어가 있으면 번역
777
  if contains_korean(prompt):
778
  translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
779
  print(f"Translated Korean prompt: {translated_prompt}")
 
782
  translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
783
  print(f"Translated Japanese prompt: {translated_prompt}")
784
  prompt = translated_prompt
785
+ elif contains_chinese(prompt):
786
+ translated_prompt = zh_translator(prompt, max_length=512)[0]['translation_text']
787
+ print(f"Translated Chinese prompt: {translated_prompt}")
788
+ prompt = translated_prompt
789
+
790
+
791
+
792
+
793
 
794
  if seed == 0:
795
  seed = int(random.random() * 1000000)