Yesandu commited on
Commit
d1842d2
·
verified ·
1 Parent(s): ae900b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -54
app.py CHANGED
@@ -1,39 +1,18 @@
1
  from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
  import tensorflow as tf
3
  import gradio as gr
4
- import os
5
- import re
6
- from nltk.corpus import stopwords
7
- from nltk.stem import PorterStemmer
8
-
9
- # Ensure you have the necessary nltk resources
10
- import nltk
11
- nltk.download('stopwords')
12
-
13
- # Caching the model locally to avoid re-downloading
14
- MODEL_NAME = "Zabihin/Symptom_to_Diagnosis"
15
- CACHE_DIR = "./cached_model"
16
- if not os.path.exists(CACHE_DIR):
17
- os.makedirs(CACHE_DIR)
18
 
19
  # Load the tokenizer and model
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
21
- model = TFAutoModelForSequenceClassification.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
22
-
23
- # Initialize stopwords and stemmer
24
- stop_words = set(stopwords.words('english'))
25
- stemmer = PorterStemmer()
26
 
27
- # Clean the input text with advanced preprocessing
28
  def clean_input(symptom_text):
29
- # Remove non-ASCII characters and convert to lowercase
30
  symptom_text = ''.join([c for c in symptom_text if ord(c) < 128])
31
- symptom_text = symptom_text.lower().strip() # Remove leading/trailing spaces
32
-
33
- # Remove stopwords and apply stemming
34
- words = symptom_text.split()
35
- filtered_words = [stemmer.stem(word) for word in words if word not in stop_words]
36
- return ' '.join(filtered_words)
37
 
38
  # Define the predict function
39
  def predict(symptom_text, chat_history=[]):
@@ -49,37 +28,46 @@ def predict(symptom_text, chat_history=[]):
49
  logits = outputs.logits
50
  prediction = tf.argmax(logits, axis=-1).numpy()[0]
51
 
52
- # Map the prediction to a label and description
53
  labels = {
54
- 0: ("Allergy", "A condition in which the immune system reacts abnormally to a foreign substance."),
55
- 1: ("Arthritis", "A disease causing painful inflammation and stiffness of the joints."),
56
- 2: ("Bronchial Asthma", "A condition in which your airways narrow and swell and produce extra mucus."),
57
- 3: ("Cervical Spondylosis", "Age-related wear and tear affecting the bones, discs, and joints in the neck."),
58
- 4: ("Chicken Pox", "A viral infection causing an itchy, blister-like rash on the skin."),
59
- 5: ("Common Cold", "A viral infection affecting the upper respiratory tract, causing a sore throat and runny nose."),
60
- 6: ("Dengue", "A viral illness transmitted by mosquitoes, leading to flu-like symptoms."),
61
- 7: ("Diabetes", "A group of diseases that affect how your body uses blood sugar (glucose)."),
62
- 8: ("Drug Reaction", "A harmful reaction in the body due to a medication or drug."),
63
- 9: ("Fungal Infection", "An infection caused by fungi, often affecting the skin or nails."),
64
- 10: ("Gastroesophageal Reflux Disease", "A digestive disorder where stomach acid irritates the food pipe."),
65
- 11: ("Hypertension", "High blood pressure, a condition that can lead to serious health issues if untreated."),
66
- 12: ("Impetigo", "A contagious bacterial skin infection that causes red sores or blisters."),
67
- 13: ("Jaundice", "A yellowish tint to the skin or eyes caused by excess bilirubin in the blood."),
68
- 14: ("Malaria", "A life-threatening disease transmitted by mosquitoes, caused by a parasite."),
69
- 15: ("Migraine", "A severe headache that can cause intense throbbing or a pulsing sensation."),
70
- 16: ("Peptic Ulcer Disease", "Sores that develop on the lining of the stomach or the upper part of the small intestine."),
71
- 17: ("Pneumonia", "An infection that inflames the air sacs in one or both lungs, causing cough and difficulty breathing."),
72
- 18: ("Psoriasis", "An autoimmune condition that causes skin cells to multiply too quickly, leading to patches of red, scaly skin."),
73
- 19: ("Typhoid", "A bacterial infection caused by Salmonella typhi, leading to fever, fatigue, and abdominal pain."),
74
- 20: ("Urinary Tract Infection", "An infection in any part of the urinary system, including the kidneys, bladder, or urethra."),
75
- 21: ("Varicose Veins", "Swollen and enlarged veins, often appearing blue or dark purple, usually in the legs."),
 
 
 
 
 
 
 
 
76
  }
77
 
78
- diagnosis, description = labels.get(prediction, ("Unknown diagnosis", "No description available for this condition"))
 
79
 
80
  # Add conversation history
81
  chat_history.append(("User", symptom_text))
82
- chat_history.append(("AI", f"Predicted Diagnosis: **{diagnosis}**. {description} Please consult a doctor for more accurate results."))
83
 
84
  except Exception as e:
85
  chat_history.append(("AI", f"Error: {str(e)}"))
@@ -88,7 +76,7 @@ def predict(symptom_text, chat_history=[]):
88
 
89
  # Gradio UI
90
  with gr.Blocks() as interface:
91
- gr.Markdown("<h1 style='text-align: center; font-size: 50px; margin-top: 40px; margin-bottom: 40px;'>Medi Mind - Your AI Health Assistant</h1>")
92
  chatbot = gr.Chatbot()
93
  input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...")
94
  send_button = gr.Button("Send")
 
1
  from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
  import tensorflow as tf
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Load the tokenizer and model
6
+ model_name = "Zabihin/Symptom_to_Diagnosis"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
 
 
 
9
 
10
+ # Clean the input text
11
  def clean_input(symptom_text):
12
+ # Remove unwanted characters or non-ASCII characters
13
  symptom_text = ''.join([c for c in symptom_text if ord(c) < 128])
14
+ symptom_text = symptom_text.lower() # Optional: Convert to lowercase
15
+ return symptom_text
 
 
 
 
16
 
17
  # Define the predict function
18
  def predict(symptom_text, chat_history=[]):
 
28
  logits = outputs.logits
29
  prediction = tf.argmax(logits, axis=-1).numpy()[0]
30
 
31
+ # Map the prediction to a label
32
  labels = {
33
+ 0: "Allergy", 1: "Arthritis", 2: "Bronchial Asthma", 3: "Cervical Spondylosis",
34
+ 4: "Chicken Pox", 5: "Common Cold", 6: "Dengue", 7: "Diabetes", 8: "Drug Reaction",
35
+ 9: "Fungal Infection", 10: "Gastroesophageal Reflux Disease", 11: "Hypertension",
36
+ 12: "Impetigo", 13: "Jaundice", 14: "Malaria", 15: "Migraine", 16: "Peptic Ulcer Disease",
37
+ 17: "Pneumonia", 18: "Psoriasis", 19: "Typhoid", 20: "Urinary Tract Infection", 21: "Varicose Veins"
38
+ }
39
+
40
+ descriptions = {
41
+ "Allergy": "An immune system reaction to foreign substances.",
42
+ "Arthritis": "Inflammation of one or more joints.",
43
+ "Bronchial Asthma": "A condition where the airways become inflamed and narrow.",
44
+ "Cervical Spondylosis": "Age-related changes in the bones, discs, and joints of the neck.",
45
+ "Chicken Pox": "A highly contagious viral infection causing an itchy skin rash.",
46
+ "Common Cold": "A viral infection of the upper respiratory tract, causing sneezing, runny nose, and sore throat.",
47
+ "Dengue": "A viral disease transmitted by mosquitoes, causing fever and severe pain.",
48
+ "Diabetes": "A disease that affects how your body processes blood sugar.",
49
+ "Drug Reaction": "An adverse response to a medication.",
50
+ "Fungal Infection": "An infection caused by fungi affecting the skin or organs.",
51
+ "Gastroesophageal Reflux Disease": "A chronic digestive condition where stomach acid irritates the food pipe.",
52
+ "Hypertension": "High blood pressure that can lead to heart disease.",
53
+ "Impetigo": "A contagious bacterial skin infection.",
54
+ "Jaundice": "A yellowing of the skin or eyes due to liver disease.",
55
+ "Malaria": "A serious disease transmitted by mosquito bites, causing fever and chills.",
56
+ "Migraine": "Severe headaches often accompanied by nausea and sensitivity to light.",
57
+ "Peptic Ulcer Disease": "Sores in the stomach lining or the upper part of the small intestine.",
58
+ "Pneumonia": "An infection that inflames the air sacs in one or both lungs.",
59
+ "Psoriasis": "A chronic autoimmune disease causing the rapid growth of skin cells.",
60
+ "Typhoid": "A bacterial infection causing high fever, abdominal pain, and weakness.",
61
+ "Urinary Tract Infection": "An infection in any part of the urinary system.",
62
+ "Varicose Veins": "Swollen, twisted veins caused by faulty valves in the veins."
63
  }
64
 
65
+ diagnosis = labels.get(prediction, "Unknown diagnosis")
66
+ description = descriptions.get(diagnosis, "No description available.")
67
 
68
  # Add conversation history
69
  chat_history.append(("User", symptom_text))
70
+ chat_history.append(("AI", f"Predicted Diagnosis: <b>{diagnosis}</b>. {description} Please consult a doctor for more accurate results."))
71
 
72
  except Exception as e:
73
  chat_history.append(("AI", f"Error: {str(e)}"))
 
76
 
77
  # Gradio UI
78
  with gr.Blocks() as interface:
79
+ gr.Markdown("<h1 style='text-align: center; font-size: 40px; margin-top: 30px; margin-bottom: 30px;'>Medi Mind - Your AI Health Assistant</h1>")
80
  chatbot = gr.Chatbot()
81
  input_box = gr.Textbox(show_label=False, placeholder="Describe your symptoms here...")
82
  send_button = gr.Button("Send")