Spaces:
Runtime error
Runtime error
import requests | |
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import torch | |
title = "Community Tab Language Detection & Translation" | |
description = """ | |
When comments are created in the community tab, detect the language of the content. | |
Then, if the detected language is different from the user's language, display an option to translate it. | |
""" | |
TRANSLATION_API_URL = "https://api-inference.huggingface.co/models/t5-base" | |
LANG_ID_API_URL = "https://noe30ht5sav83xm1.us-east-1.aws.endpoints.huggingface.cloud" | |
ACCESS_TOKEN = os.environ.get("ACCESS_TOKEN") | |
ACCESS_TOKEN = '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' | |
headers = {"Authorization": f"Bearer {ACCESS_TOKEN}"} | |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
device = 0 if torch.cuda.is_available() else -1 | |
LANGS = ["ace_Arab", "eng_Latn", "fra_Latn", "spa_Latn"] | |
language_code_map = { | |
"English": "eng_Latn", | |
"French": "fra_Latn", | |
"German": "deu_Latn", | |
"Spanish": "spa_Latn", | |
"Korean": "kor_Hang", | |
"Japanese": "jpn_Jpan" | |
} | |
def translate_from_api(text): | |
response = requests.post(TRANSLATION_API_URL, headers=headers, json={ | |
"inputs": text, "wait_for_model": True, "use_cache": True}) | |
return response.json()[0]['translation_text'] | |
def translate(text, src_lang, tgt_lang): | |
src_lang_code = language_code_map[src_lang] | |
tgt_lang_code = language_code_map[tgt_lang] | |
print(f"src: {src_lang_code} tgt: {tgt_lang_code}") | |
translation_pipeline = pipeline( | |
"translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device) | |
result = translation_pipeline(text) | |
return result[0]['translation_text'] | |
def query(text, src_lang, tgt_lang): | |
translation = translate(text, src_lang, tgt_lang) | |
lang_id_response = requests.post(LANG_ID_API_URL, headers=headers, json={ | |
"inputs": text, "wait_for_model": True, "use_cache": True}) | |
lang_id = lang_id_response.json()[0] | |
return [lang_id, translation] | |
gr.Interface( | |
query, | |
[ | |
gr.Textbox(lines=2), | |
gr.Radio(["English", "French", "Korean"], value="English", label="Source Language"), | |
gr.Radio(["Spanish", "German", "Japanese"], value="Spanish", label="Target Language") | |
# gr.Radio(["English", "French", "Korean"]), | |
# gr.Radio(["Spanish", "German", "French"]), | |
], | |
outputs=[ | |
gr.Textbox(lines=3, label="Detected Language"), | |
gr.Textbox(lines=3, label="Translation") | |
], | |
title=title, | |
description=description | |
).launch() | |