CelebChat / app.py
lhzstar
initial commits
6bc94ac
raw
history blame
No virus
4.5 kB
from celebbot import CelebBot
import streamlit as st
import re
import spacy
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from utils import *
@st.cache_resource
def get_seq2seq_model(model_id):
return AutoModelForSeq2SeqLM.from_pretrained(model_id)
@st.cache_resource
def get_auto_model(model_id):
return AutoModel.from_pretrained(model_id)
@st.cache_resource
def get_tokenizer(model_id):
return AutoTokenizer.from_pretrained(model_id)
@st.cache_data
def get_celeb_data(fpath):
with open(fpath) as json_file:
return json.load(json_file)
@st.cache_resource
def preprocess_text(name, gender, text, model_id):
lname = name.split(" ")[-1]
lname_regex = re.compile(rf'\b({lname})\b')
name_regex = re.compile(rf'\b({name})\b')
lnames = lname+"’s" if not lname.endswith("s") else lname+"’"
lnames_regex = re.compile(rf'\b({lnames})\b')
names = name+"’s" if not name.endswith("s") else name+"’"
names_regex = re.compile(rf'\b({names})\b')
if gender == "M":
text = re.sub(he_regex, "I", text)
text = re.sub(his_regex, "my", text)
elif gender == "F":
text = re.sub(she_regex, "I", text)
text = re.sub(her_regex, "my", text)
text = re.sub(names_regex, "my", text)
text = re.sub(lnames_regex, "my", text)
text = re.sub(name_regex, "I", text)
text = re.sub(lname_regex, "I", text)
spacy_model = spacy.load(model_id)
texts = [i.text.strip() for i in spacy_model(text).sents]
return spacy_model, texts
def main():
hide_footer()
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "QA_model_path" not in st.session_state:
st.session_state["QA_model_path"] = "google/flan-t5-base"
if "sentTr_model_path" not in st.session_state:
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
if "start_chat" not in st.session_state:
st.session_state["start_chat"] = False
model_list = ["base", "large", "xl", "xxl"]
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
celeb_data = get_celeb_data(f'data.json')
# Create a Form Component on the Sidebar for accepting input data and parameters
celeb_name = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
celeb_gender = celeb_data[celeb_name]["gender"]
knowledge = celeb_data[celeb_name]["knowledge"]
model_choice = st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list)
st.session_state["QA_model_path"] = f"google/flan-t5-{model_choice}"
# submitted = st.form_submit_button(label="Start Chatting")
# if submitted:
# st.session_state["start_chat"] = True
# if st.session_state["start_chat"]:
celeb_bot = CelebBot(celeb_name,
get_tokenizer(st.session_state["QA_model_path"]),
get_seq2seq_model(st.session_state["QA_model_path"]),
get_tokenizer(st.session_state["sentTr_model_path"]),
get_auto_model(st.session_state["sentTr_model_path"]),
*preprocess_text(celeb_name, celeb_gender, knowledge, "en_core_web_sm")
)
prompt = st.chat_input("Say something")
print(prompt)
if prompt:
celeb_bot.text = prompt
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state["messages"].append({"role": "user", "content": prompt})
# Add assistant response to chat history
response = celeb_bot.question_answer()
# disable autoplay to play in HTML
b64 = celeb_bot.text_to_speech(autoplay=False)
md = f"""
<p>{response}</p>
<audio controls autoplay style="display:none;">
<source src="data:audio/wav;base64,{b64}" type="audio/wav">
Your browser does not support the audio element.
</audio>
"""
st.chat_message("assistant").markdown(
md,
unsafe_allow_html=True,
)
# Display assistant response in chat message container
st.session_state["messages"].append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()