import streamlit as st from streamlit_agraph import agraph, Node, Edge, Config import os from sqlalchemy import create_engine, text import pandas as pd import time from utils import ( get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description, get_similarities_among_diseases_uris, augment_the_set_of_diseaces, get_clinical_trials_related_to_diseases, get_clinical_records_by_ids ) from llm_res import get_short_summary_out_of_json_files import json import numpy as np from sentence_transformers import SentenceTransformer # variables to reveal next steps show_graph = False show_analyze_status = False show_overview = False show_details = False # IRIS connection username = "demo" password = "demo" hostname = os.getenv("IRIS_HOSTNAME", "localhost") port = "1972" namespace = "USER" CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" engine = create_engine(CONNECTION_STRING) st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True) st.title("Klìnic", help="AI-powered clinical trial search engine") with st.container(): # user input col1, col2 = st.columns((6, 1)) with col1: description_input = st.text_area(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.') with col2: st.text('') # dummy to center vertically st.text('') # dummy to center vertically st.text('') # dummy to center vertically show_analyze_status = st.button("Analyze 🔎") # analyze with st.container(): if show_analyze_status: with st.status("Analyzing...") as status: # 1. Embed the textual description that the user entered using the model # 2. Get 5 diseases with the highest cosine silimarity from the DB status.write("Analyzing the description that you wrote...") encoder = SentenceTransformer("allenai-specter") diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description( description_input, encoder ) # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases) status.write("Getting the similarities among the diseases to filter out less promising ones...") diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text] get_similarities_among_diseases_uris(diseases_uris) # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8) # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases status.write("Augmenting the set of diseases by finding others with related embeddings...") augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris) # print(augmented_set_of_diseases) # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases status.write("Getting the clinical trials related to the diseases found...") clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases( augmented_set_of_diseases, encoder ) status.write("Getting the details of the clinical trials...") json_of_clinical_trials = get_clinical_records_by_ids( [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases] ) status.json(json_of_clinical_trials, expanded=False) # 7. Use an LLM to get a summary of the clinical trials, in plain text format. status.write("Getting a summary of the clinical trials...") response = get_short_summary_out_of_json_files(json_of_clinical_trials) print(f'Response from LLM: {response}') status.write(f'Response from LLM: {response}') # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that. status.write("Getting summary statistics of the clinical trials...") # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered status.update(label="Done!", state="complete") time.sleep(1) show_graph = True # graph with st.container(): if show_graph: # TODO actual graph graph_of_diseases = agraph( nodes=[ Node(id="A", label="Node A", size=10), Node(id="B", label="Node B", size=10), Node(id="C", label="Node C", size=10), Node(id="D", label="Node D", size=10), Node(id="E", label="Node E", size=10), Node(id="F", label="Node F", size=10), Node(id="G", label="Node G", size=10), Node(id="H", label="Node H", size=10), Node(id="I", label="Node I", size=10), Node(id="J", label="Node J", size=10), ], edges=[ Edge(source="A", target="B"), Edge(source="B", target="C"), Edge(source="C", target="D"), Edge(source="D", target="E"), Edge(source="E", target="F"), Edge(source="F", target="G"), Edge(source="G", target="H"), Edge(source="H", target="I"), Edge(source="I", target="J"), ], config=Config(height=500, width=500), ) time.sleep(2) show_overview = True # overview with st.container(): if show_overview: st.write("## Disease Overview") disease_overview = ":red[lorem ipsum]" # TODO st.write(disease_overview) time.sleep(2) show_details = True # details with st.container(): if show_details: st.write("## Clinical Trials Details") trials = [] # TODO replace mock data with open("mock_trial.json") as f: d = json.load(f) for i in range(0, 5): trials.append(d) for trial in trials: with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"): official_title = trial["protocolSection"]["identificationModule"][ "officialTitle" ] st.write(f"##### {official_title}") brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"] st.write(brief_summary) status_module = { "Status": trial["protocolSection"]["statusModule"]["overallStatus"], "Status Date": trial["protocolSection"]["statusModule"][ "statusVerifiedDate" ], } st.write("###### Status") st.table(status_module) design_module = { "Study Type": trial["protocolSection"]["designModule"]["studyType"], # "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array "Allocation": trial["protocolSection"]["designModule"]["designInfo"][ "allocation" ], "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][ "count" ], } st.write("###### Design") st.table(design_module) # TODO more modules?