Omnibus's picture
Update app.py
82b38e9
import gradio as gr
import os
import torch
import gradio as gr
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B").to(device)
model.eval()
class Language:
def __init__(self, name, code):
self.name = name
self.code = code
lang_id = [
Language("Afrikaans", "af"),
Language("Albanian", "sq"),
Language("Amharic", "am"),
Language("Arabic", "ar"),
Language("Armenian", "hy"),
Language("Asturian", "ast"),
Language("Azerbaijani", "az"),
Language("Bashkir", "ba"),
Language("Belarusian", "be"),
Language("Bulgarian", "bg"),
Language("Bengali", "bn"),
Language("Breton", "br"),
Language("Bosnian", "bs"),
Language("Burmese", "my"),
Language("Catalan", "ca"),
Language("Cebuano", "ceb"),
Language("Chinese","zh"),
Language("Croatian","hr"),
Language("Czech","cs"),
Language("Danish","da"),
Language("Dutch","nl"),
Language("English","en"),
Language("Estonian","et"),
Language("Fulah","ff"),
Language("Finnish","fi"),
Language("French","fr"),
Language("Western Frisian","fy"),
Language("Gaelic","gd"),
Language("Galician","gl"),
Language("Georgian","ka"),
Language("German","de"),
Language("Greek","el"),
Language("Gujarati","gu"),
Language("Hausa","ha"),
Language("Hebrew","he"),
Language("Hindi","hi"),
Language("Haitian","ht"),
Language("Hungarian","hu"),
Language("Irish","ga"),
Language("Indonesian","id"),
Language("Igbo","ig"),
Language("Iloko","ilo"),
Language("Icelandic","is"),
Language("Italian","it"),
Language("Japanese","ja"),
Language("Javanese","jv"),
Language("Kazakh","kk"),
Language("Central Khmer","km"),
Language("Kannada","kn"),
Language("Korean","ko"),
Language("Luxembourgish","lb"),
Language("Ganda","lg"),
Language("Lingala","ln"),
Language("Lao","lo"),
Language("Lithuanian","lt"),
Language("Latvian","lv"),
Language("Malagasy","mg"),
Language("Macedonian","mk"),
Language("Malayalam","ml"),
Language("Mongolian","mn"),
Language("Marathi","mr"),
Language("Malay","ms"),
Language("Nepali","ne"),
Language("Norwegian","no"),
Language("Northern Sotho","ns"),
Language("Occitan","oc"),
Language("Oriya","or"),
Language("Panjabi","pa"),
Language("Persian","fa"),
Language("Polish","pl"),
Language("Pushto","ps"),
Language("Portuguese","pt"),
Language("Romanian","ro"),
Language("Russian","ru"),
Language("Sindhi","sd"),
Language("Sinhala","si"),
Language("Slovak","sk"),
Language("Slovenian","sl"),
Language("Spanish","es"),
Language("Somali","so"),
Language("Serbian","sr"),
Language("Serbian (cyrillic)","sr"),
Language("Serbian (latin)","sr"),
Language("Swati","ss"),
Language("Sundanese","su"),
Language("Swedish","sv"),
Language("Swahili","sw"),
Language("Tamil","ta"),
Language("Thai","th"),
Language("Tagalog","tl"),
Language("Tswana","tn"),
Language("Turkish","tr"),
Language("Ukrainian","uk"),
Language("Urdu","ur"),
Language("Uzbek","uz"),
Language("Vietnamese","vi"),
Language("Welsh","cy"),
Language("Wolof","wo"),
Language("Xhosa","xh"),
Language("Yiddish","yi"),
Language("Yoruba","yo"),
Language("Zulu","zu"),
]
d_lang = lang_id[21]
#d_lang_code = d_lang.code
def trans_page(input,trg):
src_lang = d_lang.code
for lang in lang_id:
if lang.name == trg:
trg_lang = lang.code
if trg_lang != src_lang:
tokenizer.src_lang = src_lang
with torch.no_grad():
encoded_input = tokenizer(input, return_tensors="pt").to(device)
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
else:
translated_text=input
pass
return translated_text
def trans_to(input,src,trg):
for lang in lang_id:
if lang.name == trg:
trg_lang = lang.code
for lang in lang_id:
if lang.name == src:
src_lang = lang.code
if trg_lang != src_lang:
tokenizer.src_lang = src_lang
with torch.no_grad():
encoded_input = tokenizer(input, return_tensors="pt").to(device)
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
else:
translated_text=input
pass
return translated_text
md1 = "Translate - 100 Languages"
with gr.Blocks() as transbot:
#this=gr.State()
with gr.Row():
gr.Column()
with gr.Column():
with gr.Row():
t_space = gr.Dropdown(label="Translate Space to:", choices=[l.name for l in lang_id], value="English")
#t_space = gr.Dropdown(label="Translate Space", choices=list(lang_id.keys()),value="English")
t_submit = gr.Button("Translate Space")
gr.Column()
with gr.Row():
gr.Column()
with gr.Column():
md = gr.Markdown("""<h1><center>Translate - 100 Languages</center></h1><h4><center>Translation may not be accurate</center></h4>""")
with gr.Row():
lang_from = gr.Dropdown(label="From:", choices=[l.name for l in lang_id],value="English")
lang_to = gr.Dropdown(label="To:", choices=[l.name for l in lang_id],value="Chinese")
#lang_from = gr.Dropdown(label="From:", choices=list(lang_id.keys()),value="English")
#lang_to = gr.Dropdown(label="To:", choices=list(lang_id.keys()),value="Chinese")
submit = gr.Button("Go")
with gr.Row():
with gr.Column():
message = gr.Textbox(label="Prompt",placeholder="Enter Prompt",lines=4)
translated = gr.Textbox(label="Translated",lines=4,interactive=False)
gr.Column()
t_submit.click(trans_page,[md,t_space],[md])
submit.click(trans_to, inputs=[message,lang_from,lang_to], outputs=[translated])
transbot.queue(concurrency_count=20)
transbot.launch()