from collections import Counter import graphviz from optimum.bettertransformer import BetterTransformer import penman from penman.models.noop import NoOpModel from mbart_amr.constraints.constraints import AMRLogitsProcessor from mbart_amr.data.linearization import linearized2penmanstr from mbart_amr.data.tokenization import AMRMBartTokenizer from transformers import MBartForConditionalGeneration, LogitsProcessorList import streamlit as st if "logits_processor" not in st.session_state: st.session_state["logits_processor"] = None if "tokenizer" not in st.session_state: st.session_state["tokenizer"] = None if "model" not in st.session_state: st.session_state["tokenizer"] = AMRMBartTokenizer.from_pretrained("BramVanroy/mbart-en-to-amr", src_lang="en_XX") st.session_state["model"] = MBartForConditionalGeneration.from_pretrained("BramVanroy/mbart-en-to-amr") st.session_state["model"] = BetterTransformer.transform(st.session_state["model"], keep_original_model=False) st.session_state["model"].resize_token_embeddings(len(st.session_state["tokenizer"])) st.session_state["logits_processor"] = AMRLogitsProcessor(st.session_state["tokenizer"], st.session_state["model"].config.max_length) st.title("📝 Parse text into AMR") text = st.text_input(label="Text to transform (en)") if text and "model" in st.session_state: gen_kwargs = { "max_length": st.session_state["model"].config.max_length, "num_beams": st.session_state["model"].config.num_beams, "logits_processor": LogitsProcessorList([st.session_state["logits_processor"]]) if st.session_state[ "logits_processor"] else None } encoded = st.session_state["tokenizer"](text, return_tensors="pt") generated = st.session_state["model"].generate(**encoded, **gen_kwargs) linearized = st.session_state["tokenizer"].decode_and_fix(generated)[0] penman_str = linearized2penmanstr(linearized) try: graph = penman.decode(penman_str, model=NoOpModel()) except Exception as exc: st.write(f"The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt" f" to a valid graph but note that this is invalid Penman.") st.code(penman_str) with st.expander("Error trace"): st.write(exc) else: visualized = graphviz.Digraph(node_attr={"color": "#3aafa9", "style": "rounded,filled", "shape": "box", "fontcolor": "white"}) # Count which names occur multiple times, e.g. t/talk-01 t2/talk-01 nodename_c = Counter([item[2] for item in graph.triples if item[1] == ":instance"]) # Generated initial nodenames for each variable, e.g. {"t": "talk-01", "t2": "talk-01"} nodenames = {item[0]: item[2] for item in graph.triples if item[1] == ":instance"} # Modify nodenames, so that the values are unique, e.g. {"t": "talk-01 (1)", "t2": "talk-01 (2)"} # but only the value occurs more than once nodename_str_c = Counter() for varname in nodenames: nodename = nodenames[varname] if nodename_c[nodename] > 1: nodename_str_c[nodename] += 1 nodenames[varname] = f"{nodename} ({nodename_str_c[nodename]})" def get_node_name(item: str): return nodenames[item] if item in nodenames else item try: for triple in graph.triples: if triple[1] == ":instance": continue else: visualized.edge(get_node_name(triple[0]), get_node_name(triple[2]), label=triple[1]) except Exception as exc: st.write("The generated graph is not valid so it cannot be visualized correctly. Below is the closest attempt" " to a valid graph but note that this is probably invalid Penman.") st.code(penman_str) st.write("The initial linearized output of the model was:") st.code(linearized) with st.expander("Error trace"): st.write(exc) else: st.subheader("Graph visualization") st.graphviz_chart(visualized, use_container_width=True) # Download img = visualized.pipe(format="png") st.download_button("Download graph", img, mime="image/png") # Additional info st.subheader("Model output and Penman graph") st.write("The linearized output of the model (after some post-processing) is:") st.code(linearized) st.write("When converted into Penman, it looks like this:") st.code(penman.encode(graph)) ######################## # Information, socials # ######################## st.markdown("## Contact ✒️") st.markdown("Would you like additional functionality in the demo? Or just want to get in touch?" " Give me a shout on [Twitter](https://twitter.com/BramVanroy)" " or add me on [LinkedIn](https://www.linkedin.com/in/bramvanroy/)!")