File size: 2,691 Bytes
476e166
 
a7f2f12
476e166
a7f2f12
 
 
476e166
a7f2f12
 
 
 
 
476e166
 
ceaa373
 
476e166
a7f2f12
476e166
 
ceaa373
a7f2f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceaa373
a7f2f12
ceaa373
a7f2f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceaa373
 
476e166
ceaa373
a7f2f12
 
 
 
 
 
 
ceaa373
 
 
 
476e166
a7f2f12
97e52ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import requests
import os

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch


title = "Community Tab Language Detection & Translation"
description = """
When comments are created in the community tab, detect the language of the content.
Then, if the detected language is different from the user's language, display an option to translate it.
"""


TRANSLATION_API_URL = "https://api-inference.huggingface.co/models/t5-base"
LANG_ID_API_URL = "https://noe30ht5sav83xm1.us-east-1.aws.endpoints.huggingface.cloud"
ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN")
ACCESS_TOKEN = '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'
headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"}


model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
device = 0 if torch.cuda.is_available() else -1
LANGS = ["ace_Arab", "eng_Latn", "fra_Latn", "spa_Latn"]


language_code_map = {
    "English": "eng_Latn",
    "French": "fra_Latn",
    "German": "deu_Latn",
    "Spanish": "spa_Latn",
    "Korean": "kor_Hang",
    "Japanese": "jpn_Jpan"
}


def translate_from_api(text):
    response = requests.post(TRANSLATION_API_URL, headers=headers, json={
        "inputs": text, "wait_for_model": True, "use_cache": True})

    return response.json()[0]['translation_text']


def translate(text, src_lang, tgt_lang):
    src_lang_code = language_code_map[src_lang]
    tgt_lang_code = language_code_map[tgt_lang]
    print(f"src: {src_lang_code} tgt: {tgt_lang_code}")
    translation_pipeline = pipeline(
        "translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device)
    result = translation_pipeline(text)
    return result[0]['translation_text']


def query(text, src_lang, tgt_lang):
    translation = translate(text, src_lang, tgt_lang)
    lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={
        "inputs": text, "wait_for_model": True, "use_cache": True})
    lang_id = lang_id_response.json()[0]

    return [lang_id, translation]


gr.Interface(
    query,
    [
        gr.Textbox(lines=2),
        gr.Radio(["English", "French", "Korean"], value="English", label="Source Language"),
        gr.Radio(["Spanish", "German", "Japanese"], value="Spanish", label="Target Language")
        # gr.Radio(["English", "French", "Korean"]),
        # gr.Radio(["Spanish", "German", "French"]),
    ],
    outputs=[
        gr.Textbox(lines=3, label="Detected Language"),
        gr.Textbox(lines=3, label="Translation")
    ],
    title=title,
    description=description
).launch()