import time import torch import psutil import streamlit as st from generator import GeneratorFactory device = torch.cuda.device_count() - 1 TRANSLATION_NL_TO_EN = "translation_en_to_nl" GENERATOR_LIST = [ { "model_name": "Helsinki-NLP/opus-mt-en-nl", "desc": "Opus MT en->nl", "task": TRANSLATION_NL_TO_EN, "split_sentences": True, }, { "model_name": "yhavinga/t5-small-24L-ccmatrix-multi", "desc": "T5 small nl24 ccmatrix en->nl", "task": TRANSLATION_NL_TO_EN, "split_sentences": True, }, { "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl", "desc": "longT5 large nl8 256cc/512beta/512l en->nl", "task": TRANSLATION_NL_TO_EN, "split_sentences": False, }, { "model_name": "yhavinga/byt5-small-ccmatrix-en-nl", "desc": "ByT5 small ccmatrix en->nl", "task": TRANSLATION_NL_TO_EN, "split_sentences": True, }, # { # "model_name": "yhavinga/t5-eff-large-8l-nedd-en-nl", # "desc": "T5 eff large nl8 en->nl", # "task": TRANSLATION_NL_TO_EN, # "split_sentences": True, # }, # { # "model_name": "yhavinga/t5-base-36L-ccmatrix-multi", # "desc": "T5 base nl36 ccmatrix en->nl", # "task": TRANSLATION_NL_TO_EN, # "split_sentences": True, # }, # { # "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl", # "desc": "longT5 large nl8 512beta/512l en->nl", # "task": TRANSLATION_NL_TO_EN, # "split_sentences": False, # }, # { # "model_name": "yhavinga/t5-base-36L-nedd-x-en-nl-300", # "desc": "T5 base 36L nedd en->nl 300", # "task": TRANSLATION_NL_TO_EN, # "split_sentences": True, # }, # { # "model_name": "yhavinga/long-t5-local-small-ccmatrix-en-nl", # "desc": "longT5 small ccmatrix en->nl", # "task": TRANSLATION_NL_TO_EN, # "split_sentences": True, # }, ] def main(): st.set_page_config( # Alternate names: setup_page, page, layout page_title="Babel", # String or None. Strings get appended with "โ€ข Streamlit". layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc. initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed" page_icon="๐Ÿ“š", # String, anything supported by st.image, or None. ) if "generators" not in st.session_state: st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST) generators = st.session_state["generators"] with open("style.css") as f: st.markdown(f"", unsafe_allow_html=True) st.sidebar.image("babel.png", width=200) st.sidebar.markdown( """# Babel Vertaal van en naar Engels""" ) st.sidebar.title("Parameters:") if "prompt_box" not in st.session_state: # Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html st.session_state[ "prompt_box" ] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid. And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back. It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes. โ€œMy father is very ill,โ€ she said without a word of introduction. โ€œThe nurse is frightened. Could you come in and help?โ€""" st.session_state["text"] = st.text_area( "Enter text", st.session_state.prompt_box, height=250 ) num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1) num_beam_groups = st.sidebar.number_input( "Num beam groups", min_value=1, max_value=10, value=1 ) length_penalty = st.sidebar.number_input( "Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1 ) st.sidebar.markdown( """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate) and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate). """ ) params = { "num_beams": num_beams, "num_beam_groups": num_beam_groups, "length_penalty": length_penalty, "early_stopping": True, } if st.button("Run"): memory = psutil.virtual_memory() for generator in generators: st.markdown(f"๐Ÿงฎ **Model `{generator}`**") time_start = time.time() result, params_used = generator.generate( text=st.session_state.text, **params ) time_end = time.time() time_diff = time_end - time_start st.write(result.replace("\n", " \n")) text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()]) st.markdown(f" ๐Ÿ•™ *generated in {time_diff:.2f}s, `{text_line}`*") st.write( f""" --- *Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB* """ ) if __name__ == "__main__": main()