DreamStream-1 commited on
Commit
3814a8a
Β·
verified Β·
1 Parent(s): 8f7df83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -93
app.py CHANGED
@@ -1,110 +1,289 @@
 
1
  import gradio as gr
2
- import pandas as pd
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from sklearn.tree import DecisionTreeClassifier
5
  from sklearn.ensemble import RandomForestClassifier
6
  from sklearn.naive_bayes import GaussianNB
7
  from sklearn.metrics import accuracy_score
8
 
9
- # Load datasets
10
- def load_data():
11
- df = pd.read_csv("Training.csv")
12
- tr = pd.read_csv("Testing.csv")
13
-
14
- # Encode diseases
15
- disease_dict = {
16
- 'Fungal infection': 0, 'Allergy': 1, 'GERD': 2, 'Chronic cholestasis': 3, 'Drug Reaction': 4,
17
- 'Peptic ulcer diseae': 5, 'AIDS': 6, 'Diabetes ': 7, 'Gastroenteritis': 8, 'Bronchial Asthma': 9,
18
- 'Hypertension ': 10, 'Migraine': 11, 'Cervical spondylosis': 12, 'Paralysis (brain hemorrhage)': 13,
19
- 'Jaundice': 14, 'Malaria': 15, 'Chicken pox': 16, 'Dengue': 17, 'Typhoid': 18, 'hepatitis A': 19,
20
- 'Hepatitis B': 20, 'Hepatitis C': 21, 'Hepatitis D': 22, 'Hepatitis E': 23, 'Alcoholic hepatitis': 24,
21
- 'Tuberculosis': 25, 'Common Cold': 26, 'Pneumonia': 27, 'Dimorphic hemmorhoids(piles)': 28,
22
- 'Heart attack': 29, 'Varicose veins': 30, 'Hypothyroidism': 31, 'Hyperthyroidism': 32,
23
- 'Hypoglycemia': 33, 'Osteoarthristis': 34, 'Arthritis': 35,
24
- '(vertigo) Paroymsal Positional Vertigo': 36, 'Acne': 37, 'Urinary tract infection': 38,
25
- 'Psoriasis': 39, 'Impetigo': 40
26
- }
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- df.replace({'prognosis': disease_dict}, inplace=True)
29
- df = df.infer_objects(copy=False)
30
 
31
- tr.replace({'prognosis': disease_dict}, inplace=True)
32
- tr = tr.infer_objects(copy=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- return df, tr, disease_dict
35
-
36
- try:
37
- df, tr, disease_dict = load_data()
38
- except FileNotFoundError as e:
39
- raise RuntimeError("Data files not found. Please ensure `Training.csv` and `Testing.csv` are available.")
40
- except Exception as e:
41
- raise RuntimeError(f"An error occurred while loading data: {e}")
42
-
43
- l1 = list(df.columns[:-1])
44
- X = df[l1]
45
- y = df['prognosis']
46
- X_test = tr[l1]
47
- y_test = tr['prognosis']
48
-
49
- # Trained models
50
- def train_models():
51
- models = {
52
- "Decision Tree": DecisionTreeClassifier(),
53
- "Random Forest": RandomForestClassifier(),
54
- "Naive Bayes": GaussianNB()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  }
56
- trained_models = {}
57
- for model_name, model_obj in models.items():
58
- model_obj.fit(X, y)
59
- acc = accuracy_score(y_test, model_obj.predict(X_test))
60
- trained_models[model_name] = (model_obj, acc)
61
- return trained_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- trained_models = train_models()
 
 
 
 
 
 
 
64
 
65
- def predict_disease(model, symptoms):
66
- input_test = np.zeros(len(l1))
 
 
67
  for symptom in symptoms:
68
- if symptom in l1:
69
- input_test[l1.index(symptom)] = 1
70
- prediction = model.predict([input_test])[0]
71
- return list(disease_dict.keys())[list(disease_dict.values()).index(prediction)]
72
-
73
- # Gradio Interface
74
- def app_function(name, symptom1, symptom2, symptom3, symptom4, symptom5):
75
- if not name.strip():
76
- return "Please enter the patient's name."
77
-
78
- symptoms_selected = [s for s in [symptom1, symptom2, symptom3, symptom4, symptom5] if s != "None"]
79
-
80
- if len(symptoms_selected) < 3:
81
- return "Please select at least 3 symptoms for accurate prediction."
82
-
83
- results = []
84
- for model_name, (model, acc) in trained_models.items():
85
- prediction = predict_disease(model, symptoms_selected)
86
- result = f"{model_name} Prediction: Predicted Disease: **{prediction}**"
87
- result += f" (Accuracy: {acc * 100:.2f}%)"
88
- results.append(result)
89
 
90
- return "\n\n".join(results)
91
-
92
- # Gradio Interface Setup
93
- iface = gr.Interface(
94
- fn=app_function,
95
- inputs=[
96
- gr.Textbox(label="Name of Patient"),
97
- gr.Dropdown(["None"] + l1, label="Symptom 1"),
98
- gr.Dropdown(["None"] + l1, label="Symptom 2"),
99
- gr.Dropdown(["None"] + l1, label="Symptom 3"),
100
- gr.Dropdown(["None"] + l1, label="Symptom 4"),
101
- gr.Dropdown(["None"] + l1, label="Symptom 5"),
102
- ],
103
- outputs=gr.Textbox(label="Prediction"),
104
- title="Disease Predictor Using Machine Learning",
105
- description="For accurate results, please select at least 3 symptoms.",
106
- article="**Caution:** This system is designed for informational purposes only. Please visit a healthcare provider for any medical concerns."
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # Launch the Gradio application
110
- iface.launch()
 
1
+ import os
2
  import gradio as gr
3
+ import nltk
4
  import numpy as np
5
+ import tflearn
6
+ import random
7
+ import json
8
+ import pickle
9
+ from nltk.tokenize import word_tokenize
10
+ from nltk.stem.lancaster import LancasterStemmer
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
12
+ import googlemaps
13
+ import folium
14
+ import torch
15
+ import pandas as pd
16
+ from sklearn.preprocessing import LabelEncoder
17
+ from sklearn.model_selection import train_test_split
18
  from sklearn.tree import DecisionTreeClassifier
19
  from sklearn.ensemble import RandomForestClassifier
20
  from sklearn.naive_bayes import GaussianNB
21
  from sklearn.metrics import accuracy_score
22
 
23
+ # Suppress TensorFlow warnings
24
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
25
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
26
+
27
+ # Download necessary NLTK resources
28
+ nltk.download("punkt")
29
+ stemmer = LancasterStemmer()
30
+
31
+ # Load intents and chatbot training data
32
+ with open("intents.json") as file:
33
+ intents_data = json.load(file)
34
+
35
+ with open("data.pickle", "rb") as f:
36
+ words, labels, training, output = pickle.load(f)
37
+
38
+ # Build the chatbot model
39
+ net = tflearn.input_data(shape=[None, len(training[0])])
40
+ net = tflearn.fully_connected(net, 8)
41
+ net = tflearn.fully_connected(net, 8)
42
+ net = tflearn.fully_connected(net, len(output[0]), activation="softmax")
43
+ net = tflearn.regression(net)
44
+ chatbot_model = tflearn.DNN(net)
45
+ chatbot_model.load("MentalHealthChatBotmodel.tflearn")
46
+
47
+ # Hugging Face sentiment and emotion models
48
+ tokenizer_sentiment = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
49
+ model_sentiment = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment")
50
+ tokenizer_emotion = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
51
+ model_emotion = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
52
 
53
+ # Google Maps API Client
54
+ gmaps = googlemaps.Client(key=os.getenv("GOOGLE_API_KEY"))
55
 
56
+ # Load the disease dataset
57
+ df_train = pd.read_csv("Training.csv") # Change the file path as necessary
58
+ df_test = pd.read_csv("Testing.csv") # Change the file path as necessary
59
+
60
+ # Encode diseases
61
+ disease_dict = {
62
+ 'Fungal infection': 0, 'Allergy': 1, 'GERD': 2, 'Chronic cholestasis': 3, 'Drug Reaction': 4,
63
+ 'Peptic ulcer disease': 5, 'AIDS': 6, 'Diabetes ': 7, 'Gastroenteritis': 8, 'Bronchial Asthma': 9,
64
+ 'Hypertension ': 10, 'Migraine': 11, 'Cervical spondylosis': 12, 'Paralysis (brain hemorrhage)': 13,
65
+ 'Jaundice': 14, 'Malaria': 15, 'Chicken pox': 16, 'Dengue': 17, 'Typhoid': 18, 'hepatitis A': 19,
66
+ 'Hepatitis B': 20, 'Hepatitis C': 21, 'Hepatitis D': 22, 'Hepatitis E': 23, 'Alcoholic hepatitis': 24,
67
+ 'Tuberculosis': 25, 'Common Cold': 26, 'Pneumonia': 27, 'Dimorphic hemorrhoids(piles)': 28,
68
+ 'Heart attack': 29, 'Varicose veins': 30, 'Hypothyroidism': 31, 'Hyperthyroidism': 32,
69
+ 'Hypoglycemia': 33, 'Osteoarthritis': 34, 'Arthritis': 35,
70
+ '(vertigo) Paroxysmal Positional Vertigo': 36, 'Acne': 37, 'Urinary tract infection': 38,
71
+ 'Psoriasis': 39, 'Impetigo': 40
72
+ }
73
+
74
+ # Function to prepare data
75
+ def prepare_data(df):
76
+ # Split the dataset into features and target
77
+ X = df.iloc[:, :-1] # All columns except the last one (features)
78
+ y = df.iloc[:, -1] # The last column (target)
79
+
80
+ # Encode the target variable
81
+ label_encoder = LabelEncoder()
82
+ y_encoded = label_encoder.fit_transform(y)
83
 
84
+ return X, y_encoded, label_encoder
85
+
86
+ # Preparing training and testing data
87
+ X_train, y_train, label_encoder_train = prepare_data(df_train)
88
+ X_test, y_test, label_encoder_test = prepare_data(df_test)
89
+
90
+ # Define the models
91
+ models = {
92
+ "Decision Tree": DecisionTreeClassifier(),
93
+ "Random Forest": RandomForestClassifier(),
94
+ "Naive Bayes": GaussianNB()
95
+ }
96
+
97
+ # Train and evaluate models
98
+ trained_models = {}
99
+ for model_name, model_obj in models.items():
100
+ model_obj.fit(X_train, y_train) # Fit the model
101
+ y_pred = model_obj.predict(X_test) # Make predictions
102
+ acc = accuracy_score(y_test, y_pred) # Calculate accuracy
103
+ trained_models[model_name] = {'model': model_obj, 'accuracy': acc}
104
+
105
+ # Helper Functions for Chatbot
106
+ def bag_of_words(s, words):
107
+ """Convert user input to bag-of-words vector."""
108
+ bag = [0] * len(words)
109
+ s_words = word_tokenize(s)
110
+ s_words = [stemmer.stem(word.lower()) for word in s_words if word.isalnum()]
111
+ for se in s_words:
112
+ for i, w in enumerate(words):
113
+ if w == se:
114
+ bag[i] = 1
115
+ return np.array(bag)
116
+
117
+ def generate_chatbot_response(message, history):
118
+ """Generate chatbot response and maintain conversation history."""
119
+ history = history or []
120
+ try:
121
+ result = chatbot_model.predict([bag_of_words(message, words)])
122
+ tag = labels[np.argmax(result)]
123
+ response = "I'm sorry, I didn't understand that. πŸ€”"
124
+ for intent in intents_data["intents"]:
125
+ if intent["tag"] == tag:
126
+ response = random.choice(intent["responses"])
127
+ break
128
+ except Exception as e:
129
+ response = f"Error: {e}"
130
+ history.append((message, response))
131
+ return history, response
132
+
133
+ def analyze_sentiment(user_input):
134
+ """Analyze sentiment and map to emojis."""
135
+ inputs = tokenizer_sentiment(user_input, return_tensors="pt")
136
+ with torch.no_grad():
137
+ outputs = model_sentiment(**inputs)
138
+ sentiment_class = torch.argmax(outputs.logits, dim=1).item()
139
+ sentiment_map = ["Negative πŸ˜”", "Neutral 😐", "Positive 😊"]
140
+ return f"Sentiment: {sentiment_map[sentiment_class]}"
141
+
142
+ def detect_emotion(user_input):
143
+ """Detect emotions based on input."""
144
+ pipe = pipeline("text-classification", model=model_emotion, tokenizer=tokenizer_emotion)
145
+ result = pipe(user_input)
146
+ emotion = result[0]["label"].lower().strip()
147
+ emotion_map = {
148
+ "joy": "Joy 😊",
149
+ "anger": "Anger 😠",
150
+ "sadness": "Sadness 😒",
151
+ "fear": "Fear 😨",
152
+ "surprise": "Surprise 😲",
153
+ "neutral": "Neutral 😐",
154
  }
155
+ return emotion_map.get(emotion, "Unknown πŸ€”"), emotion
156
+
157
+ def generate_suggestions(emotion):
158
+ """Return relevant suggestions based on detected emotions."""
159
+ emotion_key = emotion.lower()
160
+ suggestions = {
161
+ "joy": [
162
+ ("Mindfulness Practices", "https://www.helpguide.org/mental-health/meditation/mindful-breathing-meditation"),
163
+ ("Coping with Anxiety", "https://www.helpguide.org/mental-health/anxiety/tips-for-dealing-with-anxiety"),
164
+ ("Emotional Wellness Toolkit", "https://www.nih.gov/health-information/emotional-wellness-toolkit"),
165
+ ("Relaxation Video", "https://youtu.be/yGKKz185M5o"),
166
+ ],
167
+ "anger": [
168
+ ("Emotional Wellness Toolkit", "https://www.nih.gov/health-information/emotional-wellness-toolkit"),
169
+ ("Stress Management Tips", "https://www.health.harvard.edu/health-a-to-z"),
170
+ ("Dealing with Anger", "https://www.helpguide.org/mental-health/anxiety/tips-for-dealing-with-anxiety"),
171
+ ("Relaxation Video", "https://youtu.be/MIc299Flibs"),
172
+ ],
173
+ "fear": [
174
+ ("Mindfulness Practices", "https://www.helpguide.org/mental-health/meditation/mindful-breathing-meditation"),
175
+ ("Coping with Anxiety", "https://www.helpguide.org/mental-health/anxiety/tips-for-dealing-with-anxiety"),
176
+ ("Emotional Wellness Toolkit", "https://www.nih.gov/health-information/emotional-wellness-toolkit"),
177
+ ("Relaxation Video", "https://youtu.be/yGKKz185M5o"),
178
+ ],
179
+ "sadness": [
180
+ ("Emotional Wellness Toolkit", "https://www.nih.gov/health-information/emotional-wellness-toolkit"),
181
+ ("Dealing with Anxiety", "https://www.helpguide.org/mental-health/anxiety/tips-for-dealing-with-anxiety"),
182
+ ("Relaxation Video", "https://youtu.be/-e-4Kx5px_I"),
183
+ ],
184
+ "surprise": [
185
+ ("Managing Stress", "https://www.health.harvard.edu/health-a-to-z"),
186
+ ("Coping Strategies", "https://www.helpguide.org/mental-health/anxiety/tips-for-dealing-with-anxiety"),
187
+ ("Relaxation Video", "https://youtu.be/m1vaUGtyo-A"),
188
+ ],
189
+ }
190
+
191
+ # Create a markdown string for clickable suggestions
192
+ formatted_suggestions = [
193
+ f"- [{title}]({link})" for title, link in suggestions.get(emotion_key, [("No specific suggestions available.", "#")])
194
+ ]
195
+
196
+ return "\n".join(formatted_suggestions)
197
+
198
+ def get_health_professionals_and_map(location, query):
199
+ """Search nearby healthcare professionals using Google Maps API."""
200
+ try:
201
+ if not location or not query:
202
+ return [], "" # Return empty list if inputs are missing
203
+
204
+ geo_location = gmaps.geocode(location)
205
+ if geo_location:
206
+ lat, lng = geo_location[0]["geometry"]["location"].values()
207
+ places_result = gmaps.places_nearby(location=(lat, lng), radius=10000, keyword=query)["results"]
208
+ professionals = []
209
+ map_ = folium.Map(location=(lat, lng), zoom_start=13)
210
+ for place in places_result:
211
+ professionals.append([place['name'], place.get('vicinity', 'No address provided')])
212
+ folium.Marker(
213
+ location=[place["geometry"]["location"]["lat"], place["geometry"]["location"]["lng"]],
214
+ popup=f"{place['name']}"
215
+ ).add_to(map_)
216
+ return professionals, map_._repr_html_()
217
+
218
+ return [], "" # Return empty list if no professionals found
219
+ except Exception as e:
220
+ return [], "" # Return empty list on exception
221
 
222
+ # Main Application Logic for Chatbot
223
+ def app_function_chatbot(user_input, location, query, history):
224
+ chatbot_history, _ = generate_chatbot_response(user_input, history)
225
+ sentiment_result = analyze_sentiment(user_input)
226
+ emotion_result, cleaned_emotion = detect_emotion(user_input)
227
+ suggestions = generate_suggestions(cleaned_emotion)
228
+ professionals, map_html = get_health_professionals_and_map(location, query)
229
+ return chatbot_history, sentiment_result, emotion_result, suggestions, professionals, map_html
230
 
231
+ # Disease Prediction Logic
232
+ def predict_disease(symptoms):
233
+ """Predict disease based on input symptoms."""
234
+ input_test = np.zeros(len(X_train.columns)) # Create an array for feature input
235
  for symptom in symptoms:
236
+ if symptom in X_train.columns:
237
+ input_test[X_train.columns.get_loc(symptom)] = 1
238
+ predictions = {}
239
+ for model_name, info in trained_models.items():
240
+ prediction = info['model'].predict([input_test])[0]
241
+ predicted_disease = label_encoder_train.inverse_transform([prediction])[0]
242
+ predictions[model_name] = predicted_disease
243
+ return predictions
244
+
245
+ # Gradio Application Interface
246
+ with gr.Blocks() as app:
247
+ gr.HTML("<h1>🌟 Well-Being Companion</h1>")
 
 
 
 
 
 
 
 
 
248
 
249
+ with gr.Tab("Mental Health Chatbot"):
250
+ with gr.Row():
251
+ user_input = gr.Textbox(label="Please Enter Your Message Here")
252
+ location = gr.Textbox(label="Please Enter Your Current Location Here")
253
+ query = gr.Textbox(label="Please Enter Which Health Professional You Want To Search Nearby")
254
+
255
+ submit_chatbot = gr.Button(value="Submit Chatbot", variant="primary")
256
+
257
+ chatbot = gr.Chatbot(label="Chat History")
258
+ sentiment = gr.Textbox(label="Detected Sentiment")
259
+ emotion = gr.Textbox(label="Detected Emotion")
260
+
261
+ suggestions_markdown = gr.Markdown(label="Suggestions") # Use Markdown to display clickable links
262
+ professionals = gr.DataFrame(label="Nearby Health Professionals", headers=["Name", "Address"])
263
+ map_html = gr.HTML(label="Interactive Map")
264
+
265
+ submit_chatbot.click(
266
+ app_function_chatbot,
267
+ inputs=[user_input, location, query, chatbot],
268
+ outputs=[chatbot, sentiment, emotion, suggestions_markdown, professionals, map_html],
269
+ )
270
+
271
+ with gr.Tab("Disease Prediction"):
272
+ symptom1 = gr.Dropdown(X_train.columns.tolist(), label="Select Symptom 1")
273
+ symptom2 = gr.Dropdown(X_train.columns.tolist(), label="Select Symptom 2")
274
+ symptom3 = gr.Dropdown(X_train.columns.tolist(), label="Select Symptom 3")
275
+ symptom4 = gr.Dropdown(X_train.columns.tolist(), label="Select Symptom 4")
276
+ symptom5 = gr.Dropdown(X_train.columns.tolist(), label="Select Symptom 5")
277
+
278
+ submit_disease = gr.Button(value="Predict Disease", variant="primary")
279
+ disease_prediction_result = gr.Textbox(label="Predicted Diseases")
280
+
281
+ submit_disease.click(
282
+ lambda symptom1, symptom2, symptom3, symptom4, symptom5: predict_disease(
283
+ [symptom1, symptom2, symptom3, symptom4, symptom5]),
284
+ inputs=[symptom1, symptom2, symptom3, symptom4, symptom5],
285
+ outputs=disease_prediction_result,
286
+ )
287
 
288
  # Launch the Gradio application
289
+ app.launch()