|
import time |
|
import torch |
|
|
|
import psutil |
|
import streamlit as st |
|
|
|
from generator import GeneratorFactory |
|
|
|
device = torch.cuda.device_count() - 1 |
|
|
|
TRANSLATION_NL_TO_EN = "translation_en_to_nl" |
|
|
|
GENERATOR_LIST = [ |
|
{ |
|
"model_name": "Helsinki-NLP/opus-mt-en-nl", |
|
"desc": "Opus MT en->nl", |
|
"task": TRANSLATION_NL_TO_EN, |
|
"split_sentences": True, |
|
}, |
|
{ |
|
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi", |
|
"desc": "T5 small nl24 ccmatrix en->nl", |
|
"task": TRANSLATION_NL_TO_EN, |
|
"split_sentences": True, |
|
}, |
|
{ |
|
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl", |
|
"desc": "longT5 large nl8 256cc/512beta/512l en->nl", |
|
"task": TRANSLATION_NL_TO_EN, |
|
"split_sentences": False, |
|
}, |
|
{ |
|
"model_name": "yhavinga/byt5-small-ccmatrix-en-nl", |
|
"desc": "ByT5 small ccmatrix en->nl", |
|
"task": TRANSLATION_NL_TO_EN, |
|
"split_sentences": True, |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Babel", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
page_icon="๐", |
|
) |
|
|
|
if "generators" not in st.session_state: |
|
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST) |
|
generators = st.session_state["generators"] |
|
|
|
with open("style.css") as f: |
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
st.sidebar.image("babel.png", width=200) |
|
st.sidebar.markdown( |
|
"""# Babel |
|
Vertaal van en naar Engels""" |
|
) |
|
st.sidebar.title("Parameters:") |
|
if "prompt_box" not in st.session_state: |
|
|
|
st.session_state[ |
|
"prompt_box" |
|
] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid. |
|
|
|
And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back. |
|
|
|
It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes. |
|
|
|
โMy father is very ill,โ she said without a word of introduction. โThe nurse is frightened. Could you come in and help?โ""" |
|
st.session_state["text"] = st.text_area( |
|
"Enter text", st.session_state.prompt_box, height=250 |
|
) |
|
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1) |
|
num_beam_groups = st.sidebar.number_input( |
|
"Num beam groups", min_value=1, max_value=10, value=1 |
|
) |
|
length_penalty = st.sidebar.number_input( |
|
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1 |
|
) |
|
st.sidebar.markdown( |
|
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate) |
|
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate). |
|
""" |
|
) |
|
|
|
params = { |
|
"num_beams": num_beams, |
|
"num_beam_groups": num_beam_groups, |
|
"length_penalty": length_penalty, |
|
"early_stopping": True, |
|
} |
|
|
|
if st.button("Run"): |
|
memory = psutil.virtual_memory() |
|
|
|
for generator in generators: |
|
st.markdown(f"๐งฎ **Model `{generator}`**") |
|
time_start = time.time() |
|
result, params_used = generator.generate( |
|
text=st.session_state.text, **params |
|
) |
|
time_end = time.time() |
|
time_diff = time_end - time_start |
|
|
|
st.write(result.replace("\n", " \n")) |
|
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()]) |
|
st.markdown(f" ๐ *generated in {time_diff:.2f}s, `{text_line}`*") |
|
|
|
st.write( |
|
f""" |
|
--- |
|
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB* |
|
""" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|