|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from sentence_transformers import SentenceTransformer |
|
import chromadb |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import uvicorn |
|
import requests |
|
from itertools import combinations |
|
import sqlite3 |
|
import pandas as pd |
|
import os |
|
import time |
|
|
|
|
|
app = FastAPI() |
|
|
|
origins = [ |
|
"http://localhost:5173", |
|
"localhost:5173" |
|
] |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=origins, |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
|
|
|
|
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
client = chromadb.PersistentClient(path='./chromadb') |
|
collection = client.get_or_create_collection(name="symptomsvector") |
|
|
|
|
|
def init_db(): |
|
conn = sqlite3.connect("diseases_symptoms.db") |
|
cursor = conn.cursor() |
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS diseases ( |
|
id INTEGER PRIMARY KEY, |
|
name TEXT, |
|
symptoms TEXT, |
|
treatments TEXT |
|
) |
|
''') |
|
conn.commit() |
|
return conn |
|
|
|
|
|
if not os.path.exists("diseases_symptoms.db"): |
|
conn = init_db() |
|
df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv") |
|
df['Symptoms'] = df['Symptoms'].str.split(',').apply(lambda x: [s.strip() for s in x]) |
|
|
|
for _, row in df.iterrows(): |
|
symptoms_str = ",".join(row['Symptoms']) |
|
cursor = conn.cursor() |
|
cursor.execute("INSERT INTO diseases (name, symptoms, treatments) VALUES (?, ?, ?)", |
|
(row['Name'], symptoms_str, row.get('Treatments', ''))) |
|
conn.commit() |
|
conn.close() |
|
|
|
class SymptomQuery(BaseModel): |
|
symptom: str |
|
|
|
|
|
def fetch_diseases_by_symptoms(matching_symptoms): |
|
conn = sqlite3.connect("diseases_symptoms.db") |
|
cursor = conn.cursor() |
|
disease_list = [] |
|
unique_symptoms_list = [] |
|
matching_symptom_str = ','.join(matching_symptoms) |
|
|
|
|
|
for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases WHERE symptoms LIKE ?", |
|
(f'%{matching_symptom_str}%',)): |
|
disease_info = { |
|
'Disease': row[0], |
|
'Symptoms': row[1].split(','), |
|
'Treatments': row[2] |
|
} |
|
disease_list.append(disease_info) |
|
|
|
|
|
for symptom in row[1].split(','): |
|
symptom_lower = symptom.strip().lower() |
|
if symptom_lower not in unique_symptoms_list: |
|
unique_symptoms_list.append(symptom_lower) |
|
|
|
conn.close() |
|
return disease_list, unique_symptoms_list |
|
|
|
@app.post("/find_matching_symptoms") |
|
def find_matching_symptoms(query: SymptomQuery): |
|
symptoms = query.symptom.split(',') |
|
all_results = [] |
|
|
|
for symptom in symptoms: |
|
symptom = symptom.strip() |
|
query_embedding = model.encode([symptom]) |
|
|
|
|
|
results = collection.query( |
|
query_embeddings=query_embedding.tolist(), |
|
n_results=3 |
|
) |
|
all_results.extend(results['documents'][0]) |
|
|
|
matching_symptoms = list(dict.fromkeys(all_results)) |
|
return {"matching_symptoms": matching_symptoms} |
|
|
|
@app.post("/find_disease_list") |
|
def find_disease_list(query: SymptomQuery): |
|
|
|
selected_symptoms = [symptom.strip().lower() for symptom in query.symptom.split(',')] |
|
all_selected_symptoms.update(selected_symptoms) |
|
all_results = [] |
|
|
|
for symptom in selected_symptoms: |
|
|
|
query_embedding = model.encode([symptom]) |
|
|
|
|
|
results = collection.query( |
|
query_embeddings=query_embedding.tolist(), |
|
n_results=5 |
|
) |
|
|
|
all_results.extend(results['documents'][0]) |
|
|
|
|
|
matching_symptoms = list(dict.fromkeys(all_results)) |
|
|
|
conn = sqlite3.connect("diseases_symptoms.db") |
|
cursor = conn.cursor() |
|
|
|
disease_list = [] |
|
unique_symptoms_set = set() |
|
|
|
|
|
for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"): |
|
disease_name = row[0] |
|
disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] |
|
treatments = row[2] |
|
|
|
|
|
matched_symptoms = [symptom for symptom in matching_symptoms if symptom in disease_symptoms] |
|
|
|
if matched_symptoms: |
|
disease_info = { |
|
'Disease': disease_name, |
|
'Symptoms': disease_symptoms, |
|
'Treatments': treatments |
|
} |
|
disease_list.append(disease_info) |
|
|
|
|
|
for symptom in disease_symptoms: |
|
if symptom not in selected_symptoms: |
|
unique_symptoms_set.add(symptom) |
|
|
|
conn.close() |
|
|
|
|
|
unique_symptoms_list = sorted(unique_symptoms_set) |
|
|
|
return { |
|
"disease_list": disease_list, |
|
"unique_symptoms_list": unique_symptoms_list |
|
} |
|
|
|
|
|
|
|
class SelectedSymptomsQuery(BaseModel): |
|
selected_symptoms: list |
|
|
|
|
|
all_selected_symptoms = set() |
|
|
|
@app.post("/find_disease") |
|
def find_disease(query: SelectedSymptomsQuery): |
|
|
|
new_symptoms = [symptom.strip().lower() for symptom in query.selected_symptoms] |
|
all_selected_symptoms.update(new_symptoms) |
|
|
|
conn = sqlite3.connect("diseases_symptoms.db") |
|
cursor = conn.cursor() |
|
|
|
disease_list = [] |
|
unique_symptoms_set = set() |
|
|
|
|
|
for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"): |
|
disease_name = row[0] |
|
disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] |
|
treatments = row[2] |
|
|
|
|
|
matched_symptoms = [symptom for symptom in all_selected_symptoms if symptom in disease_symptoms] |
|
|
|
|
|
if len(matched_symptoms) == len(all_selected_symptoms): |
|
disease_info = { |
|
'Disease': disease_name, |
|
'Symptoms': disease_symptoms, |
|
'Treatments': treatments |
|
} |
|
disease_list.append(disease_info) |
|
|
|
|
|
for symptom in disease_symptoms: |
|
if symptom not in all_selected_symptoms: |
|
unique_symptoms_set.add(symptom) |
|
|
|
conn.close() |
|
|
|
|
|
unique_symptoms_list = sorted(unique_symptoms_set) |
|
|
|
return { |
|
"unique_symptoms_list": unique_symptoms_list, |
|
"all_selected_symptoms": list(all_selected_symptoms), |
|
"disease_list": disease_list |
|
} |
|
|
|
class DiseaseDetail(BaseModel): |
|
Disease: str |
|
Symptoms: list |
|
Treatments: str |
|
MatchCount: int |
|
|
|
@app.post("/pass2llm") |
|
def pass2llm(query: DiseaseDetail): |
|
headers = { |
|
"Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG", |
|
"Ngrok-Version": "2" |
|
} |
|
response = requests.get("https://api.ngrok.com/endpoints", headers=headers) |
|
|
|
if response.status_code == 200: |
|
llm_api_response = response.json() |
|
public_url = llm_api_response['endpoints'][0]['public_url'] |
|
prompt = f"Here is a list of diseases and their details: {query}. Please generate a summary." |
|
|
|
llm_headers = {"Content-Type": "application/json"} |
|
llm_payload = {"model": "llama3", "prompt": prompt, "stream": False} |
|
llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload) |
|
|
|
if llm_response.status_code == 200: |
|
llm_response_json = llm_response.json() |
|
return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")} |
|
else: |
|
return {"message": "Failed to get response from LLM!", "error": llm_response.text} |
|
else: |
|
return {"message": "Failed to get public URL from Ngrok!", "error": response.text} |
|
|
|
|
|
|
|
|
|
@app.post("/trigger-reload") |
|
async def trigger_reload(): |
|
global all_selected_symptoms |
|
all_selected_symptoms.clear() |
|
return "cleared" |
|
|
|
|
|
|
|
|
|
|