Daimon's picture
Update app.py
3187942
raw
history blame
6.6 kB
import streamlit as st
import pandas as pd
from pathlib import Path
#from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from transformers import M2M100ForConditionalGeneration
from tokenization_small100 import SMALL100Tokenizer
import io
st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layout="wide")
@st.cache
def load_model():
model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100")
return model
def get_translation(src_code, trg_code, src):
#tokenizer.src_lang = src_code
#encoded = tokenizer(src, return_tensors="pt")
#generated_tokens = model.generate(
#**encoded,
#forced_bos_token_id=tokenizer.lang_code_to_id[trg_code]
#)
#trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
model = load_model()
tokenizer.tgt_lang = trg_code
encoded = tokenizer(src, return_tensors="pt")
generated_tokens = model.generate(**encoded)
trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return trg
def open_input(the_file):
sheets = []
if the_file.name.endswith('.tsv'):
parsed = pd.read_csv(the_file, sep="\t")
elif the_file.name.endswith('.xlsx'):
xlsx = pd.ExcelFile(the_file)
if len(xlsx.sheet_names) > 1:
sheets = [sheet for sheet in xlsx.sheet_names]
parsed = [pd.read_excel(xlsx, sheet) for sheet in sheets]
else:
parsed = pd.read_excel(the_file)
return parsed, sheets
def translate_data(df, s_lang, t_lang, col_for_translation, languages):
translated_data = []
new_df = df
for text in df[col_for_translation]:
if len(text) > 0 and s_lang in languages and t_lang in languages:
with st.spinner("Translating..."):
try:
target_text = get_translation(s_lang, t_lang, text)[0]
translated_data.append(target_text)
except:
st.subheader("Translation failed :sad:")
break
else:
st.write("Please enter the source text, source language and target language.")
new_df["SMALL-100 translation"] = translated_data
return new_df
def select_column(data, valid_source, valid_target, is_excel=False):
if is_excel:
columns = (col for col in data[0].columns)
else:
columns = (col for col in data.columns)
src_col = st.selectbox(
'Select the column to translate (WARNING: You can only select a single column - please make sure all columns are named accordingly):',
columns,
)
if src_col:
col_src_lang = st.selectbox(
'Source language:',
valid_source,
)
col_trg_lang = st.selectbox(
'Target language:',
valid_target,
)
submitted_cols = st.button("Translate column")
return submitted_cols, src_col, col_src_lang, col_trg_lang
st.subheader("SMALL-100 Translator")
source = "In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move."
target = ""
#model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
#tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = SMALL100Tokenizer.from_pretrained("alirezamsh/small100")
#valid_languages = ['de_DE', 'en_XX', 'it_IT']
valid_languages = ['de', 'it', 'en']
valid_languages_tuple = (lang for lang in valid_languages)
valid_languages_tuple_trg = (lang for lang in valid_languages)
with st.form("my_form"):
left_c, right_c = st.columns(2)
#with left_c:
src_lang = st.selectbox(
'Source language',
valid_languages_tuple,
)
#with right_c:
trg_lang = st.selectbox(
'Target language',
valid_languages_tuple_trg,
)
source = st.text_area("Source", value=source, height=130, placeholder="Enter the source text...")
submitted = st.form_submit_button("Translate")
if submitted:
if len(source) > 0 and src_lang in valid_languages and trg_lang in valid_languages:
with st.spinner("Translating..."):
try:
target = get_translation(src_lang, trg_lang, source)[0]
st.subheader("Translation done!")
target = st.text_area("Target", value=target, height=130)
except:
st.subheader("Translation failed :sad:")
else:
st.write("Please enter the source text, source language and target language.")
st.subheader('Input XLSX/TSV')
uploaded_file = st.file_uploader("Choose a file")
done = False
if uploaded_file is not None:
valid_col = (lang for lang in valid_languages)
valid_col_trg = (lang for lang in valid_languages)
data, sheets = open_input(uploaded_file)
if len(sheets) > 0:
translated_sheets = []
submitted_cols, src_col, src_code, trg_code = select_column(data, valid_col, valid_col_trg, is_excel=True)
if submitted_cols:
for sheet in data:
translated_sheets.append(translate_data(sheet, src_code, trg_code, src_col, valid_languages))
done = True
else:
submitted_cols, src_col, valid_col, valid_col_trg = select_column(data, valid_col, valid_col_trg)
st.subheader("DataFrame")
st.write(data)
st.write(data.describe())
if submitted_cols:
new_df = translate_data(data, valid_col, valid_col_trg, src_col, valid_languages)
done = True
if done:
st.subheader("Translated DataFrame")
if len(sheets) > 0:
pass
buffer = io.BytesIO()
with pd.ExcelWriter(buffer) as writer:
for idx, sheet in enumerate(translated_sheets):
sheet.to_excel(writer, sheet_name=sheets[idx])
st.download_button('Download XLSX', buffer, 'translated_file.xlsx', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', key='download-xlsx')
else:
st.write(new_df)
st.write(new_df.describe())
to_dl = new_df.to_csv(index=False, sep='\t').encode('utf-8')
st.download_button('Download TSV', to_dl, 'translated_file.tsv', 'text/tsv', key='download-tsv')
else:
st.info("☝️ Upload a TSV file")