Mental-Sage / main.py
rogerkoranteng's picture
Upload folder using huggingface_hub
3506b46 verified
import gradio as gr
import os
import json
from dotenv import load_dotenv
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from huggingface_hub import login
from datetime import datetime
import numpy as np
import torch
from gtts import gTTS
import tempfile
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import torch
# Load environment variables from .env file
load_dotenv()
token = os.getenv("HF_TOKEN")
# Use the token in the login function
login(token=token)
# File paths for storing model configurations and chat history
MODEL_CONFIG_FILE = "model_config.json"
CHAT_HISTORY_FILE = "chat_history.json"
# Load model configurations from a JSON file (if exists)
def load_model_config():
if os.path.exists(MODEL_CONFIG_FILE):
with open(MODEL_CONFIG_FILE, 'r') as f:
return json.load(f)
return {
"gpt-4": {
"endpoint": "https://roger-m38jr9pd-eastus2.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-08-01-preview",
"api_key": os.getenv("GPT4_API_KEY"),
"model_path": None # No model path for API models
},
"gpt-4o": {
"endpoint": "https://roger-m38jr9pd-eastus2.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview",
"api_key": os.getenv("GPT4O_API_KEY"),
"model_path": None
},
"gpt-35-turbo": {
"endpoint": "https://rogerkoranteng.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-08-01-preview",
"api_key": os.getenv("GPT35_TURBO_API_KEY"),
"model_path": None
},
"gpt-4-32k": {
"endpoint": "https://roger-m38orjxq-australiaeast.openai.azure.com/openai/deployments/gpt-4-32k/chat/completions?api-version=2024-08-01-preview",
"api_key": os.getenv("GPT4_32K_API_KEY"),
"model_path": None
}
}
predefined_messages = {
"feeling_sad": "Hello, I am feeling sad today, what should I do?",
"Nobody likes me": "Hello, Sage. I feel like nobody likes me. What should I do?",
'Boyfriend broke up': "Hi Sage, my boyfriend broke up with me. I'm feeling so sad. What should I do?",
'I am lonely': "Hi Sage, I am feeling lonely. What should I do?",
'I am stressed': "Hi Sage, I am feeling stressed. What should I do?",
'I am anxious': "Hi Sage, I am feeling anxious. What should I do?",
}
# Save model configuration to JSON
def save_model_config():
with open(MODEL_CONFIG_FILE, 'w') as f:
json.dump(model_config, f, indent=4)
# Load chat history from a JSON file
def load_chat_history():
if os.path.exists(CHAT_HISTORY_FILE):
with open(CHAT_HISTORY_FILE, 'r') as f:
return json.load(f)
return []
# Save chat history to a JSON file
def save_chat_history(chat_history):
with open(CHAT_HISTORY_FILE, 'w') as f:
json.dump(chat_history, f, indent=4)
# Define model configurations
model_config = load_model_config()
# Function to dynamically add downloaded model to model_config
def add_downloaded_model(model_name, model_path):
model_config[model_name] = {
"endpoint": None,
"model_path": model_path,
"api_key": None
}
save_model_config()
return list(model_config.keys())
# Function to download model from Hugging Face synchronously
def download_model(model_name):
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model_path = f"./models/{model_name}"
os.makedirs(model_path, exist_ok=True)
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
updated_models = add_downloaded_model(model_name, model_path)
return f"Model '{model_name}' downloaded and added.", updated_models
except Exception as e:
return f"Error downloading model '{model_name}': {e}", list(model_config.keys())
# Chat function using the selected model
def generate_response(model_choice, user_message, chat_history):
model_info = model_config.get(model_choice)
if not model_info:
return "Invalid model selection. Please choose a valid model.", chat_history
chat_history.append({"role": "user", "content": user_message})
headers = {"Content-Type": "application/json"}
# Check if the model is an API model (it will have an endpoint)
if model_info["endpoint"]:
if model_info["api_key"]:
headers["api-key"] = model_info["api_key"]
data = {"messages": chat_history, "max_tokens": 1500, "temperature": 0.7}
try:
# Send request to the API model endpoint
response = requests.post(model_info["endpoint"], headers=headers, json=data)
response.raise_for_status()
assistant_message = response.json()['choices'][0]['message']['content']
chat_history.append({"role": "assistant", "content": assistant_message})
save_chat_history(chat_history) # Save chat history to JSON
except requests.exceptions.RequestException as e:
assistant_message = f"Error: {e}"
chat_history.append({"role": "assistant", "content": assistant_message})
save_chat_history(chat_history)
else:
# If it's a local model, load the model and tokenizer from the local path
model_path = model_info["model_path"]
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
inputs = tokenizer(user_message, return_tensors="pt")
outputs = model.generate(inputs['input_ids'], max_length=500, num_return_sequences=1)
assistant_message = tokenizer.decode(outputs[0], skip_special_tokens=True)
chat_history.append({"role": "assistant", "content": assistant_message})
save_chat_history(chat_history)
except Exception as e:
assistant_message = f"Error loading model locally: {e}"
chat_history.append({"role": "assistant", "content": assistant_message})
save_chat_history(chat_history)
# Convert the assistant message to audio
tts = gTTS(assistant_message)
audio_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
tts.save(audio_file.name)
return chat_history, audio_file.name
# Function to format chat history with custom bubble styles
def format_chat_bubble(history):
formatted_history = ""
for message in history:
timestamp = datetime.now().strftime("%H:%M:%S")
if message["role"] == "user":
formatted_history += f'''
<div class="user-bubble">
<strong>Me:</strong> {message["content"]}
</div>
'''
else:
formatted_history += f'''
<div class="assistant-bubble">
<strong>Sage:</strong> {message["content"]}
</div>
'''
return formatted_history
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
def transcribe(audio):
if audio is None:
return "No audio input received."
sr, y = audio
# Convert to mono if stereo
if y.ndim > 1:
y = y.mean(axis=1)
y = y.astype(np.float32)
y /= np.max(np.abs(y))
# Tokenize the audio
input_values = tokenizer(y, return_tensors="pt", sampling_rate=sr).input_values
# Perform inference
with torch.no_grad():
logits = model(input_values).logits
# Decode the logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.decode(predicted_ids[0])
return transcription
# Create the Gradio interface
with gr.Blocks() as interface:
gr.Markdown("## Chat with Sage - Your Mental Health Advisor")
with gr.Tab("Model Management"):
with gr.Tabs():
with gr.TabItem("Model Selection"):
gr.Markdown("### Select Model for Chat")
model_dropdown = gr.Dropdown(choices=list(model_config.keys()), label="Choose a Model", value="gpt-4",
allow_custom_value=True)
status_textbox = gr.Textbox(label="Model Selection Status", value="Selected model: gpt-4")
model_dropdown.change(lambda model: f"Selected model: {model}", inputs=model_dropdown,
outputs=status_textbox)
with gr.TabItem("Download Model"): # Sub-tab for downloading models
gr.Markdown("### Download a Model from Hugging Face")
model_name_input = gr.Textbox(label="Enter Model Name from Hugging Face (e.g., gpt2)")
download_button = gr.Button("Download Model")
download_status = gr.Textbox(label="Download Status")
# Model download synchronous handler
def on_model_download(model_name):
download_message, updated_models = download_model(model_name)
# Trigger the dropdown update to show the newly added model
return download_message, gr.update(choices=updated_models, value=updated_models[-1])
download_button.click(on_model_download, inputs=model_name_input,
outputs=[download_status, model_dropdown])
refresh_button = gr.Button("Refresh Model List")
refresh_button.click(lambda: gr.update(choices=list(model_config.keys())), inputs=[],
outputs=model_dropdown)
with gr.Tab("Chat Interface"):
gr.Markdown("### Chat with Sage")
# Chat history state for tracking conversation
chat_history_state = gr.State(load_chat_history()) # Load existing chat history
# Add initial introduction message
if not chat_history_state.value:
chat_history_state.value.append({"role": "assistant", "content": "Hello, I am Sage. How can I assist you today?"})
chat_display = gr.HTML(label="Chat", value=format_chat_bubble(chat_history_state.value), elem_id="chat-display")
user_message = gr.Textbox(placeholder="Type your message here...", label="Your Message")
send_button = gr.Button("Send Message")
# Predefined message buttons
predefined_buttons = [gr.Button(value=msg) for msg in predefined_messages.values()]
# Real-time message updating
def update_chat(model_choice, user_message, chat_history_state):
chat_history, audio_file = generate_response(model_choice, user_message, chat_history_state)
formatted_chat = format_chat_bubble(chat_history)
return formatted_chat, chat_history, audio_file
send_button.click(
update_chat,
inputs=[model_dropdown, user_message, chat_history_state],
outputs=[chat_display, chat_history_state, gr.Audio(autoplay=True)]
)
send_button.click(lambda: "", None, user_message) # Clears the user input after sending
# Add click events for predefined message buttons
for button, message in zip(predefined_buttons, predefined_messages.values()):
button.click(
update_chat,
inputs=[model_dropdown, gr.State(message), chat_history_state],
outputs=[chat_display, chat_history_state, gr.Audio(autoplay=True)]
)
with gr.Tab("Speech Interface"):
gr.Markdown("### Speak with Sage")
audio_input = gr.Audio(type="numpy")
transcribe_button = gr.Button("Transcribe")
transcribed_text = gr.Textbox(label="Transcribed Text")
transcribe_button.click(
transcribe,
inputs=audio_input,
outputs=transcribed_text
)
send_speech_button = gr.Button("Send Speech Message")
send_speech_button.click(
update_chat,
inputs=[model_dropdown, transcribed_text, chat_history_state],
outputs=[chat_display, chat_history_state, gr.Audio(autoplay=True)]
)
# Add custom CSS for scrolling chat box and bubbles
interface.css = """
#chat-display {
max-height: 500px;
overflow-y: auto;
padding: 10px;
background-color: #1a1a1a;
border-radius: 10px;
display: flex;
flex-direction: column;
justify-content: flex-start;
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
scroll-behavior: smooth;
}
/* User message style - text only */
.user-bubble {
color: #ffffff; /* Text color for the user */
padding: 8px 15px;
margin: 8px 0;
word-wrap: break-word;
align-self: flex-end;
font-size: 14px;
position: relative;
max-width: 70%; /* Make the bubble width dynamic */
border-radius: 15px;
background-color: #121212; /* Light cyan background for the user */
transition: color 0.3s ease;
}
/* Assistant message style - text only */
.assistant-bubble {
color: #ffffff; /* Text color for the assistant */
padding: 8px 15px;
margin: 8px 0;
word-wrap: break-word;
align-self: flex-start;
background-color: #2a2a2a;
font-size: 14px;
position: relative;
max-width: 70%;
transition: color 0.3s ease;
}
"""
# Launch the Gradio interface
interface.launch(server_name="0.0.0.0", server_port=8080, share=True)