from collections import Counter import graphviz import penman from penman.models.noop import NoOpModel from mbart_amr.data.linearization import linearized2penmanstr from transformers import LogitsProcessorList import streamlit as st from utils import get_resources, LANGUAGES, translate st.title("👩‍💻 Multilingual text to AMR") with st.form("input data"): text_col, lang_col = st.columns((4, 1)) text = text_col.text_input(label="Input text") src_lang = lang_col.selectbox(label="Language", options=list(LANGUAGES.keys()), index=0) submitted = st.form_submit_button("Submit") if submitted: multilingual = src_lang != "English" model, tokenizer, logitsprocessor = get_resources(multilingual) gen_kwargs = { "max_length": model.config.max_length, "num_beams": model.config.num_beams, "logits_processor": LogitsProcessorList([logitsprocessor]) } linearized = translate(text, src_lang, model, tokenizer, **gen_kwargs) penman_str = linearized2penmanstr(linearized) try: graph = penman.decode(penman_str, model=NoOpModel()) except Exception as exc: st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt" f" to a valid graph but note that this is invalid Penman.") st.code(penman_str) with st.expander("Error trace"): st.write(exc) else: visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box", "fontcolor": "white"}) # Count which names occur multiple times, e.g. t/talk-01 t2/talk-01 nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"]) # Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"} nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"} # Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"} # but only the value occurs more than once nodename_str_c = Counter() for varname in nodenames: nodename = nodenames[varname] if nodename_c[nodename] > 1: nodename_str_c[nodename] += 1 nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})" def get_node_name(item: str): return nodenames[item] if item in nodenames else item try: for triple in graph.triples: if triple[1] == ":instance": continue else: visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1]) except Exception as exc: st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt" " to a valid graph but note that this is probably invalid Penman.") st.code(penman_str) st.write("The initial linearized output of the model was:") st.code(linearized) with st.expander("Error trace"): st.write(exc) else: st.subheader("Graph visualization") st.graphviz_chart(visualized, use_container_width=True) # Download img = visualized.pipe(format="png") st.download_button("Download graph", img, mime="image/png") # Additional info st.subheader("Model output and Penman graph") st.write("The linearized output of the model (after some post-processing) is:") st.code(linearized) st.write("When converted into Penman, it looks like this:") st.code(penman.encode(graph)) ######################## # Information, socials # ######################## st.markdown("## Project: SignON 🤟") st.markdown("""
SignON logo

SignON aims to bridge the communication gap between Deaf, hard of hearing and hearing people through an accessible translation service to translate between languages and modalities with particular attention to sign languages.

This space and the accompanying models and public code are part of the SignON-project. AMR (abstract meaning representation) is used as an interlingua to translate between modalities and languages.

""", unsafe_allow_html=True) st.markdown("## Contact ✒️") st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?" " Give me a shout on [Twitter](https://twitter.com/BramVanroy)" " or add me on [LinkedIn](https://www.linkedin.com/in/bramvanroy/)!")