ginipick commited on
Commit
ef45d8e
·
verified ·
1 Parent(s): f1cb913

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -14
app.py CHANGED
@@ -805,21 +805,20 @@ def get_translator(lang):
805
  if lang not in translators_cache:
806
  try:
807
  model_name = TRANSLATORS[lang]
808
- cache_dir = download_model(model_name)
809
 
810
- if cache_dir is None:
811
- return None
812
-
813
- tokenizer = MarianTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
814
- model = MarianMTModel.from_pretrained(model_name, cache_dir=cache_dir)
815
 
816
- model = model.to("cpu")
 
817
 
818
  translators_cache[lang] = {
819
  "model": model,
820
  "tokenizer": tokenizer
821
  }
822
  print(f"Successfully loaded translator for {lang}")
 
823
  except Exception as e:
824
  print(f"Error loading translator for {lang}: {e}")
825
  return None
@@ -835,18 +834,26 @@ def translate_text(text, translator_info):
835
  tokenizer = translator_info["tokenizer"]
836
  model = translator_info["model"]
837
 
 
838
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
839
- translated = model.generate(**inputs)
840
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
 
 
 
 
 
841
 
842
  print(f"Original text: {text}")
843
- print(f"Translated text: {result}")
 
 
844
 
845
- return result
846
  except Exception as e:
847
  print(f"Translation error: {e}")
848
  return text
849
 
 
850
  @spaces.GPU
851
  @torch.no_grad()
852
  def generate_image(
@@ -854,17 +861,22 @@ def generate_image(
854
  do_img2img, init_image, image2image_strength, resize_img,
855
  progress=gr.Progress(track_tqdm=True),
856
  ):
 
857
  try:
858
  if source_lang != "English":
859
  translator_info = get_translator(source_lang)
860
- translated_prompt = translate_text(prompt, translator_info)
861
- print(f"Using translated prompt: {translated_prompt}")
 
 
 
 
862
  else:
863
  translated_prompt = prompt
864
  except Exception as e:
865
  print(f"Translation failed: {e}")
866
  translated_prompt = prompt
867
-
868
 
869
 
870
 
 
805
  if lang not in translators_cache:
806
  try:
807
  model_name = TRANSLATORS[lang]
 
808
 
809
+ # pipeline 사용 대신 직접 모델 로드
810
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
811
+ model = MarianMTModel.from_pretrained(model_name)
 
 
812
 
813
+ # CPU에서 실행
814
+ model = model.to("cpu").eval()
815
 
816
  translators_cache[lang] = {
817
  "model": model,
818
  "tokenizer": tokenizer
819
  }
820
  print(f"Successfully loaded translator for {lang}")
821
+
822
  except Exception as e:
823
  print(f"Error loading translator for {lang}: {e}")
824
  return None
 
834
  tokenizer = translator_info["tokenizer"]
835
  model = translator_info["model"]
836
 
837
+ # 입력 텍스트 전처리
838
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
839
+
840
+ # 번역 수행
841
+ with torch.no_grad():
842
+ outputs = model.generate(**inputs)
843
+
844
+ # 번역 결과 디코딩
845
+ translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
846
 
847
  print(f"Original text: {text}")
848
+ print(f"Translated text: {translated}")
849
+
850
+ return translated
851
 
 
852
  except Exception as e:
853
  print(f"Translation error: {e}")
854
  return text
855
 
856
+
857
  @spaces.GPU
858
  @torch.no_grad()
859
  def generate_image(
 
861
  do_img2img, init_image, image2image_strength, resize_img,
862
  progress=gr.Progress(track_tqdm=True),
863
  ):
864
+ # 번역 처리
865
  try:
866
  if source_lang != "English":
867
  translator_info = get_translator(source_lang)
868
+ if translator_info is not None:
869
+ translated_prompt = translate_text(prompt, translator_info)
870
+ print(f"Using translated prompt: {translated_prompt}")
871
+ else:
872
+ print(f"No translator available for {source_lang}, using original prompt")
873
+ translated_prompt = prompt
874
  else:
875
  translated_prompt = prompt
876
  except Exception as e:
877
  print(f"Translation failed: {e}")
878
  translated_prompt = prompt
879
+
880
 
881
 
882