sheonhan commited on
Commit
a7f2f12
1 Parent(s): c757bee

add multi-language translation model

Browse files
Files changed (2) hide show
  1. app.py +56 -15
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,39 +1,80 @@
1
  import requests
2
  import os
 
3
  import gradio as gr
 
 
 
4
 
5
- title = "Translate Text"
6
- description = """"""
7
- article = "Check out [the original repo](https://huggingface.co/language-tools/language-translation) that this demo is based off of."
 
 
8
 
9
 
10
  TRANSLATION_API_URL = "https://api-inference.huggingface.co/models/t5-base"
11
  LANG_ID_API_URL = "https://noe30ht5sav83xm1.us-east-1.aws.endpoints.huggingface.cloud"
12
  ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
13
- # ACCESS_TOKEN = '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'
14
  headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"}
15
 
16
 
17
- def query(payload):
18
- translation_response = requests.post(TRANSLATION_API_URL, headers=headers, json={
19
- "inputs": payload, "wait_for_model": True, "use_cache": True})
20
- translation = translation_response.json()[0]['translation_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={
23
- "inputs": payload, "wait_for_model": True, "use_cache": True})
24
- lang_id = lang_id_response.json()[0][0]
25
 
26
- return [lang_id, translation]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  gr.Interface(
30
  query,
31
- gr.Textbox(lines=2),
 
 
 
 
 
 
32
  outputs=[
33
  gr.Textbox(lines=3, label="Detected Language"),
34
  gr.Textbox(lines=3, label="Translation")
35
  ],
36
  title=title,
37
- description=description,
38
- article=article
39
  ).launch()
 
1
  import requests
2
  import os
3
+
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ import torch
7
+
8
 
9
+ title = "Community Tab Language Detection & Translation"
10
+ description = """
11
+ When comments are created in the community tab, detect the language of the content.
12
+ Then, if the detected language is different from the user's language, display an option to translate it.
13
+ """
14
 
15
 
16
  TRANSLATION_API_URL = "https://api-inference.huggingface.co/models/t5-base"
17
  LANG_ID_API_URL = "https://noe30ht5sav83xm1.us-east-1.aws.endpoints.huggingface.cloud"
18
  ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
19
+ ACCESS_TOKEN = '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'
20
  headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"}
21
 
22
 
23
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
24
+ tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
25
+ device = 0 if torch.cuda.is_available() else -1
26
+ LANGS = ["ace_Arab", "eng_Latn", "fra_Latn", "spa_Latn"]
27
+
28
+
29
+ language_code_map = {
30
+ "English": "eng_Latn",
31
+ "French": "fra_Latn",
32
+ "German": "deu_Latn",
33
+ "Spanish": "spa_Latn",
34
+ "Korean": "kor_Hang",
35
+ "Japanese": "jpn_Jpan"
36
+ }
37
+
38
+
39
+ def translate_from_api(text):
40
+ response = requests.post(TRANSLATION_API_URL, headers=headers, json={
41
+ "inputs": text, "wait_for_model": True, "use_cache": True})
42
 
43
+ return response.json()[0]['translation_text']
 
 
44
 
45
+
46
+ def translate(text, src_lang, tgt_lang):
47
+ src_lang_code = language_code_map[src_lang]
48
+ tgt_lang_code = language_code_map[tgt_lang]
49
+ print(f"src: {src_lang_code} tgt: {tgt_lang_code}")
50
+ translation_pipeline = pipeline(
51
+ "translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device)
52
+ result = translation_pipeline(text)
53
+ return result[0]['translation_text']
54
+
55
+
56
+ def query(text, src_lang, tgt_lang):
57
+ translation = translate(text, src_lang, tgt_lang)
58
+ lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={
59
+ "inputs": text, "wait_for_model": True, "use_cache": True})
60
+ lang_id = lang_id_response.json()[0]
61
+
62
+ return [lang_id, translation]
63
 
64
 
65
  gr.Interface(
66
  query,
67
+ [
68
+ gr.Textbox(lines=2),
69
+ gr.Radio(["English", "French", "Korean"], value="English", label="Source Language"),
70
+ gr.Radio(["Spanish", "German", "Japanese"], value="Spanish", label="Target Language")
71
+ # gr.Radio(["English", "French", "Korean"]),
72
+ # gr.Radio(["Spanish", "German", "French"]),
73
+ ],
74
  outputs=[
75
  gr.Textbox(lines=3, label="Detected Language"),
76
  gr.Textbox(lines=3, label="Translation")
77
  ],
78
  title=title,
79
+ description=description
 
80
  ).launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ transformers