Medi_Mind / app.py
Yesandu's picture
Update app.py
3d03888 verified
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
import gradio as gr
# Load the tokenizer and model
model_name = "Zabihin/Symptom_to_Diagnosis"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
# Clean the input text
def clean_input(symptom_text):
# Remove unwanted characters or non-ASCII characters
symptom_text = ''.join([c for c in symptom_text if ord(c) < 128])
symptom_text = symptom_text.lower() # Optional: Convert to lowercase
return symptom_text
# Define the predict function
def predict(symptom_text, chat_history=[]):
try:
# Clean the input
symptom_text = clean_input(symptom_text)
# Tokenize the input
inputs = tokenizer(symptom_text, return_tensors="tf", padding=True, truncation=True, max_length=512)
# Get model output
outputs = model(**inputs)
logits = outputs.logits
prediction = tf.argmax(logits, axis=-1).numpy()[0]
# Map the prediction to a label
labels = {
0: "Allergy", 1: "Arthritis", 2: "Bronchial Asthma", 3: "Cervical Spondylosis",
4: "Chicken Pox", 5: "Common Cold", 6: "Dengue", 7: "Diabetes", 8: "Drug Reaction",
9: "Fungal Infection", 10: "Gastroesophageal Reflux Disease", 11: "Hypertension",
12: "Impetigo", 13: "Jaundice", 14: "Malaria", 15: "Migraine", 16: "Peptic Ulcer Disease",
17: "Pneumonia", 18: "Psoriasis", 19: "Typhoid", 20: "Urinary Tract Infection", 21: "Varicose Veins"
}
descriptions = {
"Allergy": "An immune system reaction to foreign substances.",
"Arthritis": "Inflammation of one or more joints.",
"Bronchial Asthma": "A condition where the airways become inflamed and narrow.",
"Cervical Spondylosis": "Age-related changes in the bones, discs, and joints of the neck.",
"Chicken Pox": "A highly contagious viral infection causing an itchy skin rash.",
"Common Cold": "A viral infection of the upper respiratory tract, causing sneezing, runny nose, and sore throat.",
"Dengue": "A viral disease transmitted by mosquitoes, causing fever and severe pain.",
"Diabetes": "A disease that affects how your body processes blood sugar.",
"Drug Reaction": "An adverse response to a medication.",
"Fungal Infection": "An infection caused by fungi affecting the skin or organs.",
"Gastroesophageal Reflux Disease": "A chronic digestive condition where stomach acid irritates the food pipe.",
"Hypertension": "High blood pressure that can lead to heart disease.",
"Impetigo": "A contagious bacterial skin infection.",
"Jaundice": "A yellowing of the skin or eyes due to liver disease.",
"Malaria": "A serious disease transmitted by mosquito bites, causing fever and chills.",
"Migraine": "Severe headaches often accompanied by nausea and sensitivity to light.",
"Peptic Ulcer Disease": "Sores in the stomach lining or the upper part of the small intestine.",
"Pneumonia": "An infection that inflames the air sacs in one or both lungs.",
"Psoriasis": "A chronic autoimmune disease causing the rapid growth of skin cells.",
"Typhoid": "A bacterial infection causing high fever, abdominal pain, and weakness.",
"Urinary Tract Infection": "An infection in any part of the urinary system.",
"Varicose Veins": "Swollen, twisted veins caused by faulty valves in the veins."
}
diagnosis = labels.get(prediction, "Unknown diagnosis")
description = descriptions.get(diagnosis, "No description available.")
# Add conversation history
chat_history.append(("User", symptom_text))
chat_history.append(("AI", f"Predicted Diagnosis: <b>{diagnosis}</b>. {description} Please consult a doctor for more accurate results."))
except Exception as e:
chat_history.append(("AI", f"Error: {str(e)}"))
return chat_history, ""
# Gradio UI
with gr.Blocks() as interface:
gr.Markdown("""
<h1 style='text-align: center; font-size: 50px; margin-top: 50px; margin-bottom: 30px;'>Medi Mind - Your AI Health Assistant</h1>
""")
chatbot = gr.Chatbot()
input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...")
send_button = gr.Button("Send")
input_box.submit(predict, [input_box, chatbot], [chatbot, input_box])
send_button.click(predict, [input_box, chatbot], [chatbot, input_box])
if __name__ == "__main__":
interface.launch(share=True, server_name="0.0.0.0", server_port=7860, debug=True)