from transformers import pipeline import streamlit as st from streamlit.report_thread import get_report_ctx def query_cache(q_emb=None): ctx = get_report_ctx() session_id = ctx.session_id session = st.server.server.Server.get_current()._get_session_info(session_id).session if not hasattr(session, "_query_state"): setattr(session, "_query_state", q_emb) if q_emb: session._query_state = q_emb return session._query_state # usage q_emb = query_cache() # will get from cache if exists #q_emb = query_cache(new_emb) # will set cache to value if 'user_text' not in q_emb: q_emb.user_text = 'foo' st.text_input("Write something", value=q_emb.user_text) if st.button("Write with transformer"): gpt2 = pipeline('text-generation') res = gpt2("My name is Mario and")[0]["generated_text"] # st.session_state.user_text = res st.user_text = res