File size: 4,819 Bytes
5d44adb
9210351
5d44adb
26a55ff
5d44adb
d1842d2
 
 
282969e
d1842d2
5d44adb
d1842d2
5d44adb
d1842d2
 
5d44adb
 
282969e
5d44adb
 
 
 
 
 
282969e
5d44adb
 
 
 
282969e
d1842d2
5d44adb
d1842d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d44adb
282969e
d1842d2
 
282969e
5d44adb
 
d1842d2
282969e
5d44adb
 
 
 
282969e
 
 
3d03888
 
 
282969e
 
 
5d44adb
 
282969e
26a55ff
 
5654e11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
import gradio as gr

# Load the tokenizer and model
model_name = "Zabihin/Symptom_to_Diagnosis"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)

# Clean the input text
def clean_input(symptom_text):
    # Remove unwanted characters or non-ASCII characters
    symptom_text = ''.join([c for c in symptom_text if ord(c) < 128])
    symptom_text = symptom_text.lower()  # Optional: Convert to lowercase
    return symptom_text

# Define the predict function
def predict(symptom_text, chat_history=[]):
    try:
        # Clean the input
        symptom_text = clean_input(symptom_text)

        # Tokenize the input
        inputs = tokenizer(symptom_text, return_tensors="tf", padding=True, truncation=True, max_length=512)

        # Get model output
        outputs = model(**inputs)
        logits = outputs.logits
        prediction = tf.argmax(logits, axis=-1).numpy()[0]

        # Map the prediction to a label
        labels = {
            0: "Allergy", 1: "Arthritis", 2: "Bronchial Asthma", 3: "Cervical Spondylosis", 
            4: "Chicken Pox", 5: "Common Cold", 6: "Dengue", 7: "Diabetes", 8: "Drug Reaction", 
            9: "Fungal Infection", 10: "Gastroesophageal Reflux Disease", 11: "Hypertension", 
            12: "Impetigo", 13: "Jaundice", 14: "Malaria", 15: "Migraine", 16: "Peptic Ulcer Disease", 
            17: "Pneumonia", 18: "Psoriasis", 19: "Typhoid", 20: "Urinary Tract Infection", 21: "Varicose Veins"
        }

        descriptions = {
            "Allergy": "An immune system reaction to foreign substances.",
            "Arthritis": "Inflammation of one or more joints.",
            "Bronchial Asthma": "A condition where the airways become inflamed and narrow.",
            "Cervical Spondylosis": "Age-related changes in the bones, discs, and joints of the neck.",
            "Chicken Pox": "A highly contagious viral infection causing an itchy skin rash.",
            "Common Cold": "A viral infection of the upper respiratory tract, causing sneezing, runny nose, and sore throat.",
            "Dengue": "A viral disease transmitted by mosquitoes, causing fever and severe pain.",
            "Diabetes": "A disease that affects how your body processes blood sugar.",
            "Drug Reaction": "An adverse response to a medication.",
            "Fungal Infection": "An infection caused by fungi affecting the skin or organs.",
            "Gastroesophageal Reflux Disease": "A chronic digestive condition where stomach acid irritates the food pipe.",
            "Hypertension": "High blood pressure that can lead to heart disease.",
            "Impetigo": "A contagious bacterial skin infection.",
            "Jaundice": "A yellowing of the skin or eyes due to liver disease.",
            "Malaria": "A serious disease transmitted by mosquito bites, causing fever and chills.",
            "Migraine": "Severe headaches often accompanied by nausea and sensitivity to light.",
            "Peptic Ulcer Disease": "Sores in the stomach lining or the upper part of the small intestine.",
            "Pneumonia": "An infection that inflames the air sacs in one or both lungs.",
            "Psoriasis": "A chronic autoimmune disease causing the rapid growth of skin cells.",
            "Typhoid": "A bacterial infection causing high fever, abdominal pain, and weakness.",
            "Urinary Tract Infection": "An infection in any part of the urinary system.",
            "Varicose Veins": "Swollen, twisted veins caused by faulty valves in the veins."
        }

        diagnosis = labels.get(prediction, "Unknown diagnosis")
        description = descriptions.get(diagnosis, "No description available.")

        # Add conversation history
        chat_history.append(("User", symptom_text))
        chat_history.append(("AI", f"Predicted Diagnosis: <b>{diagnosis}</b>. {description} Please consult a doctor for more accurate results."))

    except Exception as e:
        chat_history.append(("AI", f"Error: {str(e)}"))

    return chat_history, ""

# Gradio UI
with gr.Blocks() as interface:
    gr.Markdown("""
        <h1 style='text-align: center; font-size: 50px; margin-top: 50px; margin-bottom: 30px;'>Medi Mind - Your AI Health Assistant</h1>
    """)
    chatbot = gr.Chatbot()
    input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...")
    send_button = gr.Button("Send")

    input_box.submit(predict, [input_box, chatbot], [chatbot, input_box])
    send_button.click(predict, [input_box, chatbot], [chatbot, input_box])

if __name__ == "__main__":
    interface.launch(share=True, server_name="0.0.0.0", server_port=7860, debug=True)