Yesandu commited on
Commit
9210351
·
verified ·
1 Parent(s): 372b9c1

Update app.py

Browse files

added a description for diagnose

Files changed (1) hide show
  1. app.py +106 -42
app.py CHANGED
@@ -1,52 +1,119 @@
1
- from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
- import tensorflow as tf
3
  import gradio as gr
 
 
 
4
 
5
- # Load the tokenizer and model
6
- model_name = "Zabihin/Symptom_to_Diagnosis"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
 
9
 
10
- # Clean the input text
11
- def clean_input(symptom_text):
12
- # Remove unwanted characters or non-ASCII characters
13
- symptom_text = ''.join([c for c in symptom_text if ord(c) < 128])
14
- symptom_text = symptom_text.lower() # Optional: Convert to lowercase
15
- return symptom_text
16
-
17
- # Define the predict function
18
  def predict(symptom_text, chat_history=[]):
19
- try:
20
- # Clean the input
21
- symptom_text = clean_input(symptom_text)
22
-
23
- # Tokenize the input
24
- inputs = tokenizer(symptom_text, return_tensors="tf", padding=True, truncation=True, max_length=512)
25
 
26
- # Get model output
27
- outputs = model(**inputs)
28
- logits = outputs.logits
29
- prediction = tf.argmax(logits, axis=-1).numpy()[0]
30
 
31
- # Map the prediction to a label
32
- labels = {
33
- 0: "Allergy", 1: "Arthritis", 2: "Bronchial Asthma", 3: "Cervical Spondylosis",
34
- 4: "Chicken Pox", 5: "Common Cold", 6: "Dengue", 7: "Diabetes", 8: "Drug Reaction",
35
- 9: "Fungal Infection", 10: "Gastroesophageal Reflux Disease", 11: "Hypertension",
36
- 12: "Impetigo", 13: "Jaundice", 14: "Malaria", 15: "Migraine", 16: "Peptic Ulcer Disease",
37
- 17: "Pneumonia", 18: "Psoriasis", 19: "Typhoid", 20: "Urinary Tract Infection", 21: "Varicose Veins"
38
- }
 
 
 
 
 
 
39
 
40
- diagnosis = labels.get(prediction, "Unknown diagnosis")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Add conversation history
43
- chat_history.append(("User", symptom_text))
44
- chat_history.append(("AI", f"Predicted Diagnosis: {diagnosis}. Please consult a doctor for more accurate results."))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- except Exception as e:
47
- chat_history.append(("AI", f"Error: {str(e)}"))
48
-
49
- return chat_history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Gradio UI
52
  with gr.Blocks() as interface:
@@ -54,10 +121,7 @@ with gr.Blocks() as interface:
54
  chatbot = gr.Chatbot()
55
  input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...")
56
  send_button = gr.Button("Send")
57
-
58
- input_box.submit(predict, [input_box, chatbot], [chatbot, input_box])
59
  send_button.click(predict, [input_box, chatbot], [chatbot, input_box])
60
 
61
  if __name__ == "__main__":
62
  interface.launch(share=True, server_name="0.0.0.0", server_port=7860, debug=True)
63
- h
 
 
 
1
  import gradio as gr
2
+ from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
3
+ import tensorflow as tf
4
+ import numpy as np
5
 
6
+ # Load the pre-trained model and tokenizer
7
+ model_name = "your_model_name_here"
 
8
  model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
+ # Function to predict the disease based on symptoms
 
 
 
 
 
 
 
12
  def predict(symptom_text, chat_history=[]):
13
+ # Preprocess the input (lowercase, remove special characters)
14
+ symptom_text = symptom_text.lower().strip()
 
 
 
 
15
 
16
+ # Tokenize the input symptoms
17
+ inputs = tokenizer(symptom_text, return_tensors="tf", padding=True, truncation=True, max_length=512)
 
 
18
 
19
+ # Make the prediction
20
+ outputs = model(**inputs)
21
+ prediction = tf.argmax(outputs.logits, axis=-1).numpy()[0]
22
+ predicted_label = get_prediction_label(prediction)
23
+
24
+ # Provide a more descriptive output
25
+ response = f"Predicted Condition: {predicted_label}\n\n"
26
+ response += f"Description: {get_condition_description(predicted_label)}\n\n"
27
+ response += "Common symptoms for this condition include:\n"
28
+ response += get_condition_symptoms(predicted_label)
29
+
30
+ chat_history.append(("User", symptom_text))
31
+ chat_history.append(("AI", response))
32
+ return response, chat_history
33
 
34
+ # Function to map the prediction index to a condition
35
+ def get_prediction_label(prediction_index):
36
+ # Mapping of prediction index to disease
37
+ condition_map = {
38
+ 0: "Common Cold", # Common Cold
39
+ 1: "Flu", # Flu
40
+ 2: "Covid-19", # Covid-19
41
+ 3: "Diabetes", # Diabetes
42
+ 4: "Pneumonia", # Pneumonia
43
+ 5: "Allergy", # Allergy
44
+ 6: "Asthma", # Asthma
45
+ 7: "Cancer", # Cancer
46
+ 8: "Heart Disease", # Heart Disease
47
+ 9: "Hypertension", # Hypertension
48
+ 10: "Gastroenteritis", # Gastroenteritis
49
+ 11: "Tuberculosis", # Tuberculosis
50
+ 12: "Migraine", # Migraine
51
+ 13: "Chronic Obstructive Pulmonary Disease", # COPD
52
+ 14: "Stroke", # Stroke
53
+ 15: "Acid Reflux", # Acid Reflux
54
+ 16: "Kidney Disease", # Kidney Disease
55
+ 17: "Liver Disease", # Liver Disease
56
+ 18: "Anemia", # Anemia
57
+ 19: "Sepsis", # Sepsis
58
+ # Add more as needed...
59
+ }
60
+ return condition_map.get(prediction_index, "Unknown Condition")
61
 
62
+ # Function to get the description for a condition
63
+ def get_condition_description(condition):
64
+ # Description of conditions
65
+ condition_descriptions = {
66
+ "Common Cold": "A viral infection of your upper respiratory system, causing symptoms like a runny nose, sore throat, and cough.",
67
+ "Flu": "An infectious respiratory illness caused by influenza viruses, with symptoms including fever, chills, and body aches.",
68
+ "Covid-19": "A respiratory illness caused by the SARS-CoV-2 virus, with symptoms ranging from mild cold-like symptoms to severe respiratory distress.",
69
+ "Diabetes": "A group of diseases that affect how your body uses blood sugar (glucose), leading to high blood sugar levels over time.",
70
+ "Pneumonia": "An infection that inflames the air sacs in one or both lungs, which can lead to symptoms like cough, fever, and difficulty breathing.",
71
+ "Allergy": "A reaction of your immune system to a substance that is usually harmless, leading to symptoms such as sneezing, itching, and swelling.",
72
+ "Asthma": "A chronic condition in which your airways narrow and swell, causing difficulty breathing, wheezing, and coughing.",
73
+ "Cancer": "A group of diseases involving abnormal cell growth that can spread to other parts of the body, often accompanied by unexplained weight loss and fatigue.",
74
+ "Heart Disease": "A range of conditions that affect the heart, including coronary artery disease, arrhythmias, and heart attacks, leading to chest pain and shortness of breath.",
75
+ "Hypertension": "High blood pressure, which can lead to serious health problems such as heart disease and stroke if left untreated.",
76
+ "Gastroenteritis": "Inflammation of the stomach and intestines, typically caused by a viral or bacterial infection, leading to diarrhea, vomiting, and stomach cramps.",
77
+ "Tuberculosis": "A potentially serious bacterial infection that primarily affects the lungs, causing persistent cough, chest pain, and weight loss.",
78
+ "Migraine": "A neurological condition characterized by intense headaches, often accompanied by nausea, vomiting, and sensitivity to light and sound.",
79
+ "Chronic Obstructive Pulmonary Disease": "A group of lung diseases, including emphysema and chronic bronchitis, that cause long-term breathing problems.",
80
+ "Stroke": "A medical emergency in which the blood supply to part of the brain is interrupted, leading to symptoms such as numbness, difficulty speaking, and confusion.",
81
+ "Acid Reflux": "A condition in which stomach acid frequently flows back into the esophagus, causing heartburn and discomfort in the chest area.",
82
+ "Kidney Disease": "A condition in which the kidneys are damaged and unable to filter blood properly, leading to symptoms such as fatigue and swelling in the legs.",
83
+ "Liver Disease": "A group of conditions that damage the liver, leading to symptoms such as jaundice, abdominal pain, and swelling.",
84
+ "Anemia": "A condition where the blood doesn't have enough healthy red blood cells to carry oxygen to your body's tissues, causing fatigue and weakness.",
85
+ "Sepsis": "A life-threatening condition caused by the body's response to an infection, leading to symptoms such as fever, confusion, and rapid heart rate.",
86
+ # Add more conditions and descriptions as necessary...
87
+ }
88
+ return condition_descriptions.get(condition, "No description available.")
89
 
90
+ # Function to get the symptoms for a condition
91
+ def get_condition_symptoms(condition):
92
+ # Mapping of condition to common symptoms
93
+ condition_symptoms = {
94
+ "Common Cold": "Cough, runny nose, sore throat, congestion, mild headache",
95
+ "Flu": "Fever, chills, cough, body aches, sore throat, fatigue",
96
+ "Covid-19": "Cough, fever, shortness of breath, fatigue, loss of taste/smell",
97
+ "Diabetes": "Frequent urination, increased thirst, blurred vision, fatigue",
98
+ "Pneumonia": "Cough, fever, shortness of breath, chest pain",
99
+ "Allergy": "Sneezing, runny nose, itchy eyes, wheezing",
100
+ "Asthma": "Wheezing, shortness of breath, chest tightness, coughing",
101
+ "Cancer": "Fatigue, unexplained weight loss, persistent pain, cough",
102
+ "Heart Disease": "Chest pain, shortness of breath, fatigue, irregular heartbeat",
103
+ "Hypertension": "Headaches, dizziness, shortness of breath, blurred vision",
104
+ "Gastroenteritis": "Diarrhea, vomiting, stomach cramps, nausea",
105
+ "Tuberculosis": "Persistent cough, chest pain, coughing up blood, night sweats",
106
+ "Migraine": "Severe headache, nausea, vomiting, sensitivity to light/sound",
107
+ "Chronic Obstructive Pulmonary Disease": "Chronic cough, shortness of breath, wheezing, fatigue",
108
+ "Stroke": "Sudden numbness, confusion, trouble speaking, severe headache",
109
+ "Acid Reflux": "Heartburn, chest pain, regurgitation, difficulty swallowing",
110
+ "Kidney Disease": "Fatigue, swelling in the legs, loss of appetite, difficulty urinating",
111
+ "Liver Disease": "Jaundice, abdominal pain, dark urine, swelling in the legs",
112
+ "Anemia": "Fatigue, weakness, pale skin, shortness of breath",
113
+ "Sepsis": "Fever, chills, rapid heart rate, confusion, low blood pressure",
114
+ # Add more conditions and symptoms as necessary...
115
+ }
116
+ return condition_symptoms.get(condition, "No symptoms available.")
117
 
118
  # Gradio UI
119
  with gr.Blocks() as interface:
 
121
  chatbot = gr.Chatbot()
122
  input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...")
123
  send_button = gr.Button("Send")
 
 
124
  send_button.click(predict, [input_box, chatbot], [chatbot, input_box])
125
 
126
  if __name__ == "__main__":
127
  interface.launch(share=True, server_name="0.0.0.0", server_port=7860, debug=True)