sheonhan commited on
Commit
8b8b295
·
1 Parent(s): a8a63e9

use run inference from Space

Browse files
Files changed (4) hide show
  1. app.py +16 -12
  2. lid.176.ftz +3 -0
  3. lid218e.bin +3 -0
  4. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import requests
2
  import os
3
 
 
4
  import gradio as gr
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  import torch
@@ -11,11 +12,8 @@ When comments are created in the community tab, detect the language of the conte
11
  Then, if the detected language is different from the user's language, display an option to translate it.
12
  """
13
 
14
-
15
- TRANSLATION_API_URL = "https://api-inference.huggingface.co/models/t5-base"
16
- LANG_ID_API_URL = "https://noe30ht5sav83xm1.us-east-1.aws.endpoints.huggingface.cloud"
17
  ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
18
- ACCESS_TOKEN = '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'
19
  headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"}
20
 
21
  model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
@@ -32,18 +30,22 @@ language_code_map = {
32
  "Japanese": "jpn_Jpan"
33
  }
34
 
35
-
36
- def translate_from_api(text):
37
- response = requests.post(TRANSLATION_API_URL, headers=headers, json={
38
- "inputs": text, "wait_for_model": True, "use_cache": True})
39
-
40
- return response.json()[0]['translation_text']
 
 
 
 
41
 
42
 
43
  def translate(text, src_lang, tgt_lang):
44
  src_lang_code = language_code_map[src_lang]
45
  tgt_lang_code = language_code_map[tgt_lang]
46
- print(f"src: {src_lang_code} tgt: {tgt_lang_code}")
47
  translation_pipeline = pipeline(
48
  "translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device)
49
  result = translation_pipeline(text)
@@ -55,8 +57,10 @@ def query(text, src_lang, tgt_lang):
55
  lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={
56
  "inputs": text, "wait_for_model": True, "use_cache": True})
57
  lang_id = lang_id_response.json()[0]
 
 
58
 
59
- return [lang_id, translation]
60
 
61
 
62
  examples = [
 
1
  import requests
2
  import os
3
 
4
+ import fasttext
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
  import torch
 
12
  Then, if the detected language is different from the user's language, display an option to translate it.
13
  """
14
 
15
+ LANG_ID_API_URL = "https://q5esh83u7boq5qwd.us-east-1.aws.endpoints.huggingface.cloud"
 
 
16
  ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
 
17
  headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"}
18
 
19
  model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
 
30
  "Japanese": "jpn_Jpan"
31
  }
32
 
33
+ def identify_language(text):
34
+ model_file = "lid218e.bin"
35
+ model_full_path = os.path.join(os.path.dirname(__file__), model_file)
36
+ model = fasttext.load_model(model_full_path)
37
+ predictions = model.predict(text, k=1) # e.g., (('__label__eng_Latn',), array([0.81148803]))
38
+
39
+ PREFIX_LENGTH = 7 # To strip away '__label__' from language code
40
+ language_code = predictions[0][0][PREFIX_LENGTH:]
41
+ return language_code
42
+
43
 
44
 
45
  def translate(text, src_lang, tgt_lang):
46
  src_lang_code = language_code_map[src_lang]
47
  tgt_lang_code = language_code_map[tgt_lang]
48
+
49
  translation_pipeline = pipeline(
50
  "translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device)
51
  result = translation_pipeline(text)
 
57
  lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={
58
  "inputs": text, "wait_for_model": True, "use_cache": True})
59
  lang_id = lang_id_response.json()[0]
60
+
61
+ language_code = identify_language(text)
62
 
63
+ return [language_code, translation]
64
 
65
 
66
  examples = [
lid.176.ftz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f3472cfe8738a7b6099e8e999c3cbfae0dcd15696aac7d7738a8039db603e83
3
+ size 938013
lid218e.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ded5749a2ad79ae9ab7c9190c7c8b97ff20d54ad8b9527ffa50107238fc7f6a
3
+ size 1176355829
requirements.txt CHANGED
@@ -1,2 +1,3 @@
 
1
  torch
2
  transformers
 
1
+ fasttext
2
  torch
3
  transformers