|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import gradio as gr |
|
import spaces |
|
|
|
|
|
print("Loading model...") |
|
tokenizer = AutoTokenizer.from_pretrained("ombhojane/mental-health-assistant") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"ombhojane/mental-health-assistant", |
|
device_map="auto", |
|
torch_dtype=torch.float32 |
|
) |
|
print("Model loaded successfully!") |
|
|
|
@spaces.GPU |
|
def generate_response(message, max_length, temperature): |
|
if not message: |
|
return "Please enter your concerns or feelings to get support." |
|
|
|
|
|
prompt = f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_length=max_length, |
|
temperature=temperature, |
|
do_sample=True, |
|
top_p=0.9 |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response.replace(message, "").strip() |
|
|
|
|
|
custom_css = """ |
|
.gradio-container { |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f5f7f9; |
|
} |
|
.main-title { |
|
text-align: center; |
|
color: #2C3E50; |
|
margin-bottom: 1em; |
|
} |
|
.description { |
|
text-align: justify; |
|
margin-bottom: 2em; |
|
color: #34495E; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Wellness 🌿 |
|
### Your Compassionate Mental Health Support Companion |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
gr.Markdown( |
|
""" |
|
### About |
|
Welcome to your safe space for mental health support. This AI assistant is trained to provide |
|
empathetic listening, emotional support, and helpful suggestions for managing mental health |
|
concerns. While it's not a replacement for professional help, it can offer guidance and |
|
support when you need someone to talk to. |
|
|
|
**Remember**: In case of emergency or severe distress, please contact professional mental |
|
health services or emergency services immediately. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
input_text = gr.Textbox( |
|
lines=4, |
|
placeholder="Share your thoughts, feelings, or concerns...", |
|
label="Your Message" |
|
) |
|
|
|
with gr.Row(): |
|
max_length = gr.Slider( |
|
minimum=100, |
|
maximum=1000, |
|
value=512, |
|
step=50, |
|
label="Response Length", |
|
info="Adjust the length of the response" |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Response Style", |
|
info="Higher values for more creative responses, lower for more focused ones" |
|
) |
|
|
|
submit_btn = gr.Button("Get Support", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
output_text = gr.Textbox( |
|
lines=12, |
|
label="Support Response", |
|
show_copy_button=True |
|
) |
|
|
|
with gr.Accordion("Sample Conversation Starters", open=False): |
|
gr.Markdown( |
|
""" |
|
- I've been feeling overwhelmed lately and having trouble sleeping |
|
- How can I manage anxiety during stressful situations? |
|
- I'm having difficulty concentrating at work/school |
|
- What are some good self-care practices for mental health? |
|
- I'm feeling lonely and isolated, what can I do? |
|
""" |
|
) |
|
|
|
|
|
with gr.Accordion("Important Resources", open=False): |
|
gr.Markdown( |
|
""" |
|
### Emergency Contacts |
|
- National Crisis Helpline: 988 (US) |
|
- Emergency Services: 911 (US) / 112 (EU) |
|
|
|
### Mental Health Resources |
|
- National Alliance on Mental Illness (NAMI): 1-800-950-NAMI |
|
- Crisis Text Line: Text HOME to 741741 |
|
- Psychology Today Therapist Finder |
|
- BetterHelp Online Counseling |
|
|
|
Remember: This AI assistant is not a substitute for professional mental health care. |
|
If you're experiencing severe symptoms or having thoughts of self-harm, please seek |
|
professional help immediately. |
|
""" |
|
) |
|
|
|
submit_btn.click( |
|
generate_response, |
|
inputs=[input_text, max_length, temperature], |
|
outputs=output_text |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |