Spaces:
Sleeping
Sleeping
import os | |
import time | |
from openai import OpenAI | |
import numpy as np | |
import streamlit as st | |
import tensorflow as tf | |
import tensorflow_text | |
# import plotly.graph_objects as go | |
# from dotenv import load_dotenv | |
from langchain_openai import OpenAI as OpenAiLC | |
from langchain.memory import ConversationSummaryMemory, ChatMessageHistory | |
from llm import sys_instruction | |
############## | |
# PAGE STYLES | |
# Set page title and icon | |
st.set_page_config(page_title="EmoInsight", | |
page_icon=":robot_face:", | |
initial_sidebar_state="expanded",) | |
# Custom css styles | |
with open('style.css') as f: | |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
# Load variables from .env file | |
# load_dotenv() | |
# Load large model | |
# Decorator to cache non-data objects | |
def Loading_sentiment_analysis_model(): | |
model = tf.saved_model.load('one_2') | |
return model | |
senti_model = Loading_sentiment_analysis_model() | |
emoji_mapping = { | |
"sadness": "π’", | |
"neutral": "π", | |
"joy": "π", | |
"anger": "π‘", | |
"fear": "π¨", | |
"love": "β€οΈ", | |
"surprise": "π²", | |
} | |
emotion_categories = { | |
0: 'anger', | |
1: 'fear', | |
2: 'joy', | |
3: 'love', | |
4: 'neutral', | |
5: 'sadness', | |
6: 'surprise' | |
} | |
################## | |
# STATE VARIABLES | |
# set api key | |
if 'key' not in st.session_state: | |
st.session_state.key = os.environ["API_TOKEN"] | |
# openai.api_key = st.session_state.key | |
# gpt llm | |
if 'llm' not in st.session_state: | |
st.session_state.llm = OpenAiLC( | |
temperature=0.2, openai_api_key=st.session_state.key) | |
# model name | |
if "openai_model" not in st.session_state: | |
st.session_state["openai_model"] = "gpt-3.5-turbo" | |
# openai client | |
# model name | |
if "client" not in st.session_state: | |
st.session_state["client"] = OpenAI( | |
api_key=st.session_state.key | |
) | |
# st chat history | |
if "message_history" not in st.session_state: | |
st.session_state.message_history = [] | |
# set instruction for gpt response | |
if 'sys_inst' not in st.session_state: | |
st.session_state.sys_inst = sys_instruction() | |
# dict to store user question emotion | |
if 'emotion_counts' not in st.session_state: | |
st.session_state.emotion_counts = { | |
'anger': 0, | |
'fear': 0, | |
'joy': 0, | |
'love': 0, | |
'neutral': 0, | |
'sadness': 0, | |
'surprise': 0 | |
} | |
####################### | |
# LANG-CHAIN VARIABLES | |
# storing chat history | |
if 'old_summary' not in st.session_state: | |
st.session_state.old_summary = 'User came to psychological assistant chatbot' | |
# langChian msg history | |
if 'lg_msg_history' not in st.session_state: | |
st.session_state.lg_msg_history = ChatMessageHistory() | |
# summarize old conversation | |
if 'memory' not in st.session_state: | |
st.session_state.memory = ConversationSummaryMemory.from_messages( | |
llm=st.session_state.llm, | |
buffer=st.session_state.old_summary, | |
return_messages=True, | |
chat_memory=st.session_state.lg_msg_history) | |
############################################# | |
# MAIN APP # | |
############################################# | |
st.sidebar.markdown('') | |
st.sidebar.markdown('') | |
st.sidebar.markdown('') | |
st.sidebar.success("Select `Sentiment Plot` button to see the Emotino Graph") | |
st.sidebar.markdown('') | |
clear_chats = st.sidebar.button('Clear Chat') | |
if clear_chats: | |
st.session_state.lg_msg_history.clear() | |
st.session_state.old_summary = 'User came to psychological assistant chatbot' | |
st.session_state.message_history = [] | |
alert = st.sidebar.warning('Chat cleared', icon='π¨') | |
time.sleep(2) # Wait for 3 seconds | |
alert.empty() # Clear the alert | |
st.markdown("<h1><center>EmoInsight</center></h1>", | |
unsafe_allow_html=True) | |
# greetings | |
if len(st.session_state.message_history) == 0: | |
# add to st history | |
st.session_state.message_history.append( | |
{"role": "assistant", "content": "How can I help you?"}) | |
# add to lg history | |
# st.session_state.lg_msg_history.add_ai_message("How can I help you?") | |
# HISTORY | |
for message in st.session_state.message_history: | |
if message['role'] == 'system': | |
with st.chat_message("Emotion", avatar=emoji_mapping.get(message["content"])): | |
a = "Sentiment: {}".format(message["content"]) | |
st.markdown(a) | |
else: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# CHAT BOT | |
if prompt := st.chat_input("What is up?"): | |
# USER | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# add to st history | |
st.session_state.message_history.append( | |
{"role": "user", "content": prompt}) | |
# add to lg history | |
st.session_state.lg_msg_history.add_user_message(prompt) | |
# SENTIMENT PREDICION | |
emotion = senti_model([prompt]) | |
true_classes = np.argmax(emotion, axis=1) | |
emotion_category = emotion_categories.get(int(true_classes)) | |
st.session_state.emotion_counts[emotion_category] += 1 | |
# EMOTION | |
with st.chat_message("Emotion", avatar=emoji_mapping.get(emotion_category)): | |
st.write("Sentiment: {}".format(emotion_category)) | |
st.session_state.message_history.append( | |
{"role": "system", "content": emotion_category}) | |
# AI BOT | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
# get response | |
for chunk in st.session_state.client.chat.completions.create( | |
model=st.session_state["openai_model"], | |
messages=[ | |
{"role": "system", "content": st.session_state.sys_inst.format( | |
history=st.session_state.old_summary)}, | |
{"role": "user", "content": prompt} | |
], # pass old chat history | |
stream=True): | |
# render gpt response in realtime | |
if chunk.choices[0].delta.content: | |
# print(chunk.choices[0].delta.content) | |
full_response += chunk.choices[0].delta.content | |
message_placeholder.markdown(full_response + "β") | |
message_placeholder.markdown(full_response) | |
# add to st history | |
st.session_state.message_history.append( | |
{"role": "assistant", "content": full_response}) | |
# add to lg history | |
st.session_state.lg_msg_history.add_ai_message(prompt) | |
# Clear old chat after 4 dialogs | |
# And update old summary with new summary | |
chat_len = len(st.session_state.lg_msg_history.messages) | |
if (chat_len >= 4) and (chat_len % 4 == 0): | |
# get new summary of chat | |
st.session_state.old_summary = st.session_state.memory.predict_new_summary( | |
messages=st.session_state.lg_msg_history.messages, | |
existing_summary=st.session_state.old_summary) | |
# flush old lg-chat history | |
st.session_state.lg_msg_history.clear() |