Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
import streamlit as st | |
from streamlit.runtime.scriptrunner import RerunException, StopException | |
from openai import OpenAI | |
from pymongo import MongoClient | |
from pinecone import Pinecone | |
import uuid | |
from datetime import datetime | |
import time | |
from streamlit.runtime.caching import cache_data | |
# Load environment variables | |
load_dotenv() | |
# Configuration | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
MONGODB_URI = os.getenv("MONGODB_URI") | |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
PINECONE_ENVIRONMENT = os.getenv("PINECONE_ENVIRONMENT") | |
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME") | |
# Initialize clients | |
openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
mongo_client = MongoClient(MONGODB_URI) | |
db = mongo_client["Wall_Street"] | |
conversation_history = db["conversation_history"] | |
global_common_memory = db["global_common_memory"] # New global common memory collection | |
chat_history = db["chat_history"] # New collection for storing all chats | |
# Initialize Pinecone | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
pinecone_index = pc.Index(PINECONE_INDEX_NAME) | |
# Set up Streamlit page configuration | |
st.set_page_config(page_title="GPT-Driven Chat System - User", page_icon="💬", layout="wide") | |
# Custom CSS to improve the UI | |
st.markdown(""" | |
<style> | |
/* Your custom CSS styles */ | |
</style> | |
""", unsafe_allow_html=True) | |
# Initialize Streamlit session state | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if 'user_type' not in st.session_state: | |
st.session_state['user_type'] = None | |
if 'session_id' not in st.session_state: | |
st.session_state['session_id'] = str(uuid.uuid4()) | |
# --- Common Memory Functions --- | |
# Cache for 5 minutes | |
def get_global_common_memory(): | |
"""Retrieve the global common memory.""" | |
memory_doc = global_common_memory.find_one({"memory_id": "global_common_memory_id"}) | |
return memory_doc.get('memory', []) if memory_doc else [] | |
# --- Relevant Context Retrieval --- | |
# Cache for 1 minute | |
def get_relevant_context(query, top_k=3): | |
""" | |
Retrieve relevant context from Pinecone based on the user query. | |
""" | |
try: | |
query_embedding = openai_client.embeddings.create( | |
model="text-embedding-3-large", # Updated to use the larger model | |
input=query | |
).data[0].embedding | |
results = pinecone_index.query(vector=query_embedding, top_k=top_k, include_metadata=True) | |
contexts = [item['metadata']['text'] for item in results['matches']] | |
return " ".join(contexts) | |
except Exception as e: | |
print(f"Error retrieving context: {str(e)}") | |
return "" | |
# --- GPT Response Function --- | |
def get_gpt_response(prompt, context=""): | |
""" | |
Generates a response from the GPT model based on the user prompt, retrieved context, and global chat memory. | |
Assesses confidence and marks the response as 'uncertain' if confidence is low. | |
""" | |
try: | |
print(prompt) | |
#print(context) | |
common_memory = get_global_common_memory() | |
system_message = ( | |
"You are a helpful assistant. Use the following context and global chat memory to inform your responses, " | |
"but don't mention them explicitly unless directly relevant to the user's question. " | |
"If you are uncertain about the answer, respond with 'I am not sure about that.'" | |
) | |
if common_memory: | |
# Join the memory items into a single string | |
memory_str = "\n".join(common_memory) | |
system_message += f"\n\nGlobal Chat Memory:\n{memory_str}" | |
print(system_message) | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": f"Context: {context}\n\nUser query: {prompt}"} | |
] | |
completion = openai_client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=messages, | |
temperature=0.7 # Adjust temperature for confidence control | |
) | |
response = completion.choices[0].message.content.strip() | |
print(response) | |
# Determine if the response indicates uncertainty | |
is_uncertain = "i am not sure about that" in response.lower() | |
print(is_uncertain) | |
return response, is_uncertain | |
except Exception as e: | |
return f"Error generating response: {str(e)}", False | |
# --- Send User Message --- | |
def send_message(message): | |
""" | |
Sends a user message. If the chatbot is uncertain, messages are sent to the operator for approval. | |
""" | |
context = get_relevant_context(message) | |
user_message = { | |
"role": "user", | |
"content": message, | |
"timestamp": datetime.utcnow(), | |
"status": "approved" # User messages are always approved | |
} | |
# Upsert the user message immediately | |
result = conversation_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": user_message}, | |
"$set": {"last_updated": datetime.utcnow()}, | |
"$setOnInsert": {"created_at": datetime.utcnow()} | |
}, | |
upsert=True | |
) | |
# Update or create the chat history document | |
chat_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": user_message}, | |
"$set": {"last_updated": datetime.utcnow()}, | |
"$setOnInsert": {"created_at": datetime.utcnow()} | |
}, | |
upsert=True | |
) | |
# Update the session state with the user message | |
st.session_state['chat_history'].append(user_message) | |
if not st.session_state.get('admin_takeover_active'): | |
# Generate GPT response if takeover is not active | |
gpt_response, is_uncertain = get_gpt_response(message, context) | |
if is_uncertain: | |
status = "pending" # Mark as pending for operator approval | |
else: | |
status = "approved" | |
assistant_message = { | |
"role": "assistant", | |
"content": gpt_response, | |
"timestamp": datetime.utcnow(), | |
"status": status | |
} | |
# Upsert the assistant message | |
result = conversation_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": assistant_message}, | |
"$set": {"last_updated": datetime.utcnow()} | |
} | |
) | |
# Update the chat history document | |
chat_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": assistant_message}, | |
"$set": {"last_updated": datetime.utcnow()} | |
} | |
) | |
# Update the session state with the assistant message | |
st.session_state['chat_history'].append(assistant_message) | |
# --- Send Admin Message --- | |
def send_admin_message(message): | |
""" | |
Sends an admin message directly to the user during a takeover. | |
""" | |
admin_message = { | |
"role": "admin", | |
"content": message, | |
"timestamp": datetime.utcnow(), | |
"status": "approved" | |
} | |
# Upsert the admin message | |
result = conversation_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": admin_message}, | |
"$set": {"last_updated": datetime.utcnow()} | |
} | |
) | |
# Update the chat history document | |
chat_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": admin_message}, | |
"$set": {"last_updated": datetime.utcnow()} | |
} | |
) | |
# Update the session state with the admin message | |
st.session_state['chat_history'].append(admin_message) | |
# --- User Page --- | |
def user_page(): | |
if 'session_id' not in st.session_state: | |
st.session_state['session_id'] = str(uuid.uuid4()) | |
st.title("Chat Interface") | |
chat_col, info_col = st.columns([3, 1]) | |
with chat_col: | |
# Create a placeholder for the chat interface | |
chat_placeholder = st.empty() | |
# Add a manual refresh button | |
if st.button("Refresh Chat"): | |
fetch_and_update_chat() | |
# Handle new user input outside the loop | |
user_input = st.chat_input("Type your message here...", key="user_chat_input") | |
if user_input: | |
send_message(user_input) | |
# If admin takeover is active, allow admin to send messages | |
if st.session_state.get('admin_takeover_active'): | |
admin_input = st.chat_input("Admin is currently taking over the chat...", key="admin_chat_input") | |
if admin_input: | |
send_admin_message(admin_input) | |
# Main loop for continuous updates | |
while True: | |
with chat_placeholder.container(): | |
# Display all messages in the chat history | |
for message in st.session_state['chat_history']: | |
if message["role"] == "user": | |
with st.chat_message("user"): | |
st.markdown(message["content"]) | |
elif message["role"] == "assistant": | |
with st.chat_message("assistant"): | |
if message.get("status") == "approved": | |
st.markdown(message["content"]) | |
elif message.get("status") == "pending": | |
st.info("This response is pending operator approval...") | |
elif message["role"] == "admin": | |
with st.chat_message("admin"): | |
st.markdown(f"**Admin:** {message['content']}") | |
# Fetch updates every 5 seconds | |
fetch_and_update_chat() | |
time.sleep(5) | |
with info_col: | |
st.subheader("Session Information") | |
stat_cols = st.columns(2) | |
with stat_cols[0]: | |
st.write(f"**Session ID:** {st.session_state['session_id'][:8]}...") | |
with stat_cols[1]: | |
st.write(f"**User Type:** {st.session_state.get('user_type', 'Regular User')}") | |
stat_cols = st.columns(2) | |
with stat_cols[0]: | |
st.write(f"**Chat History Count:** {len(st.session_state['chat_history'])}") | |
with stat_cols[1]: | |
st.write(f"**Last Active:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
# Add more user-specific information or settings as needed | |
# --- Fetch and Update Chat --- | |
def fetch_and_update_chat(): | |
""" | |
Fetches the latest chat history from the database and updates the session state. | |
""" | |
chat = conversation_history.find_one({"session_id": st.session_state['session_id']}) | |
if chat: | |
st.session_state['chat_history'] = chat.get('messages', []) | |
else: | |
st.session_state['chat_history'] = [] | |
# Check if admin takeover is active | |
takeover_status = db.takeover_status.find_one({"session_id": st.session_state['session_id']}) | |
is_active = takeover_status and takeover_status.get("active", False) | |
st.session_state['admin_takeover_active'] = is_active | |
# --- View Full Chat (User Perspective) --- | |
def view_full_chat(session_id): | |
st.title("Full Chat View") | |
chat = conversation_history.find_one({"session_id": session_id}) | |
if chat: | |
st.subheader(f"Session ID: {session_id[:8]}...") | |
last_updated = chat.get('last_updated', datetime.utcnow()) | |
st.write(f"Last Updated: {last_updated}") | |
# Display global common memory | |
st.subheader("Global Chat Memory") | |
common_memory = get_global_common_memory() | |
if common_memory: | |
for idx, item in enumerate(common_memory, 1): | |
st.text(f"{idx}. {item}") | |
else: | |
st.info("No global chat memory found.") | |
# Add button to clear global common memory | |
if st.button("Clear Global Chat Memory"): | |
clear_global_common_memory() | |
st.success("Global chat memory cleared successfully!") | |
st.rerun() | |
# Add admin takeover button | |
if st.button("Admin Takeover"): | |
st.session_state['admin_takeover'] = session_id | |
st.rerun() | |
for message in chat.get('messages', []): | |
if message['role'] == 'user': | |
with st.chat_message("user"): | |
st.markdown(message["content"]) | |
elif message['role'] == 'assistant': | |
with st.chat_message("assistant"): | |
if message.get("status") == "approved": | |
st.markdown(message["content"]) | |
else: | |
st.info("Waiting for admin approval...") | |
elif message['role'] == 'admin': # Display admin messages | |
with st.chat_message("admin"): | |
st.markdown(f"**Admin:** {message['content']}") | |
st.caption(f"Timestamp: {message.get('timestamp', 'N/A')}") | |
else: | |
st.error("Chat not found.") | |
if st.button("Back to Admin Dashboard"): | |
st.session_state.pop('selected_chat', None) | |
st.rerun() | |
def main(): | |
try: | |
user_page() | |
except (RerunException, StopException): | |
# These exceptions are used by Streamlit for page navigation and stopping the script | |
raise | |
except Exception as e: | |
st.error(f"An unexpected error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |