fastapi / app.py
Prajith04's picture
Update app.py
9a573be verified
raw
history blame
8.03 kB
from fastapi import FastAPI
from pydantic import BaseModel
import pandas as pd
from sentence_transformers import SentenceTransformer
import chromadb
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import requests
from itertools import combinations
# Define FastAPI app
app = FastAPI()
origins = [
"http://localhost:5173",
"localhost:5173"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# Load the dataset and model at startup
df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv")
df['Symptoms'] = df['Symptoms'].str.split(',')
df['Symptoms'] = df['Symptoms'].apply(lambda x: [s.strip() for s in x])
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
client = chromadb.PersistentClient(path='./chromadb')
collection = client.get_or_create_collection(name="symptomsvector")
class SymptomQuery(BaseModel):
symptom: str
# Endpoint to handle symptom query and return matching symptoms
@app.post("/find_matching_symptoms")
def find_matching_symptoms(query: SymptomQuery):
# Generate embedding for the symptom query
symptoms = query.symptom.split(',')
all_results = []
for symptom in symptoms:
symptom = symptom.strip()
query_embedding = model.encode([symptom])
# Perform similarity search in ChromaDB
results = collection.query(
query_embeddings=query_embedding.tolist(),
n_results=3 # Return top 3 similar symptoms for each symptom
)
all_results.extend(results['documents'][0])
# Remove duplicates while preserving order
matching_symptoms = list(dict.fromkeys(all_results))
return {"matching_symptoms": matching_symptoms}
# Endpoint to handle symptom query and return matching diseases
@app.post("/find_matching_diseases")
def find_matching_diseases(query: SymptomQuery):
# Generate embedding for the symptom query
query_embedding = model.encode([query.symptom])
# Perform similarity search in ChromaDB
results = collection.query(
query_embeddings=query_embedding.tolist(),
n_results=5 # Return top 5 similar symptoms
)
# Extract matching symptoms
matching_symptoms = results['documents'][0]
# Filter diseases that match the symptoms
matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))]
return {"matching_diseases": matching_diseases['Name'].tolist()}
all_symptoms=[]
all_selected_symptoms=[]
# Endpoint to handle symptom query and return detailed disease list
@app.post("/find_disease_list")
def find_disease_list(query: SymptomQuery):
# Generate embedding for the symptom query
query_embedding = model.encode([query.symptom])
# Perform similarity search in ChromaDB
results = collection.query(
query_embeddings=query_embedding.tolist(),
n_results=5 # Return top 5 similar symptoms
)
# Extract matching symptoms
matching_symptoms = results['documents'][0]
all_symptoms.append(matching_symptoms)
# Filter diseases that match the symptoms
matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))]
# Create a list of disease information
disease_list = []
symptoms_list = []
unique_symptoms_list = []
for _, row in matching_diseases.iterrows():
disease_info = {
'Disease': row['Name'],
'Symptoms': row['Symptoms'],
'Treatments': row['Treatments']
}
disease_list.append(disease_info)
symptoms_info = row['Symptoms']
symptoms_list.append(symptoms_info)
for i in range(len(symptoms_list)):
for j in range(len(symptoms_list[i])):
if symptoms_list[i][j] not in unique_symptoms_list:
unique_symptoms_list.append(symptoms_list[i][j].lower())
return {"disease_list": disease_list, "unique_symptoms_list": unique_symptoms_list}
class SelectedSymptomsQuery(BaseModel):
selected_symptoms: list
@app.post("/find_disease")
def find_disease(query: SelectedSymptomsQuery):
SelectedSymptoms = query.selected_symptoms
all_selected_symptoms.extend(SelectedSymptoms)
disease_list = []
symptoms_list = []
unique_symptoms_set = set() # Use a set for unique symptoms
# Combine all the symptoms we already know (all_symptoms + selected symptoms)
known_symptoms = {symptom.lower() for symptom_set in all_symptoms for symptom in symptom_set}
known_symptoms.update([symptom.lower() for symptom in SelectedSymptoms])
# Generate combinations of symptoms from all_symptoms and selected symptoms
for symptoms_set in all_symptoms:
for i in range(1, len(symptoms_set) + 1): # Generate combinations with all lengths
for symptom_combination in combinations(symptoms_set, i):
temp = list(symptom_combination) + SelectedSymptoms # Combine with selected symptoms
# Search for diseases that match the combination
matching_diseases = df[df['Symptoms'].apply(lambda x: all(s in x for s in temp))]
for _, row in matching_diseases.iterrows():
disease_info = {
'Disease': row['Name'],
'Symptoms': row['Symptoms'],
'Treatments': row['Treatments']
}
disease_list.append(disease_info)
# Add each symptom in lowercase to the set
for symptom in row['Symptoms']:
if symptom.lower() not in known_symptoms:
unique_symptoms_set.add(symptom.lower())
# Convert the set back to a list for the response
unique_symptoms_list = list(unique_symptoms_set)
return {
"unique_symptoms_list": unique_symptoms_list,
"all_selected_symptoms": all_selected_symptoms,
"all_symptoms": all_symptoms,
"disease_list": disease_list
}
class DiseaseListQuery(BaseModel):
disease_list: list
class DiseaseDetail(BaseModel):
Disease: str
Symptoms: list
Treatments: str
MatchCount: int
@app.post("/pass2llm")
def pass2llm(query: DiseaseDetail):
# Prepare the data to be sent to the LLM API
disease_list_details = query
# Make the API request to the Ngrok endpoint to get the public URL
headers = {
"Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG",
"Ngrok-Version": "2"
}
response = requests.get("https://api.ngrok.com/endpoints", headers=headers)
# Check if the request was successful
if response.status_code == 200:
llm_api_response = response.json()
public_url = llm_api_response['endpoints'][0]['public_url']
# Prepare the prompt with the disease list details
prompt = f"Here is a list of diseases and their details: {disease_list_details}. Please generate a summary."
# Make the request to the LLM API
llm_headers = {
"Content-Type": "application/json"
}
llm_payload = {
"model": "llama3",
"prompt": prompt,
"stream": False
}
llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload)
# Check if the request to the LLM API was successful
if llm_response.status_code == 200:
llm_response_json = llm_response.json()
return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")}
else:
return {"message": "Failed to get response from LLM!", "error": llm_response.text}
else:
return {"message": "Failed to get public URL from Ngrok!", "error": response.text}
# To run the FastAPI app with Uvicorn
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)