File size: 5,458 Bytes
10e4cb6
a636b36
 
 
 
 
 
 
 
 
c207ffc
3a64fb1
be722e2
 
10e4cb6
4624570
34a4f65
a4cce9a
34a4f65
9903fee
72d0e7a
34a4f65
72d0e7a
9903fee
 
 
 
 
 
 
 
 
34a4f65
9903fee
 
 
 
 
be8b77d
 
9903fee
 
 
 
 
 
61d9513
72d0e7a
 
 
 
 
9903fee
 
 
adf24b1
de3564a
a636b36
26768c0
 
a636b36
 
 
cca43a4
a636b36
14478b9
a636b36
 
ab6d6cd
26768c0
3a64fb1
26768c0
 
 
 
 
10e4cb6
a636b36
 
 
26768c0
a636b36
 
26768c0
 
 
 
 
 
 
 
 
a636b36
26768c0
14478b9
 
26768c0
 
a636b36
26768c0
14478b9
 
26768c0
 
 
 
 
 
 
 
 
 
 
3a64fb1
26768c0
 
 
 
 
 
 
 
 
 
 
1946379
26768c0
 
1946379
26768c0
1946379
26768c0
 
 
 
 
 
87a519d
5878a82
26768c0
3a64fb1
5878a82
 
26768c0
 
 
a636b36
26768c0
 
 
 
5878a82
26768c0
3a64fb1
26768c0
 
5878a82
 
 
 
 
26768c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import streamlit as st
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    pipeline,
)
#from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, pipeline
#from llama_cpp import Llama
from datasets import load_dataset
import os
import requests


# Replace with the direct image URL
flower_image_url = "https://i.postimg.cc/hG2FG85D/2.png"

# Inject custom CSS for the background with a centered and blurred image
st.markdown(
    f"""
    <style>
    /* Container for background */
    html, body {{
        margin: 0;
        padding: 0;
        overflow: hidden;
    }}
    [data-testid="stAppViewContainer"] {{
        position: relative;
        z-index: 1; /* Ensure UI elements are above the background */
    }}
    /* Blurred background image */
    .blurred-background {{
        position: fixed;
        top: 0;
        left: 0;
        width: 100%;
        height: 100%;
        z-index: -1; /* Send background image behind all UI elements */
        background-image: url("{flower_image_url}");
        background-size: cover;
        background-position: center;
        filter: blur(10px); /* Adjust blur ratio here */
        opacity: 0.8; /* Optional: Add slight transparency for a subtle effect */
    }}
    </style>
    """,
    unsafe_allow_html=True
)

# Add the blurred background div
st.markdown('<div class="blurred-background"></div>', unsafe_allow_html=True)

#"""""""""""""""""""""""""   Application Code Starts here   """""""""""""""""""""""""""""""""""""""""""""

# Cache resource for dataset loading
@st.cache_resource
def load_counseling_dataset():
    # Load a smaller subset of the dataset for memory efficiency
    dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train")
    return dataset

# Process the dataset in batches to avoid memory overuse
def process_dataset_in_batches(dataset, batch_size=500):
    for example in dataset.shuffle().select(range(batch_size)):
        yield example

# Fine-tune the model and save it
@st.cache_resource
def fine_tune_model():
    # Load base model and tokenizer
    model_name = "prabureddy/Mental-Health-FineTuned-Mistral-7B-Instruct-v0.2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    # Enable gradient checkpointing for memory optimization
    model.gradient_checkpointing_enable()

    # Prepare dataset for training
    dataset = load_counseling_dataset()
    
    def preprocess_function(examples):
        return tokenizer(examples["context"] + "\n" + examples["response"], truncation=True)

    tokenized_datasets = dataset.map(preprocess_function, batched=True)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./fine_tuned_model",
        evaluation_strategy="steps",
        learning_rate=2e-5,
        per_device_train_batch_size=5,
        per_device_eval_batch_size=5,
        num_train_epochs=3,
        weight_decay=0.01,
        fp16=True,  # Enable FP16 for lower memory usage
        save_total_limit=2,
        save_steps=250,
        logging_steps=50,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()

    # Save the fine-tuned model
    trainer.save_model("./fine_tuned_model")
    tokenizer.save_pretrained("./fine_tuned_model")
    return "./fine_tuned_model"

# Load or fine-tune the model
model_dir = fine_tune_model()

# Load the fine-tuned model for inference
@st.cache_resource
def load_pipeline(model_dir):
    return pipeline("text-generation", model=model_dir)

pipe = load_pipeline(model_dir)

# Streamlit App
st.title("Mental Health Support Assistant")
st.markdown("""
Welcome to the **Mental Health Support Assistant**.  
This tool helps detect potential mental health concerns based on user input and provides **uplifting and positive suggestions** to boost morale.
""")

# User input for mental health concerns
user_input = st.text_area("Please share your concern:", placeholder="Type your question or concern here...")

if st.button("Get Supportive Response"):
    if user_input.strip():
        with st.spinner("Analyzing your input and generating a response..."):
            try:
                # Generate a response
                response = pipe(user_input, max_length=150, num_return_sequences=1)[0]["generated_text"]
                st.subheader("Supportive Suggestion:")
                st.markdown(f"**{response}**")
            except Exception as e:
                st.error(f"An error occurred while generating the response: {e}")
    else:
        st.error("Please enter a concern to receive suggestions.")

# Sidebar for additional resources
st.sidebar.header("Additional Resources")
st.sidebar.markdown("""
- [Mental Health Foundation](https://www.mentalhealth.org)
- [Mind](https://www.mind.org.uk)
- [National Suicide Prevention Lifeline](https://suicidepreventionlifeline.org)
""")
st.sidebar.info("This application is not a replacement for professional counseling. If you're in crisis, seek professional help immediately.")