File size: 5,670 Bytes
61cf524
 
5bf61fc
61cf524
a352d2f
edaa6c5
5bf61fc
50ca63c
215c29e
 
 
 
fee8826
edaa6c5
 
 
 
5bf61fc
 
fee8826
 
 
 
5bf61fc
fee8826
50ca63c
4b8c206
b369092
 
 
 
 
 
 
61cf524
b369092
61cf524
 
 
50ca63c
 
 
 
 
 
61cf524
 
 
 
 
 
610a6d3
61cf524
 
 
 
 
 
 
 
 
 
 
716b464
 
610a6d3
716b464
 
61cf524
716b464
610a6d3
 
61cf524
 
 
a352d2f
61cf524
 
 
 
4b8c206
61cf524
 
716b464
610a6d3
 
61cf524
215c29e
50ca63c
716b464
0eb69ca
0807606
 
 
 
 
 
017a61f
 
0807606
 
 
 
 
 
 
 
0eb69ca
 
 
 
 
 
 
 
 
 
 
0807606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eb69ca
 
 
0807606
 
 
 
 
a180cf5
61cf524
4b8c206
017a61f
0eb69ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import pandas as pd
from rank_bm25 import BM25Okapi
import gradio as gr
import openai
from datasets import load_dataset

# Load Hugging Face dataset
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("Hugging Face token is not set. Please set HF_TOKEN as an environment variable.")

# Explicitly define dataset file paths
data_files = {
    "train": "hf://datasets/farah1/mental-health-posts-classification/train.csv",
    "validation": "hf://datasets/farah1/mental-health-posts-classification/validation.csv",
}

try:
    print("Loading dataset...")
    dataset = load_dataset("csv", data_files=data_files)
    train_data = dataset["train"].to_pandas()
    print("Dataset loaded successfully.")
except Exception as e:
    print(f"Failed to load dataset: {e}")
    train_data = pd.DataFrame()

# Check and create the 'text' column
if "text" not in train_data.columns:
    if "title" in train_data.columns and "content" in train_data.columns:
        train_data["text"] = train_data["title"] + " " + train_data["content"]
    else:
        raise ValueError("The 'text' column is missing, and the required 'title' and 'content' columns are not available to create it.")

# Initialize BM25
tokenized_train = [doc.split() for doc in train_data["text"]]
bm25 = BM25Okapi(tokenized_train)

# Few-shot classification function
def classify_text(api_key, input_text, k=20):
    # Set the API key
    openai.api_key = api_key
    if not openai.api_key:
        return "Error: OpenAI API key is not set."

    # Tokenize input text
    tokenized_text = input_text.split()
    # Get top-k similar examples using BM25
    scores = bm25.get_scores(tokenized_text)
    top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:k]

    # Build examples for the prompt
    examples = "\n".join(
        f"Example {i+1}:\nText: {train_data.iloc[idx]['text']}\nClassification: "
        f"Stress={train_data.iloc[idx]['Ground_Truth_Stress']}, "
        f"Anxiety={train_data.iloc[idx]['Ground_Truth_Anxiety']}, "
        f"Depression={train_data.iloc[idx]['Ground_Truth_Depression']}, "
        f"Other={train_data.iloc[idx]['Ground_Truth_Other_binary']}\n"
        for i, idx in enumerate(top_k_indices)
    )

    # Construct OpenAI prompt
    prompt = f"""
    You are a mental health specialist. Analyze the provided text and classify it into one or more of the following categories: Stress, Anxiety, Depression, or Other.

    Respond with a single category that best matches the content: Stress, Anxiety, Depression, or Other.

    Here is the text to classify:
    "{input_text}"

    ### Examples:
    {examples}
    """

    try:
        response = openai.ChatCompletion.create(
            messages=[
                {"role": "system", "content": "You are a mental health specialist."},
                {"role": "user", "content": prompt},
            ],
            model="gpt-4",
            temperature=0,
        )
        content = response.choices[0].message.content.strip()
        print("OpenAI Response:", content)
        return content  # Return the label directly
    except Exception as e:
        print(f"Error occurred: {e}")
        return f"Error: {e}"

# Enhanced Gradio Interface with Examples
with gr.Blocks() as interface:
    gr.Markdown(
        """
        # ๐Ÿง  Mental Health Text Classifier
        Welcome to the Mental Health Text Classifier. Enter your OpenAI API key and input text, and the system will classify it into one of the following categories:
        **Stress**, **Anxiety**, **Depression**, or **Other**.

        > **Disclaimer**: This tool is an AI-based model and may not always provide accurate results. It is not a substitute for professional psychological advice. If you have concerns about your mental health, please consult a licensed psychologist or mental health professional.
        """
    )
    with gr.Row():
        api_key_input = gr.Textbox(
            label="๐Ÿ”‘ OpenAI API Key",
            placeholder="Enter your OpenAI API key here...",
            type="password",
        )
    with gr.Row():
        example_texts = gr.Dropdown(
            label="๐Ÿ”Ž Example Texts",
            choices=[
                "I feel like I'm constantly under pressure at work. The deadlines keep piling up, and I can't seem to catch a break. I'm exhausted and overwhelmed.",  # Stress
                "Lately, Iโ€™ve been feeling really anxious about social situations. I keep overthinking what people might think of me, and itโ€™s making me avoid going out altogether.",  # Anxiety
                "I have no energy or motivation to do anything. Even getting out of bed feels like a struggle, and nothing seems to make me happy anymore.",  # Depression
            ],
            value="Select an example text or type your own below.",
            interactive=True,
        )
    with gr.Row():
        text_input = gr.Textbox(
            label="๐Ÿ“ Input Text",
            lines=5,
            placeholder="Enter your thoughts or feelings here...",
        )
    with gr.Row():
        submit_button = gr.Button("Classify")
    with gr.Row():
        output_text = gr.Textbox(
            label="๐Ÿท๏ธ Classification Result",
            placeholder="The classification result will appear here.",
            interactive=False,
        )

    # Update the text input field when an example is selected
    example_texts.change(fn=lambda x: x, inputs=example_texts, outputs=text_input)

    submit_button.click(
        classify_text,
        inputs=[api_key_input, text_input],
        outputs=output_text,
    )

if __name__ == "__main__":
    interface.launch()