Prajithr04 commited on
Commit
9889643
·
1 Parent(s): 7671fea
Files changed (4) hide show
  1. Dockerfile +13 -0
  2. app.py +201 -0
  3. chromadb/chroma.sqlite3 +0 -0
  4. requirements.txt +0 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import pandas as pd
4
+ from sentence_transformers import SentenceTransformer
5
+ import chromadb
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ import uvicorn
8
+ import requests
9
+ # Define FastAPI app
10
+ app = FastAPI()
11
+
12
+ origins = [
13
+ "http://localhost:5173",
14
+ "localhost:5173"
15
+ ]
16
+
17
+
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=origins,
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"]
24
+ )
25
+
26
+ # Load the dataset and model at startup
27
+ df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv")
28
+ df['Symptoms'] = df['Symptoms'].str.split(',')
29
+ df['Symptoms'] = df['Symptoms'].apply(lambda x: [s.strip() for s in x])
30
+
31
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
32
+ client = chromadb.PersistentClient(path='./chromadb')
33
+ collection = client.get_or_create_collection(name="symptomsvector")
34
+
35
+ class SymptomQuery(BaseModel):
36
+ symptom: str
37
+
38
+ # Endpoint to handle symptom query and return matching symptoms
39
+ @app.post("/find_matching_symptoms")
40
+ def find_matching_symptoms(query: SymptomQuery):
41
+ # Generate embedding for the symptom query
42
+ symptoms = query.symptom.split(',')
43
+ all_results = []
44
+
45
+ for symptom in symptoms:
46
+ symptom = symptom.strip()
47
+ query_embedding = model.encode([symptom])
48
+
49
+ # Perform similarity search in ChromaDB
50
+ results = collection.query(
51
+ query_embeddings=query_embedding.tolist(),
52
+ n_results=3 # Return top 3 similar symptoms for each symptom
53
+ )
54
+ all_results.extend(results['documents'][0])
55
+
56
+ # Remove duplicates while preserving order
57
+ matching_symptoms = list(dict.fromkeys(all_results))
58
+
59
+ return {"matching_symptoms": matching_symptoms}
60
+
61
+ # Endpoint to handle symptom query and return matching diseases
62
+ @app.post("/find_matching_diseases")
63
+ def find_matching_diseases(query: SymptomQuery):
64
+ # Generate embedding for the symptom query
65
+ query_embedding = model.encode([query.symptom])
66
+
67
+ # Perform similarity search in ChromaDB
68
+ results = collection.query(
69
+ query_embeddings=query_embedding.tolist(),
70
+ n_results=5 # Return top 5 similar symptoms
71
+ )
72
+
73
+ # Extract matching symptoms
74
+ matching_symptoms = results['documents'][0]
75
+
76
+ # Filter diseases that match the symptoms
77
+ matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))]
78
+
79
+ return {"matching_diseases": matching_diseases['Name'].tolist()}
80
+
81
+ # Endpoint to handle symptom query and return detailed disease list
82
+ @app.post("/find_disease_list")
83
+ def find_disease_list(query: SymptomQuery):
84
+ # Generate embedding for the symptom query
85
+ query_embedding = model.encode([query.symptom])
86
+
87
+ # Perform similarity search in ChromaDB
88
+ results = collection.query(
89
+ query_embeddings=query_embedding.tolist(),
90
+ n_results=5 # Return top 5 similar symptoms
91
+ )
92
+
93
+ # Extract matching symptoms
94
+ matching_symptoms = results['documents'][0]
95
+
96
+ # Filter diseases that match the symptoms
97
+ matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))]
98
+
99
+ # Create a list of disease information
100
+ disease_list = []
101
+ symptoms_list = []
102
+ unique_symptoms_list = []
103
+ for _, row in matching_diseases.iterrows():
104
+ disease_info = {
105
+ 'Disease': row['Name'],
106
+ 'Symptoms': row['Symptoms'],
107
+ 'Treatments': row['Treatments']
108
+ }
109
+ disease_list.append(disease_info)
110
+ symptoms_info = row['Symptoms']
111
+ symptoms_list.append(symptoms_info)
112
+ for i in range(len(symptoms_list)):
113
+ for j in range(len(symptoms_list[i])):
114
+ if symptoms_list[i][j] not in unique_symptoms_list:
115
+ unique_symptoms_list.append(symptoms_list[i][j])
116
+ return {"disease_list": disease_list, "unique_symptoms_list": unique_symptoms_list}
117
+
118
+ class SelectedSymptomsQuery(BaseModel):
119
+ selected_symptoms: list
120
+
121
+ @app.post("/find_disease")
122
+ def find_disease(query: SelectedSymptomsQuery):
123
+ selected_symptoms = query.selected_symptoms
124
+ # Filter diseases that match at least one of the selected symptoms
125
+ matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in x for s in selected_symptoms))]
126
+
127
+ # Sort diseases by the number of matching symptoms in descending order
128
+ matching_diseases['match_count'] = matching_diseases['Symptoms'].apply(lambda x: sum(s in selected_symptoms for s in x))
129
+ matching_diseases = matching_diseases.sort_values(by='match_count', ascending=False)
130
+
131
+ # Create a list of disease information
132
+ disease_list = []
133
+ max_match_count_disease = None
134
+ max_match_count = -1
135
+
136
+ for _, row in matching_diseases.iterrows():
137
+ disease_info = {
138
+ 'Disease': row['Name'],
139
+ 'Symptoms': row['Symptoms'],
140
+ 'Treatments': row['Treatments'],
141
+ 'MatchCount': row['match_count']
142
+ }
143
+ disease_list.append(disease_info)
144
+
145
+ # Check if this disease has the maximum match count
146
+ if row['match_count'] > max_match_count:
147
+ max_match_count = row['match_count']
148
+ max_match_count_disease = disease_info
149
+
150
+ return {"disease_list": disease_list, "max_match_count_disease": max_match_count_disease}
151
+ class DiseaseListQuery(BaseModel):
152
+ disease_list: list
153
+
154
+ class DiseaseDetail(BaseModel):
155
+ Disease: str
156
+ Symptoms: list
157
+ Treatments: str
158
+ MatchCount: int
159
+
160
+ @app.post("/pass2llm")
161
+ def pass2llm(query: DiseaseDetail):
162
+ # Prepare the data to be sent to the LLM API
163
+ disease_list_details = query
164
+
165
+ # Make the API request to the Ngrok endpoint to get the public URL
166
+ headers = {
167
+ "Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG",
168
+ "Ngrok-Version": "2"
169
+ }
170
+ response = requests.get("https://api.ngrok.com/endpoints", headers=headers)
171
+
172
+ # Check if the request was successful
173
+ if response.status_code == 200:
174
+ llm_api_response = response.json()
175
+ public_url = llm_api_response['endpoints'][0]['public_url']
176
+
177
+ # Prepare the prompt with the disease list details
178
+ prompt = f"Here is a list of diseases and their details: {disease_list_details}. Please generate a summary."
179
+
180
+ # Make the request to the LLM API
181
+ llm_headers = {
182
+ "Content-Type": "application/json"
183
+ }
184
+ llm_payload = {
185
+ "model": "llama3",
186
+ "prompt": prompt,
187
+ "stream": False
188
+ }
189
+ llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload)
190
+
191
+ # Check if the request to the LLM API was successful
192
+ if llm_response.status_code == 200:
193
+ llm_response_json = llm_response.json()
194
+ return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")}
195
+ else:
196
+ return {"message": "Failed to get response from LLM!", "error": llm_response.text}
197
+ else:
198
+ return {"message": "Failed to get public URL from Ngrok!", "error": response.text}
199
+ # To run the FastAPI app with Uvicorn
200
+ # if __name__ == "__main__":
201
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
chromadb/chroma.sqlite3 ADDED
Binary file (168 kB). View file
 
requirements.txt ADDED
Binary file (5.88 kB). View file