Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
from rank_bm25 import BM25Okapi | |
import gradio as gr | |
import openai | |
from datasets import load_dataset | |
# Load Hugging Face dataset | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
raise ValueError("Hugging Face token is not set. Please set HF_TOKEN as an environment variable.") | |
# Explicitly define dataset file paths | |
data_files = { | |
"train": "hf://datasets/farah1/mental-health-posts-classification/train.csv", | |
"validation": "hf://datasets/farah1/mental-health-posts-classification/validation.csv", | |
} | |
try: | |
print("Loading dataset...") | |
dataset = load_dataset("csv", data_files=data_files) | |
train_data = dataset["train"].to_pandas() | |
print("Dataset loaded successfully.") | |
except Exception as e: | |
print(f"Failed to load dataset: {e}") | |
train_data = pd.DataFrame() | |
# Check and create the 'text' column | |
if "text" not in train_data.columns: | |
if "title" in train_data.columns and "content" in train_data.columns: | |
train_data["text"] = train_data["title"] + " " + train_data["content"] | |
else: | |
raise ValueError("The 'text' column is missing, and the required 'title' and 'content' columns are not available to create it.") | |
# Initialize BM25 | |
tokenized_train = [doc.split() for doc in train_data["text"]] | |
bm25 = BM25Okapi(tokenized_train) | |
# Few-shot classification function | |
def classify_text(api_key, input_text, k=20): | |
# Set the API key | |
openai.api_key = api_key | |
if not openai.api_key: | |
return "Error: OpenAI API key is not set." | |
# Tokenize input text | |
tokenized_text = input_text.split() | |
# Get top-k similar examples using BM25 | |
scores = bm25.get_scores(tokenized_text) | |
top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k] | |
# Build examples for the prompt | |
examples = "\n".join( | |
f"Example {i+1}:\nText: {train_data.iloc[idx]['text']}\nClassification: " | |
f"Stress={train_data.iloc[idx]['Ground_Truth_Stress']}, " | |
f"Anxiety={train_data.iloc[idx]['Ground_Truth_Anxiety']}, " | |
f"Depression={train_data.iloc[idx]['Ground_Truth_Depression']}, " | |
f"Other={train_data.iloc[idx]['Ground_Truth_Other_binary']}\n" | |
for i, idx in enumerate(top_k_indices) | |
) | |
# Construct OpenAI prompt | |
prompt = f""" | |
You are a mental health specialist. Analyze the provided text and classify it into one or more of the following categories: Stress, Anxiety, Depression, or Other. | |
Respond with a single category that best matches the content: Stress, Anxiety, Depression, or Other. | |
Here is the text to classify: | |
"{input_text}" | |
### Examples: | |
{examples} | |
""" | |
try: | |
response = openai.ChatCompletion.create( | |
messages=[ | |
{"role": "system", "content": "You are a mental health specialist."}, | |
{"role": "user", "content": prompt}, | |
], | |
model="gpt-4", | |
temperature=0, | |
) | |
content = response.choices[0].message.content.strip() | |
print("OpenAI Response:", content) | |
return content # Return the label directly | |
except Exception as e: | |
print(f"Error occurred: {e}") | |
return f"Error: {e}" | |
# Enhanced Gradio Interface with Examples | |
with gr.Blocks() as interface: | |
gr.Markdown( | |
""" | |
# π§ Mental Health Text Classifier | |
Welcome to the Mental Health Text Classifier. Enter your OpenAI API key and input text, and the system will classify it into one of the following categories: | |
**Stress**, **Anxiety**, **Depression**, or **Other**. | |
> **Disclaimer**: This tool is an AI-based model and may not always provide accurate results. It is not a substitute for professional psychological advice. If you have concerns about your mental health, please consult a licensed psychologist or mental health professional. | |
""" | |
) | |
with gr.Row(): | |
api_key_input = gr.Textbox( | |
label="π OpenAI API Key", | |
placeholder="Enter your OpenAI API key here...", | |
type="password", | |
) | |
with gr.Row(): | |
example_texts = gr.Dropdown( | |
label="π Example Texts", | |
choices=[ | |
"I feel like I'm constantly under pressure at work. The deadlines keep piling up, and I can't seem to catch a break. I'm exhausted and overwhelmed.", # Stress | |
"Lately, Iβve been feeling really anxious about social situations. I keep overthinking what people might think of me, and itβs making me avoid going out altogether.", # Anxiety | |
"I have no energy or motivation to do anything. Even getting out of bed feels like a struggle, and nothing seems to make me happy anymore.", # Depression | |
], | |
value="Select an example text or type your own below.", | |
interactive=True, | |
) | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="π Input Text", | |
lines=5, | |
placeholder="Enter your thoughts or feelings here...", | |
) | |
with gr.Row(): | |
submit_button = gr.Button("Classify") | |
with gr.Row(): | |
output_text = gr.Textbox( | |
label="π·οΈ Classification Result", | |
placeholder="The classification result will appear here.", | |
interactive=False, | |
) | |
# Update the text input field when an example is selected | |
example_texts.change(fn=lambda x: x, inputs=example_texts, outputs=text_input) | |
submit_button.click( | |
classify_text, | |
inputs=[api_key_input, text_input], | |
outputs=output_text, | |
) | |
if __name__ == "__main__": | |
interface.launch() | |