import streamlit as st from streamlit_agraph import agraph, Node, Edge, Config import os from sqlalchemy import create_engine, text import pandas as pd import time from utils import ( get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description, get_similarities_among_diseases_uris, augment_the_set_of_diseaces, get_clinical_trials_related_to_diseases, get_clinical_records_by_ids ) import json import numpy as np from sentence_transformers import SentenceTransformer begin = st.container() username = "demo" password = "demo" hostname = os.getenv("IRIS_HOSTNAME", "localhost") port = "1972" namespace = "USER" CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" engine = create_engine(CONNECTION_STRING) begin.write("# Klìnic") description_input = begin.text_input( label="Enter the disease description 👇", placeholder="A disease that causes memory loss and other cognitive impairments.", ) if begin.button("Analyze 🔎"): # 1. Embed the textual description that the user entered using the model # 2. Get 5 diseases with the highest cosine silimarity from the DB encoder = SentenceTransformer("allenai-specter") diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description( description_input, encoder ) # for disease_label in diseases_related_to_the_user_text: # st.text(disease_label) # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases) diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text] get_similarities_among_diseases_uris(diseases_uris) print(diseases_related_to_the_user_text) # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8) # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris) print(augmented_set_of_diseases) # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases( augmented_set_of_diseases, encoder ) print(f'clinical_trials_related_to_the_diseases: {clinical_trials_related_to_the_diseases}') json_of_clinical_trials = get_clinical_records_by_ids( [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases] ) print(f'json_of_clinical_trials: {json_of_clinical_trials}') # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that. # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered graph_of_diseases = agraph( nodes=[ Node(id="A", label="Node A", size=10), Node(id="B", label="Node B", size=10), Node(id="C", label="Node C", size=10), Node(id="D", label="Node D", size=10), Node(id="E", label="Node E", size=10), Node(id="F", label="Node F", size=10), Node(id="G", label="Node G", size=10), Node(id="H", label="Node H", size=10), Node(id="I", label="Node I", size=10), Node(id="J", label="Node J", size=10), ], edges=[ Edge(source="A", target="B"), Edge(source="B", target="C"), Edge(source="C", target="D"), Edge(source="D", target="E"), Edge(source="E", target="F"), Edge(source="F", target="G"), Edge(source="G", target="H"), Edge(source="H", target="I"), Edge(source="I", target="J"), ], config=Config(height=500, width=500), ) # TODO: also when user clicks enter begin.write(":red[Here should be the graph]") # TODO remove chart_data = pd.DataFrame( np.random.randn(20, 3), columns=["a", "b", "c"] ) # TODO remove begin.scatter_chart(chart_data) # TODO remove begin.write("## Disease Overview") disease_overview = ":red[lorem ipsum]" # TODO begin.write(disease_overview) begin.write("## Clinical Trials Details") trials = [] # TODO replace mock data with open("mock_trial.json") as f: d = json.load(f) for i in range(0, 5): trials.append(d) for trial in trials: with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"): official_title = trial["protocolSection"]["identificationModule"][ "officialTitle" ] st.write(f"##### {official_title}") brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"] st.write(brief_summary) status_module = { "Status": trial["protocolSection"]["statusModule"]["overallStatus"], "Status Date": trial["protocolSection"]["statusModule"][ "statusVerifiedDate" ], } st.write("###### Status") st.table(status_module) design_module = { "Study Type": trial["protocolSection"]["designModule"]["studyType"], # "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array "Allocation": trial["protocolSection"]["designModule"]["designInfo"][ "allocation" ], "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][ "count" ], } st.write("###### Design") st.table(design_module) # TODO more modules?