File size: 13,335 Bytes
00e4075
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
###
# - Author: Jaelin Lee, Abhishek Dutta
# - Date: Mar 23, 2024
# - Description: Streamlit UI for mental health support chatbot using sentiment analsys, RL, BM25/ChromaDB, and LLM.

# - Note:
#   - Updated to UI to show predicted mental health condition in behind the scence regardless of the ositive/negative sentiment
###

from dotenv import load_dotenv, find_dotenv
import pandas as pd
import streamlit as st
from q_learning_chatbot import QLearningChatbot
from xgb_mental_health import MentalHealthClassifier
from bm25_retreive_question import QuestionRetriever as QuestionRetriever_bm25
from Chromadb_storage_JyotiNigam import QuestionRetriever as QuestionRetriever_chromaDB
from llm_response_generator import LLLResponseGenerator
import os
from llama_guard import moderate_chat, get_category_name

from gtts import gTTS
from io import BytesIO
from streamlit_mic_recorder import speech_to_text

import re

# Streamlit UI
st.title("MindfulMedia Mentor")

# Define states and actions
states = [
    "Negative",
    "Moderately Negative",
    "Neutral",
    "Moderately Positive",
    "Positive",
]
actions = ["encouragement", "empathy", "spiritual"]

# Initialize Q-learning chatbot and mental health classifier
chatbot = QLearningChatbot(states, actions)

# Initialize MentalHealthClassifier
# data_path = "/Users/jaelinlee/Documents/projects/fomo/input/data.csv"
data_path = os.path.join("data", "processed", "data.csv")
print(data_path)

tokenizer_model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
mental_classifier_model_path = "app/mental_health_model.pkl"
mental_classifier = MentalHealthClassifier(data_path, mental_classifier_model_path)


# Function to display Q-table
def display_q_table(q_values, states, actions):
    q_table_dict = {"State": states}
    for i, action in enumerate(actions):
        q_table_dict[action] = q_values[:, i]

    q_table_df = pd.DataFrame(q_table_dict)
    return q_table_df


def text_to_speech(text):
    # Use gTTS to convert text to speech
    tts = gTTS(text=text, lang="en")
    # Save the speech as bytes in memory
    fp = BytesIO()
    tts.write_to_fp(fp)
    return fp


def speech_recognition_callback():
    # Ensure that speech output is available
    if st.session_state.my_stt_output is None:
        st.session_state.p01_error_message = "Please record your response again."
        return

    # Clear any previous error messages
    st.session_state.p01_error_message = None

    # Store the speech output in the session state
    st.session_state.speech_input = st.session_state.my_stt_output


def remove_html_tags(text):
    clean_text = re.sub("<.*?>", "", text)
    return clean_text


# Initialize memory
if "entered_text" not in st.session_state:
    st.session_state.entered_text = []
if "entered_mood" not in st.session_state:
    st.session_state.entered_mood = []
if "messages" not in st.session_state:
    st.session_state.messages = []
if "user_sentiment" not in st.session_state:
    st.session_state.user_sentiment = "Neutral"
if "mood_trend" not in st.session_state:
    st.session_state.mood_trend = "Unchanged"
if "predicted_mental_category" not in st.session_state:
    st.session_state.predicted_mental_category = ""
if "ai_tone" not in st.session_state:
    st.session_state.ai_tone = "Empathy"
if "mood_trend_symbol" not in st.session_state:
    st.session_state.mood_trend_symbol = ""
if "show_question" not in st.session_state:
    st.session_state.show_question = False
if "asked_questions" not in st.session_state:
    st.session_state.asked_questions = []
# Check if 'llama_guard_enabled' is already in session state, otherwise initialize it
if "llama_guard_enabled" not in st.session_state:
    st.session_state["llama_guard_enabled"] = True  # Default value to True

# Select Question Retriever
selected_retriever_option = st.sidebar.selectbox(
    "Choose Question Retriever", ("BM25", "ChromaDB")
)
if selected_retriever_option == "BM25":
    retriever = QuestionRetriever_bm25()
if selected_retriever_option == "ChromaDB":
    retriever = QuestionRetriever_chromaDB()

for message in st.session_state.messages:
    with st.chat_message(message.get("role")):
        st.write(message.get("content"))

section_visible = True

# Collect user input
# Add a radio button to choose input mode
input_mode = st.sidebar.radio("Select input mode:", ["Text", "Speech"])
user_message = None
if input_mode == "Speech":
    # Use the speech_to_text function to capture speech input
    speech_input = speech_to_text(key="my_stt", callback=speech_recognition_callback)
    # Check if speech input is available
    if "speech_input" in st.session_state and st.session_state.speech_input:
        # Display the speech input
        # st.text(f"Speech Input: {st.session_state.speech_input}")

        # Process the speech input as a query
        user_message = st.session_state.speech_input
        st.session_state.speech_input = None
else:
    user_message = st.chat_input("Type your message here:")


# Modify the checkbox call to include a unique key parameter
llama_guard_enabled = st.sidebar.checkbox(
    "Enable LlamaGuard",
    value=st.session_state["llama_guard_enabled"],
    key="llama_guard_toggle",
)


# Update the session state based on the checkbox interaction
st.session_state["llama_guard_enabled"] = llama_guard_enabled

# Take user input
if user_message:
    st.session_state.entered_text.append(user_message)

    st.session_state.messages.append({"role": "user", "content": user_message})
    with st.chat_message("user"):
        st.write(user_message)

    is_safe = True
    if st.session_state["llama_guard_enabled"]:
        # guard_status = moderate_chat(user_prompt)
        guard_status, error = moderate_chat(user_message)
        if error:
            st.error(f"Failed to retrieve data from Llama Guard: {error}")
        else:
            if "unsafe" in guard_status[0]["generated_text"]:
                is_safe = False
                # added on March 24th
                unsafe_category_name = get_category_name(
                    guard_status[0]["generated_text"]
                )

    if is_safe == False:
        response = f"I see you are asking something about {unsafe_category_name} Due to eithical and safety reasons, I can't provide the help you need. Please reach out to someone who can, like a family member, friend, or therapist. In urgent situations, contact emergency services or a crisis hotline. Remember, asking for help is brave, and you're not alone."
        st.session_state.messages.append({"role": "ai", "content": response})
        with st.chat_message("ai"):
            st.markdown(response)
        speech_fp = text_to_speech(response)
        # Play the speech
        st.audio(speech_fp, format="audio/mp3")
    else:
        # Detect mental condition
        with st.spinner("Processing..."):
            mental_classifier.initialize_tokenizer(tokenizer_model_name)
            mental_classifier.preprocess_data()
            predicted_mental_category = mental_classifier.predict_category(user_message)
            print("Predicted mental health condition:", predicted_mental_category)

            # Detect sentiment
            user_sentiment = chatbot.detect_sentiment(user_message)

            # Retrieve question
            if user_sentiment in ["Negative", "Moderately Negative", "Neutral"]:
                question = retriever.get_response(
                    user_message, predicted_mental_category
                )
                show_question = True
            else:
                show_question = False
                question = ""
                # predicted_mental_category = ""

            # Update mood history / mood_trend
            chatbot.update_mood_history()
            mood_trend = chatbot.check_mood_trend()

            # Define rewards
            if user_sentiment in ["Positive", "Moderately Positive"]:
                if mood_trend == "increased":
                    reward = +1
                    mood_trend_symbol = " ⬆️"
                elif mood_trend == "unchanged":
                    reward = +0.8
                    mood_trend_symbol = ""
                else:  # decreased
                    reward = -0.2
                    mood_trend_symbol = " ⬇️"
            else:
                if mood_trend == "increased":
                    reward = +1
                    mood_trend_symbol = " ⬆️"
                elif mood_trend == "unchanged":
                    reward = -0.2
                    mood_trend_symbol = ""
                else:  # decreased
                    reward = -1
                    mood_trend_symbol = " ⬇️"

            print(
                f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑"
            )

            # Update Q-values
            chatbot.update_q_values(
                user_sentiment, chatbot.actions[0], reward, user_sentiment
            )

            # Get recommended action based on the updated Q-values
            ai_tone = chatbot.get_action(user_sentiment)
            print(ai_tone)

            print(st.session_state.messages)

            # LLM Response Generator
            load_dotenv(find_dotenv())

            llm_model = LLLResponseGenerator()
            temperature = 0.1
            max_length = 128

            # Collect all messages exchanged so far into a single text string
            all_messages = "\n".join(
                [message.get("content") for message in st.session_state.messages]
            )

            # Question asked to the user: {question}

            template = """INSTRUCTIONS: {context}
            
                Respond to the user with a tone of {ai_tone}. 
                
                Response by the user: {user_text}  
                Response;
                """
            context = f"You are a mental health supporting non-medical assistant. Provide some advice and ask a relevant question back to the user. {all_messages}"

            llm_response = llm_model.llm_inference(
                model_type="huggingface",
                question=question,
                prompt_template=template,
                context=context,
                ai_tone=ai_tone,
                questionnaire=predicted_mental_category,
                user_text=user_message,
                temperature=temperature,
                max_length=max_length,
            )

            llm_response = remove_html_tags(llm_response)

            if show_question:
                llm_reponse_with_quesiton = f"{llm_response}\n\n{question}"
            else:
                llm_reponse_with_quesiton = llm_response

            # Append the user and AI responses to the chat history
            st.session_state.messages.append(
                {"role": "ai", "content": llm_reponse_with_quesiton}
            )

        with st.chat_message("ai"):
            st.markdown(llm_reponse_with_quesiton)
            # Convert the response to speech
            speech_fp = text_to_speech(llm_reponse_with_quesiton)
            # Play the speech
            st.audio(speech_fp, format="audio/mp3")
            # st.write(f"{llm_response}")
            # if show_question:
            #     st.write(f"{question}")
            # else:
            # user doesn't feel negative.
            # get question to ecourage even more positive behaviour

            # Update data to memory
            st.session_state.user_sentiment = user_sentiment
            st.session_state.mood_trend = mood_trend
            st.session_state.predicted_mental_category = predicted_mental_category
            st.session_state.ai_tone = ai_tone
            st.session_state.mood_trend_symbol = mood_trend_symbol
            st.session_state.show_question = show_question

    # Show/hide "Behind the Scene" section
    # section_visible = st.sidebar.button('Show/Hide Behind the Scene')

    with st.sidebar.expander("Behind the Scene", expanded=section_visible):
        st.subheader("What AI is doing:")
        # Use the values stored in session state
        st.write(
            f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})"
        )
        # if st.session_state.show_question:
        st.write(
            f"- Possible Mental Condition: {st.session_state.predicted_mental_category.capitalize()}"
        )
        st.write(f"- AI Tone: {st.session_state.ai_tone.capitalize()}")
        st.write(f"- Question retrieved from: {selected_retriever_option}")
        st.write(
            f"- If the user feels negative, moderately negative, or neutral, at the end of the AI response, it adds a mental health condition related question. The question is retrieved from DB. The categories of questions are limited to Depression, Anxiety, ADHD, Social Media Addiction, Social Isolation, and Cyberbullying which are most associated with FOMO related to excessive social media usage."
        )
        st.write(
            f"- Below q-table is continuously updated after each interaction with the user. If the user's mood increases, AI gets a reward. Else, AI gets a punishment."
        )

        # Display Q-table
        st.dataframe(display_q_table(chatbot.q_values, states, actions))