Prajith04 commited on
Commit
add1f34
·
verified ·
1 Parent(s): 2aabcef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -121
app.py CHANGED
@@ -1,12 +1,15 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- import pandas as pd
4
  from sentence_transformers import SentenceTransformer
5
  import chromadb
6
  from fastapi.middleware.cors import CORSMiddleware
7
  import uvicorn
8
  import requests
9
  from itertools import combinations
 
 
 
 
10
  # Define FastAPI app
11
  app = FastAPI()
12
 
@@ -15,7 +18,6 @@ origins = [
15
  "localhost:5173"
16
  ]
17
 
18
-
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=origins,
@@ -24,22 +26,72 @@ app.add_middleware(
24
  allow_headers=["*"]
25
  )
26
 
27
- # Load the dataset and model at startup
28
- df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv")
29
- df['Symptoms'] = df['Symptoms'].str.split(',')
30
- df['Symptoms'] = df['Symptoms'].apply(lambda x: [s.strip() for s in x])
31
-
32
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
33
  client = chromadb.PersistentClient(path='./chromadb')
34
  collection = client.get_or_create_collection(name="symptomsvector")
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  class SymptomQuery(BaseModel):
37
  symptom: str
38
 
39
- # Endpoint to handle symptom query and return matching symptoms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @app.post("/find_matching_symptoms")
41
  def find_matching_symptoms(query: SymptomQuery):
42
- # Generate embedding for the symptom query
43
  symptoms = query.symptom.split(',')
44
  all_results = []
45
 
@@ -50,122 +102,126 @@ def find_matching_symptoms(query: SymptomQuery):
50
  # Perform similarity search in ChromaDB
51
  results = collection.query(
52
  query_embeddings=query_embedding.tolist(),
53
- n_results=3 # Return top 3 similar symptoms for each symptom
54
  )
55
  all_results.extend(results['documents'][0])
56
 
57
- # Remove duplicates while preserving order
58
  matching_symptoms = list(dict.fromkeys(all_results))
59
-
60
  return {"matching_symptoms": matching_symptoms}
61
 
62
- # Endpoint to handle symptom query and return matching diseases
63
- @app.post("/find_matching_diseases")
64
- def find_matching_diseases(query: SymptomQuery):
65
- # Generate embedding for the symptom query
66
- query_embedding = model.encode([query.symptom])
 
67
 
68
- # Perform similarity search in ChromaDB
69
- results = collection.query(
70
- query_embeddings=query_embedding.tolist(),
71
- n_results=5 # Return top 5 similar symptoms
72
- )
73
 
74
- # Extract matching symptoms
75
- matching_symptoms = results['documents'][0]
 
 
 
 
 
76
 
77
- # Filter diseases that match the symptoms
78
- matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))]
 
 
 
79
 
80
- return {"matching_diseases": matching_diseases['Name'].tolist()}
81
- all_symptoms=[]
82
- all_selected_symptoms=[]
83
- # Endpoint to handle symptom query and return detailed disease list
84
- @app.post("/find_disease_list")
85
- def find_disease_list(query: SymptomQuery):
86
- # Generate embedding for the symptom query
87
- query_embedding = model.encode([query.symptom])
88
-
89
- # Perform similarity search in ChromaDB
90
- results = collection.query(
91
- query_embeddings=query_embedding.tolist(),
92
- n_results=5 # Return top 5 similar symptoms
93
- )
94
-
95
- # Extract matching symptoms
96
- matching_symptoms = results['documents'][0]
97
- all_symptoms.append(matching_symptoms)
98
- # Filter diseases that match the symptoms
99
- matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))]
100
-
101
- # Create a list of disease information
102
  disease_list = []
103
- symptoms_list = []
104
- unique_symptoms_list = []
105
- for _, row in matching_diseases.iterrows():
106
- disease_info = {
107
- 'Disease': row['Name'],
108
- 'Symptoms': row['Symptoms'],
109
- 'Treatments': row['Treatments']
110
- }
111
- disease_list.append(disease_info)
112
- symptoms_info = row['Symptoms']
113
- symptoms_list.append(symptoms_info)
114
- for i in range(len(symptoms_list)):
115
- for j in range(len(symptoms_list[i])):
116
- if symptoms_list[i][j] not in unique_symptoms_list:
117
- unique_symptoms_list.append(symptoms_list[i][j].lower())
118
- return {"disease_list": disease_list, "unique_symptoms_list": unique_symptoms_list}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  class SelectedSymptomsQuery(BaseModel):
121
  selected_symptoms: list
122
 
 
 
 
123
  @app.post("/find_disease")
124
  def find_disease(query: SelectedSymptomsQuery):
125
- SelectedSymptoms = query.selected_symptoms
126
- all_selected_symptoms.extend(SelectedSymptoms)
 
 
 
 
127
 
128
  disease_list = []
129
- symptoms_list = []
130
- unique_symptoms_set = set() # Use a set for unique symptoms
131
-
132
- # Combine all the symptoms we already know (all_symptoms + selected symptoms)
133
- known_symptoms = {symptom.lower() for symptom_set in all_symptoms for symptom in symptom_set}
134
- known_symptoms.update([symptom.lower() for symptom in SelectedSymptoms])
135
-
136
- # Generate combinations of symptoms from all_symptoms and selected symptoms
137
- for symptoms_set in all_symptoms:
138
- for i in range(1, len(symptoms_set) + 1): # Generate combinations with all lengths
139
- for symptom_combination in combinations(symptoms_set, i):
140
- temp = list(symptom_combination) + SelectedSymptoms # Combine with selected symptoms
141
-
142
- # Search for diseases that match the combination
143
- matching_diseases = df[df['Symptoms'].apply(lambda x: all(s in x for s in temp))]
144
-
145
- for _, row in matching_diseases.iterrows():
146
- disease_info = {
147
- 'Disease': row['Name'],
148
- 'Symptoms': row['Symptoms'],
149
- 'Treatments': row['Treatments']
150
- }
151
- disease_list.append(disease_info)
152
-
153
- # Add each symptom in lowercase to the set
154
- for symptom in row['Symptoms']:
155
- if symptom.lower() not in known_symptoms:
156
- unique_symptoms_set.add(symptom.lower())
157
-
158
- # Convert the set back to a list for the response
159
- unique_symptoms_list = list(unique_symptoms_set)
160
 
161
  return {
162
  "unique_symptoms_list": unique_symptoms_list,
163
- "all_selected_symptoms": all_selected_symptoms,
164
- "all_symptoms": all_symptoms,
165
  "disease_list": disease_list
166
  }
167
- class DiseaseListQuery(BaseModel):
168
- disease_list: list
169
 
170
  class DiseaseDetail(BaseModel):
171
  Disease: str
@@ -175,36 +231,21 @@ class DiseaseDetail(BaseModel):
175
 
176
  @app.post("/pass2llm")
177
  def pass2llm(query: DiseaseDetail):
178
- # Prepare the data to be sent to the LLM API
179
- disease_list_details = query
180
-
181
- # Make the API request to the Ngrok endpoint to get the public URL
182
  headers = {
183
  "Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG",
184
  "Ngrok-Version": "2"
185
  }
186
  response = requests.get("https://api.ngrok.com/endpoints", headers=headers)
187
 
188
- # Check if the request was successful
189
  if response.status_code == 200:
190
  llm_api_response = response.json()
191
  public_url = llm_api_response['endpoints'][0]['public_url']
 
192
 
193
- # Prepare the prompt with the disease list details
194
- prompt = f"Here is a list of diseases and their details: {disease_list_details}. Please generate a summary."
195
-
196
- # Make the request to the LLM API
197
- llm_headers = {
198
- "Content-Type": "application/json"
199
- }
200
- llm_payload = {
201
- "model": "llama3",
202
- "prompt": prompt,
203
- "stream": False
204
- }
205
  llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload)
206
 
207
- # Check if the request to the LLM API was successful
208
  if llm_response.status_code == 200:
209
  llm_response_json = llm_response.json()
210
  return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")}
@@ -212,8 +253,7 @@ def pass2llm(query: DiseaseDetail):
212
  return {"message": "Failed to get response from LLM!", "error": llm_response.text}
213
  else:
214
  return {"message": "Failed to get public URL from Ngrok!", "error": response.text}
 
215
  # To run the FastAPI app with Uvicorn
216
  # if __name__ == "__main__":
217
  # uvicorn.run(app, host="0.0.0.0", port=8000)
218
-
219
-
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  from sentence_transformers import SentenceTransformer
4
  import chromadb
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import uvicorn
7
  import requests
8
  from itertools import combinations
9
+ import sqlite3
10
+ import pandas as pd
11
+ import os
12
+
13
  # Define FastAPI app
14
  app = FastAPI()
15
 
 
18
  "localhost:5173"
19
  ]
20
 
 
21
  app.add_middleware(
22
  CORSMiddleware,
23
  allow_origins=origins,
 
26
  allow_headers=["*"]
27
  )
28
 
29
+ # Load the model at startup
 
 
 
 
30
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
31
  client = chromadb.PersistentClient(path='./chromadb')
32
  collection = client.get_or_create_collection(name="symptomsvector")
33
 
34
+ # Helper function to initialize database and populate from CSV if needed
35
+ def init_db():
36
+ conn = sqlite3.connect("diseases_symptoms.db")
37
+ cursor = conn.cursor()
38
+ cursor.execute('''
39
+ CREATE TABLE IF NOT EXISTS diseases (
40
+ id INTEGER PRIMARY KEY,
41
+ name TEXT,
42
+ symptoms TEXT,
43
+ treatments TEXT
44
+ )
45
+ ''')
46
+ conn.commit()
47
+ return conn
48
+
49
+ # Populate database from CSV if it's the first time
50
+ if not os.path.exists("diseases_symptoms.db"):
51
+ conn = init_db()
52
+ df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv")
53
+ df['Symptoms'] = df['Symptoms'].str.split(',').apply(lambda x: [s.strip() for s in x])
54
+
55
+ for _, row in df.iterrows():
56
+ symptoms_str = ",".join(row['Symptoms'])
57
+ cursor = conn.cursor()
58
+ cursor.execute("INSERT INTO diseases (name, symptoms, treatments) VALUES (?, ?, ?)",
59
+ (row['Name'], symptoms_str, row.get('Treatments', '')))
60
+ conn.commit()
61
+ conn.close()
62
+
63
  class SymptomQuery(BaseModel):
64
  symptom: str
65
 
66
+ # Helper function to fetch diseases matching symptoms from SQLite
67
+ def fetch_diseases_by_symptoms(matching_symptoms):
68
+ conn = sqlite3.connect("diseases_symptoms.db")
69
+ cursor = conn.cursor()
70
+ disease_list = []
71
+ unique_symptoms_list = []
72
+ matching_symptom_str = ','.join(matching_symptoms)
73
+
74
+ # Retrieve matching diseases based on symptoms in SQLite
75
+ for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases WHERE symptoms LIKE ?",
76
+ (f'%{matching_symptom_str}%',)):
77
+ disease_info = {
78
+ 'Disease': row[0],
79
+ 'Symptoms': row[1].split(','),
80
+ 'Treatments': row[2]
81
+ }
82
+ disease_list.append(disease_info)
83
+
84
+ # Add symptoms to the unique list, converting to lowercase to avoid duplicates
85
+ for symptom in row[1].split(','):
86
+ symptom_lower = symptom.strip().lower()
87
+ if symptom_lower not in unique_symptoms_list:
88
+ unique_symptoms_list.append(symptom_lower)
89
+
90
+ conn.close()
91
+ return disease_list, unique_symptoms_list
92
+
93
  @app.post("/find_matching_symptoms")
94
  def find_matching_symptoms(query: SymptomQuery):
 
95
  symptoms = query.symptom.split(',')
96
  all_results = []
97
 
 
102
  # Perform similarity search in ChromaDB
103
  results = collection.query(
104
  query_embeddings=query_embedding.tolist(),
105
+ n_results=3
106
  )
107
  all_results.extend(results['documents'][0])
108
 
 
109
  matching_symptoms = list(dict.fromkeys(all_results))
 
110
  return {"matching_symptoms": matching_symptoms}
111
 
112
+ @app.post("/find_disease_list")
113
+ def find_disease_list(query: SymptomQuery):
114
+ # Normalize and embed each input symptom
115
+ selected_symptoms = [symptom.strip().lower() for symptom in query.symptom.split(',')]
116
+ all_selected_symptoms.update(selected_symptoms) # Add new symptoms to the set
117
+ all_results = []
118
 
119
+ for symptom in selected_symptoms:
120
+ # Generate the embedding for the current symptom
121
+ query_embedding = model.encode([symptom])
 
 
122
 
123
+ # Perform similarity search in ChromaDB
124
+ results = collection.query(
125
+ query_embeddings=query_embedding.tolist(),
126
+ n_results=5 # Return top 5 similar symptoms for each input symptom
127
+ )
128
+ # Aggregate the matching symptoms from the results
129
+ all_results.extend(results['documents'][0])
130
 
131
+ # Remove duplicates while preserving order
132
+ matching_symptoms = list(dict.fromkeys(all_results))
133
+
134
+ conn = sqlite3.connect("diseases_symptoms.db")
135
+ cursor = conn.cursor()
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  disease_list = []
138
+ unique_symptoms_set = set()
139
+
140
+ # Retrieve diseases that contain any of the matching symptoms
141
+ for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"):
142
+ disease_name = row[0]
143
+ disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] # Normalize database symptoms
144
+ treatments = row[2]
145
+
146
+ # Check if there is any overlap between matching symptoms and the disease symptoms
147
+ matched_symptoms = [symptom for symptom in matching_symptoms if symptom in disease_symptoms]
148
+
149
+ if matched_symptoms: # Include disease if there is at least one matching symptom
150
+ disease_info = {
151
+ 'Disease': disease_name,
152
+ 'Symptoms': disease_symptoms,
153
+ 'Treatments': treatments
154
+ }
155
+ disease_list.append(disease_info)
156
+
157
+ # Add symptoms not yet selected by the user to unique symptoms list
158
+ for symptom in disease_symptoms:
159
+ if symptom not in selected_symptoms:
160
+ unique_symptoms_set.add(symptom)
161
+
162
+ conn.close()
163
+
164
+ # Convert unique symptoms set to a sorted list for consistent output
165
+ unique_symptoms_list = sorted(unique_symptoms_set)
166
+
167
+ return {
168
+ "disease_list": disease_list,
169
+ "unique_symptoms_list": unique_symptoms_list
170
+ }
171
+
172
+
173
 
174
  class SelectedSymptomsQuery(BaseModel):
175
  selected_symptoms: list
176
 
177
+ # Initialize global list for persistent selected symptoms
178
+ all_selected_symptoms = set() # Use a set to avoid duplicates
179
+
180
  @app.post("/find_disease")
181
  def find_disease(query: SelectedSymptomsQuery):
182
+ # Normalize input symptoms and add them to global list
183
+ new_symptoms = [symptom.strip().lower() for symptom in query.selected_symptoms]
184
+ all_selected_symptoms.update(new_symptoms) # Add new symptoms to the set
185
+
186
+ conn = sqlite3.connect("diseases_symptoms.db")
187
+ cursor = conn.cursor()
188
 
189
  disease_list = []
190
+ unique_symptoms_set = set()
191
+
192
+ # Fetch all diseases and calculate matching symptoms
193
+ for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"):
194
+ disease_name = row[0]
195
+ disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')]
196
+ treatments = row[2]
197
+
198
+ # Find common symptoms between all selected and disease symptoms
199
+ matched_symptoms = [symptom for symptom in all_selected_symptoms if symptom in disease_symptoms]
200
+
201
+ # Check for full match between known symptoms and disease symptoms
202
+ if len(matched_symptoms) == len(all_selected_symptoms):
203
+ disease_info = {
204
+ 'Disease': disease_name,
205
+ 'Symptoms': disease_symptoms,
206
+ 'Treatments': treatments
207
+ }
208
+ disease_list.append(disease_info)
209
+
210
+ # Add symptoms not yet selected by the user to unique symptoms list
211
+ for symptom in disease_symptoms:
212
+ if symptom not in all_selected_symptoms:
213
+ unique_symptoms_set.add(symptom)
214
+
215
+ conn.close()
216
+
217
+ # Convert unique symptoms set to a sorted list for consistent output
218
+ unique_symptoms_list = sorted(unique_symptoms_set)
 
 
219
 
220
  return {
221
  "unique_symptoms_list": unique_symptoms_list,
222
+ "all_selected_symptoms": list(all_selected_symptoms), # Convert set to list for JSON response
 
223
  "disease_list": disease_list
224
  }
 
 
225
 
226
  class DiseaseDetail(BaseModel):
227
  Disease: str
 
231
 
232
  @app.post("/pass2llm")
233
  def pass2llm(query: DiseaseDetail):
 
 
 
 
234
  headers = {
235
  "Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG",
236
  "Ngrok-Version": "2"
237
  }
238
  response = requests.get("https://api.ngrok.com/endpoints", headers=headers)
239
 
 
240
  if response.status_code == 200:
241
  llm_api_response = response.json()
242
  public_url = llm_api_response['endpoints'][0]['public_url']
243
+ prompt = f"Here is a list of diseases and their details: {query}. Please generate a summary."
244
 
245
+ llm_headers = {"Content-Type": "application/json"}
246
+ llm_payload = {"model": "llama3", "prompt": prompt, "stream": False}
 
 
 
 
 
 
 
 
 
 
247
  llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload)
248
 
 
249
  if llm_response.status_code == 200:
250
  llm_response_json = llm_response.json()
251
  return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")}
 
253
  return {"message": "Failed to get response from LLM!", "error": llm_response.text}
254
  else:
255
  return {"message": "Failed to get public URL from Ngrok!", "error": response.text}
256
+
257
  # To run the FastAPI app with Uvicorn
258
  # if __name__ == "__main__":
259
  # uvicorn.run(app, host="0.0.0.0", port=8000)