klinic / app.py
ACMCMC
UI Changes
27d40b9
raw
history blame
5.96 kB
import streamlit as st
from streamlit_agraph import agraph, Node, Edge, Config
import os
from sqlalchemy import create_engine, text
import pandas as pd
import time
from utils import (
get_all_diseases_name,
get_most_similar_diseases_from_uri,
get_uri_from_name,
get_diseases_related_to_a_textual_description,
get_similarities_among_diseases_uris,
augment_the_set_of_diseaces,
get_clinical_trials_related_to_diseases,
get_clinical_records_by_ids
)
import json
import numpy as np
from sentence_transformers import SentenceTransformer
begin = st.container()
username = "demo"
password = "demo"
hostname = os.getenv("IRIS_HOSTNAME", "localhost")
port = "1972"
namespace = "USER"
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
engine = create_engine(CONNECTION_STRING)
begin.write("# Klìnic")
description_input = begin.text_input(
label="Enter the disease description 👇",
placeholder="A disease that causes memory loss and other cognitive impairments.",
)
if begin.button("Analyze 🔎"):
# 1. Embed the textual description that the user entered using the model
# 2. Get 5 diseases with the highest cosine silimarity from the DB
encoder = SentenceTransformer("allenai-specter")
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
description_input, encoder
)
# for disease_label in diseases_related_to_the_user_text:
# st.text(disease_label)
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
get_similarities_among_diseases_uris(diseases_uris)
print(diseases_related_to_the_user_text)
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
print(augmented_set_of_diseases)
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
augmented_set_of_diseases, encoder
)
print(f'clinical_trials_related_to_the_diseases: {clinical_trials_related_to_the_diseases}')
json_of_clinical_trials = get_clinical_records_by_ids(
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
)
print(f'json_of_clinical_trials: {json_of_clinical_trials}')
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
graph_of_diseases = agraph(
nodes=[
Node(id="A", label="Node A", size=10),
Node(id="B", label="Node B", size=10),
Node(id="C", label="Node C", size=10),
Node(id="D", label="Node D", size=10),
Node(id="E", label="Node E", size=10),
Node(id="F", label="Node F", size=10),
Node(id="G", label="Node G", size=10),
Node(id="H", label="Node H", size=10),
Node(id="I", label="Node I", size=10),
Node(id="J", label="Node J", size=10),
],
edges=[
Edge(source="A", target="B"),
Edge(source="B", target="C"),
Edge(source="C", target="D"),
Edge(source="D", target="E"),
Edge(source="E", target="F"),
Edge(source="F", target="G"),
Edge(source="G", target="H"),
Edge(source="H", target="I"),
Edge(source="I", target="J"),
],
config=Config(height=500, width=500),
)
# TODO: also when user clicks enter
begin.write(":red[Here should be the graph]") # TODO remove
chart_data = pd.DataFrame(
np.random.randn(20, 3), columns=["a", "b", "c"]
) # TODO remove
begin.scatter_chart(chart_data) # TODO remove
begin.write("## Disease Overview")
disease_overview = ":red[lorem ipsum]" # TODO
begin.write(disease_overview)
begin.write("## Clinical Trials Details")
trials = []
# TODO replace mock data
with open("mock_trial.json") as f:
d = json.load(f)
for i in range(0, 5):
trials.append(d)
for trial in trials:
with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"):
official_title = trial["protocolSection"]["identificationModule"][
"officialTitle"
]
st.write(f"##### {official_title}")
brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
st.write(brief_summary)
status_module = {
"Status": trial["protocolSection"]["statusModule"]["overallStatus"],
"Status Date": trial["protocolSection"]["statusModule"][
"statusVerifiedDate"
],
}
st.write("###### Status")
st.table(status_module)
design_module = {
"Study Type": trial["protocolSection"]["designModule"]["studyType"],
# "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array
"Allocation": trial["protocolSection"]["designModule"]["designInfo"][
"allocation"
],
"Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
"count"
],
}
st.write("###### Design")
st.table(design_module)
# TODO more modules?