Spaces:
Runtime error
Runtime error
File size: 6,372 Bytes
6bc94ac 436ce71 6bc94ac edcdcdb 6bc94ac 98ad652 436ce71 aafa95b e916883 aafa95b 6bc94ac fb6ade2 6bc94ac 15303cb fb6ade2 15303cb 325f09c 15303cb 8bc0535 aafa95b e916883 15303cb 436ce71 15303cb db5ef00 aafa95b 15303cb db5ef00 15303cb db5ef00 15303cb 436ce71 15303cb 325f09c 15303cb aafa95b 325f09c 15303cb 8bc0535 6bc94ac 15303cb a0194f4 15303cb fb6ade2 a0194f4 15303cb 6bc94ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from celebbot import CelebBot
import streamlit as st
from streamlit_mic_recorder import speech_to_text
from utils import *
def main():
st.set_page_config(initial_sidebar_state="expanded")
hide_footer()
model_list = ["flan-t5-xl"]
celeb_data = get_celeb_data(f'data.json')
st.sidebar.header("CelebChat")
expander = st.sidebar.expander('About the app')
with expander:
st.markdown("Experience the ultimate celebrity chats with this app!")
expander = st.sidebar.expander('Disclaimer')
with expander:
st.markdown("""
* CelebChat may produce inaccurate information about people, places, or facts.
* If you have any concerns about your privacy or believe that the app infringes on your rights, please contact me at liuhaozhe2000@gmail.com. I am committed to addressing your concerns and taking any necessary corrective actions.
""")
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "QA_model_path" not in st.session_state:
st.session_state["QA_model_path"] = "google/flan-t5-xl"
if "sentTr_model_path" not in st.session_state:
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
if "start_chat" not in st.session_state:
st.session_state["start_chat"] = False
if "prompt_from_audio" not in st.session_state:
st.session_state["prompt_from_audio"] = ""
if "prompt_from_text" not in st.session_state:
st.session_state["prompt_from_text"] = ""
if "celeb_bot" not in st.session_state:
st.session_state["celeb_bot"] = None
def text_submit():
st.session_state["prompt_from_text"] = st.session_state.text_input
st.session_state.text_input = ''
def example_submit(text):
st.session_state["prompt_from_text"] = text
def clear_chat_his():
st.session_state["messages"] = []
st.sidebar.selectbox('Choose your celebrity crush', key="celeb_name", options=sorted(list(celeb_data.keys())), on_change=clear_chat_his)
model_id=st.sidebar.selectbox("Choose Your model",options=model_list)
st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
if st.session_state["celeb_name"] == "Madonna":
name = "Madonna-American-singer-and-actress"
elif st.session_state["celeb_name"]== "Anne Hathaway":
name = "Anne-Hathaway-American-actress"
else:
name="-".join(st.session_state["celeb_name"].split(" "))
knowledge = get_article(f"https://www.britannica.com/biography/{name}")
st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"],
celeb_gender,
get_tokenizer(st.session_state["QA_model_path"]),
get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]),
get_tokenizer(st.session_state["sentTr_model_path"]),
get_auto_model(st.session_state["sentTr_model_path"]),
*preprocess_text(st.session_state["celeb_name"], knowledge, "en_core_web_lg")
)
dialogue_container = st.container()
with dialogue_container:
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if "_last_audio_id" not in st.session_state:
st.session_state["_last_audio_id"] = 0
with st.sidebar:
st.write("You can record your question...")
st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
st.text_input('Or text something...', key='text_input', on_change=text_submit)
st.write("Example questions:")
example1 = "Hello! Did you win an Oscar?"
st.button(example1, on_click=example_submit, args=[example1])
example2 = "Hi! What is your profession?"
st.button(example2, on_click=example_submit, args=[example2])
example3 = "Can you tell me about your family background?"
st.button(example3, on_click=example_submit, args=[example3])
if st.session_state["prompt_from_audio"] != None:
prompt = st.session_state["prompt_from_audio"]
elif st.session_state["prompt_from_text"] != None:
prompt = st.session_state["prompt_from_text"]
if prompt != None and prompt != '':
st.session_state["celeb_bot"].text = prompt
# Display user message in chat message container
with dialogue_container:
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state["messages"].append({"role": "user", "content": prompt})
# Add assistant response to chat history
if len(st.session_state["messages"]) < 3:
response = st.session_state["celeb_bot"].question_answer()
else:
chat_his = "Question: {q}\n\nAnswer: {a}\n\n".format(q=st.session_state["messages"][-3]["content"], a=st.session_state["messages"][-2]["content"])
response = st.session_state["celeb_bot"].question_answer(chat_his=chat_his)
# disable autoplay to play in HTML
b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
md = f"""
<p>{response}</p>
<audio controls controlsList="autoplay nodownload">
<source src="data:audio/wav;base64,{b64}" type="audio/wav">
Your browser does not support the audio element.
</audio>
"""
with dialogue_container:
st.chat_message("assistant").markdown(
md,
unsafe_allow_html=True,
)
# Display assistant response in chat message container
st.session_state["messages"].append({"role": "assistant", "content": response})
st.session_state["prompt_from_audio"] = ""
st.session_state["prompt_from_text"] = ""
if __name__ == "__main__":
main()
|