Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import time | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
codes_as_string = '''Acehnese (Arabic script) ace_Arab | |
Acehnese (Latin script) ace_Latn | |
Mesopotamian Arabic acm_Arab | |
Ta’izzi-Adeni Arabic acq_Arab | |
Tunisian Arabic aeb_Arab | |
Afrikaans afr_Latn | |
South Levantine Arabic ajp_Arab | |
Akan aka_Latn | |
Amharic amh_Ethi | |
North Levantine Arabic apc_Arab | |
Modern Standard Arabic arb_Arab | |
Modern Standard Arabic (Romanized) arb_Latn | |
Najdi Arabic ars_Arab | |
Moroccan Arabic ary_Arab | |
Egyptian Arabic arz_Arab | |
Assamese asm_Beng | |
Asturian ast_Latn | |
Awadhi awa_Deva | |
Central Aymara ayr_Latn | |
South Azerbaijani azb_Arab | |
North Azerbaijani azj_Latn | |
Bashkir bak_Cyrl | |
Bambara bam_Latn | |
Balinese ban_Latn | |
Belarusian bel_Cyrl | |
Bemba bem_Latn | |
Bengali ben_Beng | |
Bhojpuri bho_Deva | |
Banjar (Arabic script) bjn_Arab | |
Banjar (Latin script) bjn_Latn | |
Standard Tibetan bod_Tibt | |
Bosnian bos_Latn | |
Buginese bug_Latn | |
Bulgarian bul_Cyrl | |
Catalan cat_Latn | |
Cebuano ceb_Latn | |
Czech ces_Latn | |
Chokwe cjk_Latn | |
Central Kurdish ckb_Arab | |
Crimean Tatar crh_Latn | |
Welsh cym_Latn | |
Danish dan_Latn | |
German deu_Latn | |
Southwestern Dinka dik_Latn | |
Dyula dyu_Latn | |
Dzongkha dzo_Tibt | |
Greek ell_Grek | |
English eng_Latn | |
Esperanto epo_Latn | |
Estonian est_Latn | |
Basque eus_Latn | |
Ewe ewe_Latn | |
Faroese fao_Latn | |
Fijian fij_Latn | |
Finnish fin_Latn | |
Fon fon_Latn | |
French fra_Latn | |
Friulian fur_Latn | |
Nigerian Fulfulde fuv_Latn | |
Scottish Gaelic gla_Latn | |
Irish gle_Latn | |
Galician glg_Latn | |
Guarani grn_Latn | |
Gujarati guj_Gujr | |
Haitian Creole hat_Latn | |
Hausa hau_Latn | |
Hebrew heb_Hebr | |
Hindi hin_Deva | |
Chhattisgarhi hne_Deva | |
Croatian hrv_Latn | |
Hungarian hun_Latn | |
Armenian hye_Armn | |
Igbo ibo_Latn | |
Ilocano ilo_Latn | |
Indonesian ind_Latn | |
Icelandic isl_Latn | |
Italian ita_Latn | |
Javanese jav_Latn | |
Japanese jpn_Jpan | |
Kabyle kab_Latn | |
Jingpho kac_Latn | |
Kamba kam_Latn | |
Kannada kan_Knda | |
Kashmiri (Arabic script) kas_Arab | |
Kashmiri (Devanagari script) kas_Deva | |
Georgian kat_Geor | |
Central Kanuri (Arabic script) knc_Arab | |
Central Kanuri (Latin script) knc_Latn | |
Kazakh kaz_Cyrl | |
Kabiyè kbp_Latn | |
Kabuverdianu kea_Latn | |
Khmer khm_Khmr | |
Kikuyu kik_Latn | |
Kinyarwanda kin_Latn | |
Kyrgyz kir_Cyrl | |
Kimbundu kmb_Latn | |
Northern Kurdish kmr_Latn | |
Kikongo kon_Latn | |
Korean kor_Hang | |
Lao lao_Laoo | |
Ligurian lij_Latn | |
Limburgish lim_Latn | |
Lingala lin_Latn | |
Lithuanian lit_Latn | |
Lombard lmo_Latn | |
Latgalian ltg_Latn | |
Luxembourgish ltz_Latn | |
Luba-Kasai lua_Latn | |
Ganda lug_Latn | |
Luo luo_Latn | |
Mizo lus_Latn | |
Standard Latvian lvs_Latn | |
Magahi mag_Deva | |
Maithili mai_Deva | |
Malayalam mal_Mlym | |
Marathi mar_Deva | |
Minangkabau (Arabic script) min_Arab | |
Minangkabau (Latin script) min_Latn | |
Macedonian mkd_Cyrl | |
Plateau Malagasy plt_Latn | |
Maltese mlt_Latn | |
Meitei (Bengali script) mni_Beng | |
Halh Mongolian khk_Cyrl | |
Mossi mos_Latn | |
Maori mri_Latn | |
Burmese mya_Mymr | |
Dutch nld_Latn | |
Norwegian Nynorsk nno_Latn | |
Norwegian Bokmål nob_Latn | |
Nepali npi_Deva | |
Northern Sotho nso_Latn | |
Nuer nus_Latn | |
Nyanja nya_Latn | |
Occitan oci_Latn | |
West Central Oromo gaz_Latn | |
Odia ory_Orya | |
Pangasinan pag_Latn | |
Eastern Panjabi pan_Guru | |
Papiamento pap_Latn | |
Western Persian pes_Arab | |
Polish pol_Latn | |
Portuguese por_Latn | |
Dari prs_Arab | |
Southern Pashto pbt_Arab | |
Ayacucho Quechua quy_Latn | |
Romanian ron_Latn | |
Rundi run_Latn | |
Russian rus_Cyrl | |
Sango sag_Latn | |
Sanskrit san_Deva | |
Santali sat_Olck | |
Sicilian scn_Latn | |
Shan shn_Mymr | |
Sinhala sin_Sinh | |
Slovak slk_Latn | |
Slovenian slv_Latn | |
Samoan smo_Latn | |
Shona sna_Latn | |
Sindhi snd_Arab | |
Somali som_Latn | |
Southern Sotho sot_Latn | |
Spanish spa_Latn | |
Tosk Albanian als_Latn | |
Sardinian srd_Latn | |
Serbian srp_Cyrl | |
Swati ssw_Latn | |
Sundanese sun_Latn | |
Swedish swe_Latn | |
Swahili swh_Latn | |
Silesian szl_Latn | |
Tamil tam_Taml | |
Tatar tat_Cyrl | |
Telugu tel_Telu | |
Tajik tgk_Cyrl | |
Tagalog tgl_Latn | |
Thai tha_Thai | |
Tigrinya tir_Ethi | |
Tamasheq (Latin script) taq_Latn | |
Tamasheq (Tifinagh script) taq_Tfng | |
Tok Pisin tpi_Latn | |
Tswana tsn_Latn | |
Tsonga tso_Latn | |
Turkmen tuk_Latn | |
Tumbuka tum_Latn | |
Turkish tur_Latn | |
Twi twi_Latn | |
Central Atlas Tamazight tzm_Tfng | |
Uyghur uig_Arab | |
Ukrainian ukr_Cyrl | |
Umbundu umb_Latn | |
Urdu urd_Arab | |
Northern Uzbek uzn_Latn | |
Venetian vec_Latn | |
Vietnamese vie_Latn | |
Waray war_Latn | |
Wolof wol_Latn | |
Xhosa xho_Latn | |
Eastern Yiddish ydd_Hebr | |
Yoruba yor_Latn | |
Yue Chinese yue_Hant | |
Chinese (Simplified) zho_Hans | |
Chinese (Traditional) zho_Hant | |
Standard Malay zsm_Latn | |
Zulu zul_Latn''' | |
def load_models(): | |
# build model and tokenizer | |
model_name_dict = { | |
'nllb-1.3B': "ychenNLP/nllb-200-distilled-1.3B-easyproject", | |
} | |
model_dict = {} | |
for call_name, real_name in model_name_dict.items(): | |
print('\tLoading model: %s' % call_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(real_name) | |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
model_dict[call_name+'_model'] = model | |
model_dict[call_name+'_tokenizer'] = tokenizer | |
return model_dict | |
def translation(source, target, text): | |
if len(model_dict) == 2: | |
model_name = 'nllb-1.3B' | |
start_time = time.time() | |
source = flores_codes[source] | |
target = flores_codes[target] | |
model = model_dict[model_name + '_model'] | |
tokenizer = model_dict[model_name + '_tokenizer'] | |
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target) | |
output = translator(text, max_length=400) | |
end_time = time.time() | |
full_output = output | |
output = output[0]['translation_text'] | |
# result = {'inference_time': end_time - start_time, | |
# 'source': source, | |
# 'target': target, | |
# 'result': output, | |
# 'full_output': full_output} | |
return output | |
if __name__ == '__main__': | |
print('\tinit models') | |
codes_as_string = codes_as_string.split('\n') | |
flores_codes = {} | |
for code in codes_as_string: | |
lang, lang_code = code.split('\t') | |
flores_codes[lang] = lang_code | |
global model_dict | |
model_dict = load_models() | |
# define gradio demo | |
lang_codes = list(flores_codes.keys()) | |
inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'), | |
gr.inputs.Dropdown(lang_codes, default='Chinese (Simplified)', label='Target'), | |
gr.inputs.Textbox(lines=5, label="Input text"), | |
] | |
outputs = gr.inputs.Textbox(label="Output text") | |
title = "EasyProject: a simple end-to-end label projection method for cross-lingual transfer" | |
demo_status = "Check out our paper: Frustratingly Easy Label Projection for Cross-lingual Transfer (https://arxiv.org/abs/2211.15613). Powered by (ychenNLP/nllb-200-distilled-1.3B-easyproject)." | |
description = f"{demo_status}" | |
examples = [ | |
['English', 'Akan', "Davies is leaving to become chairman of the [London School of Economics]"], | |
['English', 'Telugu', 'Only [1] France [/1] and [2] Britain [/2] backed Fischer ’s proposal.'], | |
['English', 'Chinese (Simplified)', "[1] Gordon Brown [/1] on Tuesday named the current [2] head [/2] of the country 's energy regulator as the new chairman."], | |
['English', 'Afrikaans', "[1] Former [/1] senior banker [2] Callum McCarthy [/2] begins what is one of the most important jobs in London"], | |
['English', 'Kyrgyz', 'i would like to find flights from [0] columbus [/0] to [1] minneapolis [/1] on [2] monday june fourteenth [/2]'], | |
] | |
gr.Interface(translation, | |
inputs, | |
outputs, | |
title=title, | |
description=description, | |
examples=examples, | |
examples_per_page=50, | |
theme="JohnSmith9982/small_and_pretty" | |
).launch() |