Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/11FAEDRYHuCI7iX5w3JaeKoD76-9pwrLi | |
""" | |
import os | |
import json | |
import pandas as pd | |
from rank_bm25 import BM25Okapi | |
import gradio as gr | |
import openai | |
from datasets import load_dataset | |
# Ensure Hugging Face token exists | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
raise ValueError("Hugging Face token is not set. Please set HF_TOKEN as an environment variable.") | |
# Ensure OpenAI API key exists | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
if not openai.api_key: | |
raise ValueError("OpenAI API key is not set. Please set OPENAI_API_KEY as an environment variable.") | |
# Explicitly define dataset file paths | |
data_files = { | |
"train": "hf://datasets/farah1/mental-health-posts-classification/train.csv", | |
"validation": "hf://datasets/farah1/mental-health-posts-classification/validation.csv", | |
} | |
# Load dataset | |
try: | |
print("Loading dataset...") | |
dataset = load_dataset("csv", data_files=data_files) | |
train_data = dataset["train"].to_pandas() | |
validation_data = dataset["validation"].to_pandas() | |
print("Dataset loaded successfully.") | |
except Exception as e: | |
print(f"Failed to load dataset: {e}") | |
train_data = pd.DataFrame() # Fallback to empty DataFrame | |
validation_data = pd.DataFrame() | |
# Check and create the 'text' column | |
if "text" not in train_data.columns: | |
if "title" in train_data.columns and "content" in train_data.columns: | |
train_data["text"] = train_data["title"] + " " + train_data["content"] | |
else: | |
raise ValueError("The 'text' column is missing, and the required 'title' and 'content' columns are not available to create it.") | |
# Ensure the necessary columns exist in the training dataset | |
required_columns = ["text", "Ground_Truth_Stress", "Ground_Truth_Anxiety", "Ground_Truth_Depression", "Ground_Truth_Other_binary"] | |
for column in required_columns: | |
if column not in train_data.columns: | |
raise ValueError(f"Missing required column '{column}' in the training dataset.") | |
# Initialize BM25 | |
tokenized_train = [doc.split() for doc in train_data["text"]] | |
bm25 = BM25Okapi(tokenized_train) | |
# Few-shot classification function | |
def classify_text(input_text, k=20): | |
# Tokenize input text | |
tokenized_text = input_text.split() | |
# Get top-k similar examples using BM25 | |
scores = bm25.get_scores(tokenized_text) | |
top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k] | |
# Build examples for prompt | |
examples = "\n".join( | |
f"Example {i+1}:\nText: {train_data.iloc[idx]['text']}\nClassification: " | |
f"Stress={train_data.iloc[idx]['Ground_Truth_Stress']}, " | |
f"Anxiety={train_data.iloc[idx]['Ground_Truth_Anxiety']}, " | |
f"Depression={train_data.iloc[idx]['Ground_Truth_Depression']}, " | |
f"Other={train_data.iloc[idx]['Ground_Truth_Other_binary']}\n" | |
for i, idx in enumerate(top_k_indices) | |
) | |
# Construct OpenAI prompt | |
prompt = f""" | |
You are a mental health specialist. Analyze the provided text and classify it into one or more of the following categories: Stress, Anxiety, Depression, or Other. | |
Respond **only** in JSON format with the following keys: | |
- Ground_Truth_Stress: 1 or 0 | |
- Ground_Truth_Anxiety: 1 or 0 | |
- Ground_Truth_Depression: 1 or 0 | |
- Ground_Truth_Other_binary: 1 or 0 | |
Here is the text to classify: | |
"{input_text}" | |
### Example Response: | |
{{ | |
"Ground_Truth_Stress": 0, | |
"Ground_Truth_Anxiety": 1, | |
"Ground_Truth_Depression": 0, | |
"Ground_Truth_Other_binary": 0 | |
}} | |
""" | |
try: | |
response = openai.ChatCompletion.create( | |
messages=[ | |
{"role": "system", "content": "You are a mental health specialist."}, | |
{"role": "user", "content": prompt}, | |
], | |
model="gpt-4", | |
temperature=0, | |
) | |
content = response.choices[0].message.content.strip() | |
print("Raw Response Content:", content) | |
# Attempt to parse as JSON | |
return json.loads(content) | |
except json.JSONDecodeError: | |
print("Failed to decode JSON. Returning raw content.") | |
return {"error": "Failed to decode JSON", "raw_response": content} | |
except Exception as e: | |
print(f"Error occurred: {e}") | |
return {"error": str(e)} | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(lines=5, placeholder="Enter text for classification..."), | |
outputs="json", | |
title="Mental Health Text Classifier", | |
description="Classify text into Stress, Anxiety, Depression, or Other using BM25 and GPT-4.", | |
) | |
if __name__ == "__main__": | |
interface.launch() | |