File size: 4,502 Bytes
6bc94ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from celebbot import CelebBot
import streamlit as st
import re
import spacy
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from utils import *


@st.cache_resource
def get_seq2seq_model(model_id):
    return AutoModelForSeq2SeqLM.from_pretrained(model_id)

@st.cache_resource
def get_auto_model(model_id):
    return AutoModel.from_pretrained(model_id)

@st.cache_resource
def get_tokenizer(model_id):
    return AutoTokenizer.from_pretrained(model_id)

@st.cache_data
def get_celeb_data(fpath):
    with open(fpath) as json_file:
        return json.load(json_file)

@st.cache_resource
def preprocess_text(name, gender, text, model_id):
    lname = name.split(" ")[-1]
    lname_regex = re.compile(rf'\b({lname})\b')
    name_regex = re.compile(rf'\b({name})\b')
    lnames = lname+"’s" if not lname.endswith("s") else lname+"’"
    lnames_regex = re.compile(rf'\b({lnames})\b')
    names = name+"’s" if not name.endswith("s") else name+"’"
    names_regex = re.compile(rf'\b({names})\b')
    if gender == "M":
        text = re.sub(he_regex, "I", text)
        text = re.sub(his_regex, "my", text)
    elif gender == "F":
        text = re.sub(she_regex, "I", text)
        text = re.sub(her_regex, "my", text)
    text = re.sub(names_regex, "my", text)
    text = re.sub(lnames_regex, "my", text)
    text = re.sub(name_regex, "I", text)
    text = re.sub(lname_regex, "I", text)
    spacy_model = spacy.load(model_id)
    texts = [i.text.strip() for i in spacy_model(text).sents]
    return spacy_model, texts

def main():
    hide_footer()
    if "messages" not in st.session_state:
        st.session_state["messages"] = []
    if "QA_model_path" not in st.session_state:          
        st.session_state["QA_model_path"] = "google/flan-t5-base"
    if "sentTr_model_path" not in st.session_state:          
        st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
    if "start_chat" not in st.session_state:          
        st.session_state["start_chat"] = False


    model_list = ["base", "large", "xl", "xxl"]

    for message in st.session_state["messages"]:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    celeb_data = get_celeb_data(f'data.json')

    # Create a Form Component on the Sidebar for accepting input data and parameters
    celeb_name = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
    celeb_gender = celeb_data[celeb_name]["gender"]
    knowledge = celeb_data[celeb_name]["knowledge"]
    model_choice = st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list)
    st.session_state["QA_model_path"] = f"google/flan-t5-{model_choice}"

    #     submitted = st.form_submit_button(label="Start Chatting")
    # if submitted:
    #     st.session_state["start_chat"] = True

        
    # if st.session_state["start_chat"]:

    celeb_bot = CelebBot(celeb_name, 
                         get_tokenizer(st.session_state["QA_model_path"]), 
                         get_seq2seq_model(st.session_state["QA_model_path"]), 
                         get_tokenizer(st.session_state["sentTr_model_path"]), 
                         get_auto_model(st.session_state["sentTr_model_path"]), 
                         *preprocess_text(celeb_name, celeb_gender, knowledge, "en_core_web_sm")
                         )

    prompt = st.chat_input("Say something")
    print(prompt)
    if prompt:
        celeb_bot.text = prompt
        # Display user message in chat message container
        st.chat_message("user").markdown(prompt)
        # Add user message to chat history
        st.session_state["messages"].append({"role": "user", "content": prompt})

        # Add assistant response to chat history
        response = celeb_bot.question_answer()
        
        # disable autoplay to play in HTML
        b64 = celeb_bot.text_to_speech(autoplay=False)
        md = f"""
        <p>{response}</p>
        <audio controls autoplay style="display:none;">
        <source src="data:audio/wav;base64,{b64}" type="audio/wav">
        Your browser does not support the audio element.
        </audio>
        """
        st.chat_message("assistant").markdown(
            md,
            unsafe_allow_html=True,
        )
        # Display assistant response in chat message container
        st.session_state["messages"].append({"role": "assistant", "content": response})


if __name__ == "__main__":
    main()