File size: 4,819 Bytes
5d44adb 9210351 5d44adb 26a55ff 5d44adb d1842d2 282969e d1842d2 5d44adb d1842d2 5d44adb d1842d2 5d44adb 282969e 5d44adb 282969e 5d44adb 282969e d1842d2 5d44adb d1842d2 5d44adb 282969e d1842d2 282969e 5d44adb d1842d2 282969e 5d44adb 282969e 3d03888 282969e 5d44adb 282969e 26a55ff 5654e11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
import gradio as gr
# Load the tokenizer and model
model_name = "Zabihin/Symptom_to_Diagnosis"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
# Clean the input text
def clean_input(symptom_text):
# Remove unwanted characters or non-ASCII characters
symptom_text = ''.join([c for c in symptom_text if ord(c) < 128])
symptom_text = symptom_text.lower() # Optional: Convert to lowercase
return symptom_text
# Define the predict function
def predict(symptom_text, chat_history=[]):
try:
# Clean the input
symptom_text = clean_input(symptom_text)
# Tokenize the input
inputs = tokenizer(symptom_text, return_tensors="tf", padding=True, truncation=True, max_length=512)
# Get model output
outputs = model(**inputs)
logits = outputs.logits
prediction = tf.argmax(logits, axis=-1).numpy()[0]
# Map the prediction to a label
labels = {
0: "Allergy", 1: "Arthritis", 2: "Bronchial Asthma", 3: "Cervical Spondylosis",
4: "Chicken Pox", 5: "Common Cold", 6: "Dengue", 7: "Diabetes", 8: "Drug Reaction",
9: "Fungal Infection", 10: "Gastroesophageal Reflux Disease", 11: "Hypertension",
12: "Impetigo", 13: "Jaundice", 14: "Malaria", 15: "Migraine", 16: "Peptic Ulcer Disease",
17: "Pneumonia", 18: "Psoriasis", 19: "Typhoid", 20: "Urinary Tract Infection", 21: "Varicose Veins"
}
descriptions = {
"Allergy": "An immune system reaction to foreign substances.",
"Arthritis": "Inflammation of one or more joints.",
"Bronchial Asthma": "A condition where the airways become inflamed and narrow.",
"Cervical Spondylosis": "Age-related changes in the bones, discs, and joints of the neck.",
"Chicken Pox": "A highly contagious viral infection causing an itchy skin rash.",
"Common Cold": "A viral infection of the upper respiratory tract, causing sneezing, runny nose, and sore throat.",
"Dengue": "A viral disease transmitted by mosquitoes, causing fever and severe pain.",
"Diabetes": "A disease that affects how your body processes blood sugar.",
"Drug Reaction": "An adverse response to a medication.",
"Fungal Infection": "An infection caused by fungi affecting the skin or organs.",
"Gastroesophageal Reflux Disease": "A chronic digestive condition where stomach acid irritates the food pipe.",
"Hypertension": "High blood pressure that can lead to heart disease.",
"Impetigo": "A contagious bacterial skin infection.",
"Jaundice": "A yellowing of the skin or eyes due to liver disease.",
"Malaria": "A serious disease transmitted by mosquito bites, causing fever and chills.",
"Migraine": "Severe headaches often accompanied by nausea and sensitivity to light.",
"Peptic Ulcer Disease": "Sores in the stomach lining or the upper part of the small intestine.",
"Pneumonia": "An infection that inflames the air sacs in one or both lungs.",
"Psoriasis": "A chronic autoimmune disease causing the rapid growth of skin cells.",
"Typhoid": "A bacterial infection causing high fever, abdominal pain, and weakness.",
"Urinary Tract Infection": "An infection in any part of the urinary system.",
"Varicose Veins": "Swollen, twisted veins caused by faulty valves in the veins."
}
diagnosis = labels.get(prediction, "Unknown diagnosis")
description = descriptions.get(diagnosis, "No description available.")
# Add conversation history
chat_history.append(("User", symptom_text))
chat_history.append(("AI", f"Predicted Diagnosis: <b>{diagnosis}</b>. {description} Please consult a doctor for more accurate results."))
except Exception as e:
chat_history.append(("AI", f"Error: {str(e)}"))
return chat_history, ""
# Gradio UI
with gr.Blocks() as interface:
gr.Markdown("""
<h1 style='text-align: center; font-size: 50px; margin-top: 50px; margin-bottom: 30px;'>Medi Mind - Your AI Health Assistant</h1>
""")
chatbot = gr.Chatbot()
input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...")
send_button = gr.Button("Send")
input_box.submit(predict, [input_box, chatbot], [chatbot, input_box])
send_button.click(predict, [input_box, chatbot], [chatbot, input_box])
if __name__ == "__main__":
interface.launch(share=True, server_name="0.0.0.0", server_port=7860, debug=True)
|