File size: 6,433 Bytes
6bc94ac
 
436ce71
6bc94ac
 
 
 
edcdcdb
6bc94ac
98ad652
436ce71
 
 
 
 
aafa95b
e916883
 
 
aafa95b
 
 
 
6bc94ac
 
 
fb6ade2
6bc94ac
 
 
 
15303cb
 
 
 
fb6ade2
 
15303cb
 
325f09c
 
 
 
 
15303cb
8bc0535
 
 
 
aafa95b
e916883
15303cb
436ce71
15303cb
db5ef00
aafa95b
 
 
 
 
 
 
 
15303cb
db5ef00
15303cb
abca9bf
15303cb
 
db5ef00
15303cb
436ce71
15303cb
 
 
 
 
 
 
 
 
325f09c
15303cb
aafa95b
325f09c
 
 
 
 
 
 
 
15303cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bc0535
 
 
 
 
6bc94ac
15303cb
a0194f4
15303cb
 
fb6ade2
a0194f4
 
 
15303cb
 
 
 
 
 
 
 
 
 
 
6bc94ac
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from celebbot import CelebBot
import streamlit as st
from streamlit_mic_recorder import speech_to_text
from utils import *


def main():
    st.set_page_config(initial_sidebar_state="expanded")
    hide_footer()
    model_list = ["flan-t5-xl"]
    celeb_data = get_celeb_data(f'data.json')

    st.sidebar.header("CelebChat")
    expander = st.sidebar.expander('About the app')
    with expander:
        st.markdown("Experience the ultimate celebrity chats with this app!")

    expander = st.sidebar.expander('Disclaimer')
    with expander:
        st.markdown("""
                    * CelebChat may produce inaccurate information about people, places, or facts.
                    * If you have any concerns about your privacy or believe that the app infringes on your rights, please contact me at liuhaozhe2000@gmail.com. I am committed to addressing your concerns and taking any necessary corrective actions.
                    """)
    if "messages" not in st.session_state:
        st.session_state["messages"] = []
    if "QA_model_path" not in st.session_state:          
        st.session_state["QA_model_path"] = "google/flan-t5-xl"
    if "sentTr_model_path" not in st.session_state:          
        st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
    if "start_chat" not in st.session_state:          
        st.session_state["start_chat"] = False
    if "prompt_from_audio" not in st.session_state:          
        st.session_state["prompt_from_audio"] = ""
    if "prompt_from_text" not in st.session_state:          
        st.session_state["prompt_from_text"] = ""
    if "celeb_bot" not in st.session_state:          
        st.session_state["celeb_bot"] = None

    def text_submit():
        st.session_state["prompt_from_text"] = st.session_state.text_input
        st.session_state.text_input = ''

    def example_submit(text):
        st.session_state["prompt_from_text"] = text   

    def clear_chat_his():
        st.session_state["messages"] = []

    st.sidebar.selectbox('Choose your celebrity crush', key="celeb_name", options=sorted(list(celeb_data.keys())), on_change=clear_chat_his)
    model_id=st.sidebar.selectbox("Choose Your model",options=model_list)

    st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id

    celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]

    if st.session_state["celeb_name"] == "Madonna":
        name = "Madonna-American-singer-and-actress"
    elif st.session_state["celeb_name"]== "Anne Hathaway":
        name = "Anne-Hathaway-American-actress"
    else:
        name="-".join(st.session_state["celeb_name"].split(" ")) 

    knowledge = get_article(f"https://www.britannica.com/biography/{name}")
    st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"], 
                    celeb_gender,
                    get_tokenizer(st.session_state["QA_model_path"]), 
                    get_seq2seq_model(st.session_state["QA_model_path"], _tokenizer=get_tokenizer(st.session_state["QA_model_path"])) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]), 
                    get_tokenizer(st.session_state["sentTr_model_path"]), 
                    get_auto_model(st.session_state["sentTr_model_path"]), 
                    *preprocess_text(st.session_state["celeb_name"], knowledge, "en_core_web_lg")
                    )

    dialogue_container = st.container()
    with dialogue_container:
        for message in st.session_state["messages"]:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

    if "_last_audio_id" not in st.session_state:
        st.session_state["_last_audio_id"] = 0
    with st.sidebar:
        st.write("You can record your question...")
        st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
        st.text_input('Or text something...', key='text_input', on_change=text_submit)
        st.write("Example questions:")

        example1 = "Hello! Did you win an Oscar?"
        st.button(example1, on_click=example_submit, args=[example1])
        example2 = "Hi! What is your profession?"
        st.button(example2, on_click=example_submit, args=[example2])
        example3 = "Can you tell me about your family background?"
        st.button(example3, on_click=example_submit, args=[example3])
    
    if st.session_state["prompt_from_audio"] != None:
        prompt = st.session_state["prompt_from_audio"] 
    elif st.session_state["prompt_from_text"] != None:
        prompt = st.session_state["prompt_from_text"]    

    if prompt != None and prompt != '':
        st.session_state["celeb_bot"].text = prompt
        # Display user message in chat message container
        with dialogue_container:
            st.chat_message("user").markdown(prompt)
        # Add user message to chat history
        st.session_state["messages"].append({"role": "user", "content": prompt})

        # Add assistant response to chat history
        if len(st.session_state["messages"]) < 3:
            response = st.session_state["celeb_bot"].question_answer()
        else:
            chat_his = "Question: {q}\n\nAnswer: {a}\n\n".format(q=st.session_state["messages"][-3]["content"], a=st.session_state["messages"][-2]["content"])
            response = st.session_state["celeb_bot"].question_answer(chat_his=chat_his)            
        
        # disable autoplay to play in HTML
        b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
        md = f"""
        <p>{response}</p>
        <audio controls controlsList="autoplay nodownload">
        <source src="data:audio/wav;base64,{b64}" type="audio/wav">
        Your browser does not support the audio element.
        </audio>
        """
        with dialogue_container:
            st.chat_message("assistant").markdown(
                md,
                unsafe_allow_html=True,
            )
        # Display assistant response in chat message container
        st.session_state["messages"].append({"role": "assistant", "content": response})

        st.session_state["prompt_from_audio"] = ""   
        st.session_state["prompt_from_text"] = ""   


if __name__ == "__main__":
    main()