miwojc's picture
Update app.py
5861710
raw
history blame contribute delete
No virus
896 Bytes
from transformers import pipeline
import streamlit as st
from streamlit.report_thread import get_report_ctx
def query_cache(q_emb=None):
ctx = get_report_ctx()
session_id = ctx.session_id
session = st.server.server.Server.get_current()._get_session_info(session_id).session
if not hasattr(session, "_query_state"):
setattr(session, "_query_state", q_emb)
if q_emb:
session._query_state = q_emb
return session._query_state
# usage
q_emb = query_cache() # will get from cache if exists
#q_emb = query_cache(new_emb) # will set cache to value
if 'user_text' not in q_emb:
q_emb.user_text = 'foo'
st.text_input("Write something", value=q_emb.user_text)
if st.button("Write with transformer"):
gpt2 = pipeline('text-generation')
res = gpt2("My name is Mario and")[0]["generated_text"]
# st.session_state.user_text = res
st.user_text = res