import gradio as gr from transformers import MBartForConditionalGeneration, MBart50TokenizerFast import torch class MBartTranslator: """MBartTranslator class provides a simple interface for translating text using the MBart language model. The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-many-mmt" pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name. Attributes: model (MBartForConditionalGeneration): The MBart language model. tokenizer (MBart50TokenizerFast): The MBart tokenizer. """ def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None): self.supported_languages = [ "ar_AR", "de_DE", "en_XX", "es_XX", "fr_XX", "hi_IN", "it_IT", "ja_XX", "ko_XX", "pt_XX", "ru_XX", "zh_XX", "af_ZA", "bn_BD", "bs_XX", "ca_XX", "cs_CZ", "da_XX", "el_GR", "et_EE", "fa_IR", "fi_FI", "gu_IN", "he_IL", "hi_XX", "hr_HR", "hu_HU", "id_ID", "is_IS", "ja_XX", "jv_XX", "ka_GE", "kk_XX", "km_KH", "kn_IN", "ko_KR", "lo_LA", "lt_LT", "lv_LV", "mk_MK", "ml_IN", "mr_IN", "ms_MY", "ne_NP", "nl_XX", "no_XX", "pl_XX", "ro_RO", "si_LK", "sk_SK", "sl_SI", "sq_AL", "sr_XX", "sv_XX", "sw_TZ", "ta_IN", "te_IN", "th_TH", "tl_PH", "tr_TR", "uk_UA", "ur_PK", "vi_VN", "war_PH", "yue_XX", "zh_CN", "zh_TW", ] print("Building translator") print("Loading generator (this may take few minutes the first time as I need to download the model)") self.model = MBartForConditionalGeneration.from_pretrained(model_name) print("Loading tokenizer") self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang) print("Translator is ready") def translate(self, text: str, input_language: str, output_language: str) -> str: """Translate the given text from the input language to the output language. Args: text (str): The text to translate. input_language (str): The input language code (e.g. "hi_IN" for Hindi). output_language (str): The output language code (e.g. "en_US" for English). Returns: str: The translated text. """ if input_language not in self.supported_languages: raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}") if output_language not in self.supported_languages: raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}") self.tokenizer.src_lang = input_language encoded_input = self.tokenizer(text, return_tensors="pt") generated_tokens = self.model.generate( **encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language] ) translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) return translated_text[0] def translate_text(source_lang, target_lang, text): translator = MBartTranslator() return translator.translate(text, source_lang, target_lang) translation_interface = gr.Interface(fn=translate_text, inputs=[gr.inputs.Dropdown(choices=["en_XX", "es_XX", "fr_XX", "zh_XX", "hi_IN"], label="Source Language"), gr.inputs.Dropdown(choices=["en_XX", "es_XX", "fr_XX", "zh_XX", "hi_IN"], label="Target Language"), gr.inputs.Textbox(label="Text to translate")], outputs=gr.outputs.Textbox(label="Translated text")) translation_interface.launch()