import gradio as gr import os import json from dotenv import load_dotenv import requests from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from huggingface_hub import login from datetime import datetime import numpy as np import torch from gtts import gTTS import tempfile from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer import torch # Load environment variables from .env file load_dotenv() token = os.getenv("HF_TOKEN") # Use the token in the login function login(token=token) # File paths for storing model configurations and chat history MODEL_CONFIG_FILE = "model_config.json" CHAT_HISTORY_FILE = "chat_history.json" # Load model configurations from a JSON file (if exists) def load_model_config(): if os.path.exists(MODEL_CONFIG_FILE): with open(MODEL_CONFIG_FILE, 'r') as f: return json.load(f) return { "gpt-4": { "endpoint": "https://roger-m38jr9pd-eastus2.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-08-01-preview", "api_key": os.getenv("GPT4_API_KEY"), "model_path": None # No model path for API models }, "gpt-4o": { "endpoint": "https://roger-m38jr9pd-eastus2.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview", "api_key": os.getenv("GPT4O_API_KEY"), "model_path": None }, "gpt-35-turbo": { "endpoint": "https://rogerkoranteng.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-08-01-preview", "api_key": os.getenv("GPT35_TURBO_API_KEY"), "model_path": None }, "gpt-4-32k": { "endpoint": "https://roger-m38orjxq-australiaeast.openai.azure.com/openai/deployments/gpt-4-32k/chat/completions?api-version=2024-08-01-preview", "api_key": os.getenv("GPT4_32K_API_KEY"), "model_path": None } } predefined_messages = { "feeling_sad": "Hello, I am feeling sad today, what should I do?", "Nobody likes me": "Hello, Sage. I feel like nobody likes me. What should I do?", 'Boyfriend broke up': "Hi Sage, my boyfriend broke up with me. I'm feeling so sad. What should I do?", 'I am lonely': "Hi Sage, I am feeling lonely. What should I do?", 'I am stressed': "Hi Sage, I am feeling stressed. What should I do?", 'I am anxious': "Hi Sage, I am feeling anxious. What should I do?", } # Save model configuration to JSON def save_model_config(): with open(MODEL_CONFIG_FILE, 'w') as f: json.dump(model_config, f, indent=4) # Load chat history from a JSON file def load_chat_history(): if os.path.exists(CHAT_HISTORY_FILE): with open(CHAT_HISTORY_FILE, 'r') as f: return json.load(f) return [] # Save chat history to a JSON file def save_chat_history(chat_history): with open(CHAT_HISTORY_FILE, 'w') as f: json.dump(chat_history, f, indent=4) # Define model configurations model_config = load_model_config() # Function to dynamically add downloaded model to model_config def add_downloaded_model(model_name, model_path): model_config[model_name] = { "endpoint": None, "model_path": model_path, "api_key": None } save_model_config() return list(model_config.keys()) # Function to download model from Hugging Face synchronously def download_model(model_name): try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) model_path = f"./models/{model_name}" os.makedirs(model_path, exist_ok=True) model.save_pretrained(model_path) tokenizer.save_pretrained(model_path) updated_models = add_downloaded_model(model_name, model_path) return f"Model '{model_name}' downloaded and added.", updated_models except Exception as e: return f"Error downloading model '{model_name}': {e}", list(model_config.keys()) # Chat function using the selected model def generate_response(model_choice, user_message, chat_history): model_info = model_config.get(model_choice) if not model_info: return "Invalid model selection. Please choose a valid model.", chat_history chat_history.append({"role": "user", "content": user_message}) headers = {"Content-Type": "application/json"} # Check if the model is an API model (it will have an endpoint) if model_info["endpoint"]: if model_info["api_key"]: headers["api-key"] = model_info["api_key"] data = {"messages": chat_history, "max_tokens": 1500, "temperature": 0.7} try: # Send request to the API model endpoint response = requests.post(model_info["endpoint"], headers=headers, json=data) response.raise_for_status() assistant_message = response.json()['choices'][0]['message']['content'] chat_history.append({"role": "assistant", "content": assistant_message}) save_chat_history(chat_history) # Save chat history to JSON except requests.exceptions.RequestException as e: assistant_message = f"Error: {e}" chat_history.append({"role": "assistant", "content": assistant_message}) save_chat_history(chat_history) else: # If it's a local model, load the model and tokenizer from the local path model_path = model_info["model_path"] try: tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) inputs = tokenizer(user_message, return_tensors="pt") outputs = model.generate(inputs['input_ids'], max_length=500, num_return_sequences=1) assistant_message = tokenizer.decode(outputs[0], skip_special_tokens=True) chat_history.append({"role": "assistant", "content": assistant_message}) save_chat_history(chat_history) except Exception as e: assistant_message = f"Error loading model locally: {e}" chat_history.append({"role": "assistant", "content": assistant_message}) save_chat_history(chat_history) # Convert the assistant message to audio tts = gTTS(assistant_message) audio_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") tts.save(audio_file.name) return chat_history, audio_file.name # Function to format chat history with custom bubble styles def format_chat_bubble(history): formatted_history = "" for message in history: timestamp = datetime.now().strftime("%H:%M:%S") if message["role"] == "user": formatted_history += f'''
Me: {message["content"]}
''' else: formatted_history += f'''
Sage: {message["content"]}
''' return formatted_history tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") def transcribe(audio): if audio is None: return "No audio input received." sr, y = audio # Convert to mono if stereo if y.ndim > 1: y = y.mean(axis=1) y = y.astype(np.float32) y /= np.max(np.abs(y)) # Tokenize the audio input_values = tokenizer(y, return_tensors="pt", sampling_rate=sr).input_values # Perform inference with torch.no_grad(): logits = model(input_values).logits # Decode the logits predicted_ids = torch.argmax(logits, dim=-1) transcription = tokenizer.decode(predicted_ids[0]) return transcription # Create the Gradio interface with gr.Blocks() as interface: gr.Markdown("## Chat with Sage - Your Mental Health Advisor") with gr.Tab("Model Management"): with gr.Tabs(): with gr.TabItem("Model Selection"): gr.Markdown("### Select Model for Chat") model_dropdown = gr.Dropdown(choices=list(model_config.keys()), label="Choose a Model", value="gpt-4", allow_custom_value=True) status_textbox = gr.Textbox(label="Model Selection Status", value="Selected model: gpt-4") model_dropdown.change(lambda model: f"Selected model: {model}", inputs=model_dropdown, outputs=status_textbox) with gr.TabItem("Download Model"): # Sub-tab for downloading models gr.Markdown("### Download a Model from Hugging Face") model_name_input = gr.Textbox(label="Enter Model Name from Hugging Face (e.g., gpt2)") download_button = gr.Button("Download Model") download_status = gr.Textbox(label="Download Status") # Model download synchronous handler def on_model_download(model_name): download_message, updated_models = download_model(model_name) # Trigger the dropdown update to show the newly added model return download_message, gr.update(choices=updated_models, value=updated_models[-1]) download_button.click(on_model_download, inputs=model_name_input, outputs=[download_status, model_dropdown]) refresh_button = gr.Button("Refresh Model List") refresh_button.click(lambda: gr.update(choices=list(model_config.keys())), inputs=[], outputs=model_dropdown) with gr.Tab("Chat Interface"): gr.Markdown("### Chat with Sage") # Chat history state for tracking conversation chat_history_state = gr.State(load_chat_history()) # Load existing chat history # Add initial introduction message if not chat_history_state.value: chat_history_state.value.append({"role": "assistant", "content": "Hello, I am Sage. How can I assist you today?"}) chat_display = gr.HTML(label="Chat", value=format_chat_bubble(chat_history_state.value), elem_id="chat-display") user_message = gr.Textbox(placeholder="Type your message here...", label="Your Message") send_button = gr.Button("Send Message") # Predefined message buttons predefined_buttons = [gr.Button(value=msg) for msg in predefined_messages.values()] # Real-time message updating def update_chat(model_choice, user_message, chat_history_state): chat_history, audio_file = generate_response(model_choice, user_message, chat_history_state) formatted_chat = format_chat_bubble(chat_history) return formatted_chat, chat_history, audio_file send_button.click( update_chat, inputs=[model_dropdown, user_message, chat_history_state], outputs=[chat_display, chat_history_state, gr.Audio(autoplay=True)] ) send_button.click(lambda: "", None, user_message) # Clears the user input after sending # Add click events for predefined message buttons for button, message in zip(predefined_buttons, predefined_messages.values()): button.click( update_chat, inputs=[model_dropdown, gr.State(message), chat_history_state], outputs=[chat_display, chat_history_state, gr.Audio(autoplay=True)] ) with gr.Tab("Speech Interface"): gr.Markdown("### Speak with Sage") audio_input = gr.Audio(type="numpy") transcribe_button = gr.Button("Transcribe") transcribed_text = gr.Textbox(label="Transcribed Text") transcribe_button.click( transcribe, inputs=audio_input, outputs=transcribed_text ) send_speech_button = gr.Button("Send Speech Message") send_speech_button.click( update_chat, inputs=[model_dropdown, transcribed_text, chat_history_state], outputs=[chat_display, chat_history_state, gr.Audio(autoplay=True)] ) # Add custom CSS for scrolling chat box and bubbles interface.css = """ #chat-display { max-height: 500px; overflow-y: auto; padding: 10px; background-color: #1a1a1a; border-radius: 10px; display: flex; flex-direction: column; justify-content: flex-start; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1); scroll-behavior: smooth; } /* User message style - text only */ .user-bubble { color: #ffffff; /* Text color for the user */ padding: 8px 15px; margin: 8px 0; word-wrap: break-word; align-self: flex-end; font-size: 14px; position: relative; max-width: 70%; /* Make the bubble width dynamic */ border-radius: 15px; background-color: #121212; /* Light cyan background for the user */ transition: color 0.3s ease; } /* Assistant message style - text only */ .assistant-bubble { color: #ffffff; /* Text color for the assistant */ padding: 8px 15px; margin: 8px 0; word-wrap: break-word; align-self: flex-start; background-color: #2a2a2a; font-size: 14px; position: relative; max-width: 70%; transition: color 0.3s ease; } """ # Launch the Gradio interface interface.launch(server_name="0.0.0.0", server_port=8080, share=True)