Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/11FAEDRYHuCI7iX5w3JaeKoD76-9pwrLi | |
""" | |
import os | |
import json | |
from rank_bm25 import BM25Okapi | |
import pandas as pd | |
import gradio as gr | |
import openai | |
# Load dataset | |
dataset_url = "https://huggingface.co/datasets/username/mental-health-classification/resolve/main/train.csv" | |
train_data = pd.read_csv(dataset_url) | |
train_data["text"] = train_data["title"] + " " + train_data["content"] | |
# Initialize BM25 | |
tokenized_train = [doc.split() for doc in train_data["text"]] | |
bm25 = BM25Okapi(tokenized_train) | |
# Ensure the user sets their API key | |
if "OPENAI_API_KEY" not in os.environ: | |
raise ValueError("Please set your OpenAI API key using `os.environ['OPENAI_API_KEY'] = 'your_api_key'`") | |
# Initialize OpenAI API | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# Few-shot classification function | |
def classify_text(input_text, k=20): | |
# Tokenize input text | |
tokenized_text = input_text.split() | |
# Get top-k similar examples using BM25 | |
scores = bm25.get_scores(tokenized_text) | |
top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k] | |
# Build examples for prompt | |
examples = "\n".join( | |
f"Example {i+1}:\nText: {train_data.iloc[idx]['text']}\nClassification: " | |
f"Stress={train_data.iloc[idx]['Ground_Truth_Stress']}, " | |
f"Anxiety={train_data.iloc[idx]['Ground_Truth_Anxiety']}, " | |
f"Depression={train_data.iloc[idx]['Ground_Truth_Depression']}, " | |
f"Other={train_data.iloc[idx]['Ground_Truth_Other_binary']}\n" | |
for i, idx in enumerate(top_k_indices) | |
) | |
# Construct OpenAI prompt | |
prompt = f""" | |
You are a mental health specialist. Classify the text into Stress, Anxiety, Depression, or Other: | |
### Examples: | |
{examples} | |
### Text to Classify: | |
"{input_text}" | |
### Output Format: | |
- **Ground_Truth_Stress**: 1 or 0 | |
- **Ground_Truth_Anxiety**: 1 or 0 | |
- **Ground_Truth_Depression**: 1 or 0 | |
- **Ground_Truth_Other_binary**: 1 or 0 | |
""" | |
try: | |
response = openai.ChatCompletion.create( | |
messages=[ | |
{"role": "system", "content": "You are a mental health specialist."}, | |
{"role": "user", "content": prompt}, | |
], | |
model="gpt-4o", | |
temperature=0, | |
) | |
results = response.choices[0].message.content | |
return json.loads(results) | |
except Exception as e: | |
return str(e) | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(lines=5, placeholder="Enter text for classification..."), | |
outputs="json", | |
title="Mental Health Text Classifier", | |
description="Classify text into Stress, Anxiety, Depression, or Other using BM25 and GPT-4.", | |
) | |
if __name__ == "__main__": | |
interface.launch() |