from fastapi import FastAPI from pydantic import BaseModel from sentence_transformers import SentenceTransformer import chromadb from fastapi.middleware.cors import CORSMiddleware import uvicorn import requests from itertools import combinations import sqlite3 import pandas as pd import os import time # Define FastAPI app app = FastAPI() origins = [ "http://localhost:5173", "localhost:5173" ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) # Load the model at startup model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') client = chromadb.PersistentClient(path='./chromadb') collection = client.get_or_create_collection(name="symptomsvector") # Helper function to initialize database and populate from CSV if needed def init_db(): conn = sqlite3.connect("diseases_symptoms.db") cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS diseases ( id INTEGER PRIMARY KEY, name TEXT, symptoms TEXT, treatments TEXT ) ''') conn.commit() return conn # Populate database from CSV if it's the first time if not os.path.exists("diseases_symptoms.db"): conn = init_db() df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv") df['Symptoms'] = df['Symptoms'].str.split(',').apply(lambda x: [s.strip() for s in x]) for _, row in df.iterrows(): symptoms_str = ",".join(row['Symptoms']) cursor = conn.cursor() cursor.execute("INSERT INTO diseases (name, symptoms, treatments) VALUES (?, ?, ?)", (row['Name'], symptoms_str, row.get('Treatments', ''))) conn.commit() conn.close() class SymptomQuery(BaseModel): symptom: str # Helper function to fetch diseases matching symptoms from SQLite def fetch_diseases_by_symptoms(matching_symptoms): conn = sqlite3.connect("diseases_symptoms.db") cursor = conn.cursor() disease_list = [] unique_symptoms_list = [] matching_symptom_str = ','.join(matching_symptoms) # Retrieve matching diseases based on symptoms in SQLite for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases WHERE symptoms LIKE ?", (f'%{matching_symptom_str}%',)): disease_info = { 'Disease': row[0], 'Symptoms': row[1].split(','), 'Treatments': row[2] } disease_list.append(disease_info) # Add symptoms to the unique list, converting to lowercase to avoid duplicates for symptom in row[1].split(','): symptom_lower = symptom.strip().lower() if symptom_lower not in unique_symptoms_list: unique_symptoms_list.append(symptom_lower) conn.close() return disease_list, unique_symptoms_list @app.post("/find_matching_symptoms") def find_matching_symptoms(query: SymptomQuery): symptoms = query.symptom.split(',') all_results = [] for symptom in symptoms: symptom = symptom.strip() query_embedding = model.encode([symptom]) # Perform similarity search in ChromaDB results = collection.query( query_embeddings=query_embedding.tolist(), n_results=3 ) all_results.extend(results['documents'][0]) matching_symptoms = list(dict.fromkeys(all_results)) return {"matching_symptoms": matching_symptoms} @app.post("/find_disease_list") def find_disease_list(query: SymptomQuery): # Normalize and embed each input symptom selected_symptoms = [symptom.strip().lower() for symptom in query.symptom.split(',')] all_selected_symptoms.update(selected_symptoms) # Add new symptoms to the set all_results = [] for symptom in selected_symptoms: # Generate the embedding for the current symptom query_embedding = model.encode([symptom]) # Perform similarity search in ChromaDB results = collection.query( query_embeddings=query_embedding.tolist(), n_results=5 # Return top 5 similar symptoms for each input symptom ) # Aggregate the matching symptoms from the results all_results.extend(results['documents'][0]) # Remove duplicates while preserving order matching_symptoms = list(dict.fromkeys(all_results)) conn = sqlite3.connect("diseases_symptoms.db") cursor = conn.cursor() disease_list = [] unique_symptoms_set = set() # Retrieve diseases that contain any of the matching symptoms for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"): disease_name = row[0] disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] # Normalize database symptoms treatments = row[2] # Check if there is any overlap between matching symptoms and the disease symptoms matched_symptoms = [symptom for symptom in matching_symptoms if symptom in disease_symptoms] if matched_symptoms: # Include disease if there is at least one matching symptom disease_info = { 'Disease': disease_name, 'Symptoms': disease_symptoms, 'Treatments': treatments } disease_list.append(disease_info) # Add symptoms not yet selected by the user to unique symptoms list for symptom in disease_symptoms: if symptom not in selected_symptoms: unique_symptoms_set.add(symptom) conn.close() # Convert unique symptoms set to a sorted list for consistent output unique_symptoms_list = sorted(unique_symptoms_set) return { "disease_list": disease_list, "unique_symptoms_list": unique_symptoms_list } class SelectedSymptomsQuery(BaseModel): selected_symptoms: list # Initialize global list for persistent selected symptoms all_selected_symptoms = set() # Use a set to avoid duplicates @app.post("/find_disease") def find_disease(query: SelectedSymptomsQuery): # Normalize input symptoms and add them to global list new_symptoms = [symptom.strip().lower() for symptom in query.selected_symptoms] all_selected_symptoms.update(new_symptoms) # Add new symptoms to the set conn = sqlite3.connect("diseases_symptoms.db") cursor = conn.cursor() disease_list = [] unique_symptoms_set = set() # Fetch all diseases and calculate matching symptoms for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"): disease_name = row[0] disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] treatments = row[2] # Find common symptoms between all selected and disease symptoms matched_symptoms = [symptom for symptom in all_selected_symptoms if symptom in disease_symptoms] # Check for full match between known symptoms and disease symptoms if len(matched_symptoms) == len(all_selected_symptoms): disease_info = { 'Disease': disease_name, 'Symptoms': disease_symptoms, 'Treatments': treatments } disease_list.append(disease_info) # Add symptoms not yet selected by the user to unique symptoms list for symptom in disease_symptoms: if symptom not in all_selected_symptoms: unique_symptoms_set.add(symptom) conn.close() # Convert unique symptoms set to a sorted list for consistent output unique_symptoms_list = sorted(unique_symptoms_set) return { "unique_symptoms_list": unique_symptoms_list, "all_selected_symptoms": list(all_selected_symptoms), # Convert set to list for JSON response "disease_list": disease_list } class DiseaseDetail(BaseModel): Disease: str Symptoms: list Treatments: str MatchCount: int @app.post("/pass2llm") def pass2llm(query: DiseaseDetail): headers = { "Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG", "Ngrok-Version": "2" } response = requests.get("https://api.ngrok.com/endpoints", headers=headers) if response.status_code == 200: llm_api_response = response.json() public_url = llm_api_response['endpoints'][0]['public_url'] prompt = f"Here is a list of diseases and their details: {query}. Please generate a summary." llm_headers = {"Content-Type": "application/json"} llm_payload = {"model": "llama3", "prompt": prompt, "stream": False} llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload) if llm_response.status_code == 200: llm_response_json = llm_response.json() return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")} else: return {"message": "Failed to get response from LLM!", "error": llm_response.text} else: return {"message": "Failed to get public URL from Ngrok!", "error": response.text} @app.post("/trigger-reload") async def trigger_reload(): # Update the timestamp of a dummy file to trigger reload with open("reload_trigger.txt", "w") as f: f.write(f"Trigger reload at {time.time()}") return {"message": "Reload triggered."} # To run the FastAPI app with Uvicorn # if __name__ == "__main__": # uvicorn.run(app, host="0.0.0.0", port=8000)