Spaces:
Sleeping
Sleeping
File size: 15,934 Bytes
a6bd112 93e1b64 27d40b9 a6bd112 27d40b9 a6bd112 27d40b9 a6bd112 27d40b9 a6bd112 27d40b9 a6bd112 e7d7b51 a6bd112 27d40b9 93e1b64 9b5c4aa 4bb7c94 27d40b9 9b5c4aa 27d40b9 93e1b64 27d40b9 ee05396 a6bd112 ee05396 a6bd112 9b5c4aa a6bd112 9b5c4aa a6bd112 9b5c4aa 1e2e3b8 9b5c4aa a6bd112 9b5c4aa 4021316 47c6369 9b5c4aa a6bd112 47c6369 a6bd112 47c6369 a6bd112 e7d7b51 a6bd112 e7d7b51 47c6369 9b5c4aa a6bd112 551646a a6bd112 551646a a6bd112 551646a e7d7b51 9b5c4aa 1e2e3b8 a6bd112 9b5c4aa 47c6369 1e2e3b8 9b5c4aa a6bd112 ee05396 47c6369 1e2e3b8 4bb7c94 a6bd112 1b1c01c a6bd112 1b1c01c 4bb7c94 a6bd112 4bb7c94 a6bd112 1b1c01c a6bd112 9b5c4aa 4021316 9b5c4aa e7d7b51 9b5c4aa 3cfdd19 47c6369 3cfdd19 47c6369 3cfdd19 a6bd112 551646a a6bd112 9b5c4aa a6bd112 2bb81aa a6bd112 2bb81aa a6bd112 2bb81aa a6bd112 2bb81aa a6bd112 2bb81aa 551646a a6bd112 9b5c4aa 4bb7c94 a6bd112 4bb7c94 a6bd112 4bb7c94 a6bd112 4bb7c94 9b5c4aa a6bd112 796a53b ec6a815 76919d3 551646a 2bb81aa 551646a 1b1c01c a6bd112 76919d3 1b1c01c 76919d3 1b1c01c 76919d3 a6bd112 76919d3 a6bd112 1b1c01c a6bd112 1b1c01c a6bd112 1b1c01c a6bd112 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
import json
import os
import time
import matplotlib
import numpy as np
import pandas as pd
import streamlit as st
from sentence_transformers import SentenceTransformer
from sqlalchemy import create_engine, text
from streamlit_agraph import Config, Edge, Node, agraph
from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
from utils import (
augment_the_set_of_diseaces,
filter_out_less_promising_diseases,
get_all_diseases_name,
get_clinical_records_by_ids,
get_clinical_trials_related_to_diseases,
get_diseases_related_to_a_textual_description,
get_most_similar_diseases_from_uri,
get_similarities_among_diseases_uris,
get_similarities_df,
get_uri_from_name,
render_trial_details,
get_labels_of_diseases_from_uris,
)
# variables to reveal next steps
show_graph = False
show_analyze_status = False
show_overview = False
show_details = False
show_metrics = False
# IRIS connection
username = "demo"
password = "demo"
hostname = os.getenv("IRIS_HOSTNAME", "localhost")
port = "1972"
namespace = "USER"
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
engine = create_engine(CONNECTION_STRING)
st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
st.title("Klìnic", help="AI-powered clinical trial search engine")
st.subheader(
"Find clinical trials in a scoped domain of biomedical research, guiding your research with AI-powered insights."
)
with st.container(): # user input
col1, col2 = st.columns((6, 1))
with col1:
description_input = st.text_area(
label="Enter a disease description 👇",
placeholder="A disorder manifested in memory loss and other cognitive impairments among elderly patients (60+ years old), especially women.",
)
with col2:
st.text("") # dummy to center vertically
st.text("") # dummy to center vertically
st.text("") # dummy to center vertically
show_analyze_status = st.button("Analyze 🔎")
# analyze
with st.container():
if show_analyze_status:
with st.status("Analyzing...") as status:
# 1. Embed the textual description that the user entered using the model
# 2. Get 5 diseases with the highest cosine silimarity from the DB
status.write("Analyzing the description that you wrote...")
encoder = SentenceTransformer("allenai-specter")
diseases_related_to_the_user_text = (
get_diseases_related_to_a_textual_description(
description_input, encoder
)
)
status.info(
f"Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered."
)
status.json(diseases_related_to_the_user_text, expanded=False)
status.divider()
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
status.write(
"Getting the similarities among the diseases to filter out less promising ones..."
)
diseases_uris = [
disease["uri"] for disease in diseases_related_to_the_user_text
]
similarities = get_similarities_among_diseases_uris(diseases_uris)
status.info(
f"Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings."
)
status.json(similarities, expanded=False)
filtered_diseases_uris, df_similarities = (
filter_out_less_promising_diseases(similarities)
)
# Apply a colormap to the table
status.table(
df_similarities.style.background_gradient(cmap="viridis", axis=None)
)
status.info(
f"Filtered out less promising diseases, keeping {len(filtered_diseases_uris)} diseases."
)
status.json(filtered_diseases_uris, expanded=False)
status.divider()
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
status.write(
"Augmenting the set of diseases by finding others with related embeddings..."
)
augmented_set_of_diseases = augment_the_set_of_diseaces(filtered_diseases_uris)
similarities_of_augmented_set_of_diseases = (
get_similarities_among_diseases_uris(augmented_set_of_diseases)
)
df_similarities_augmented_set = get_similarities_df(
similarities_of_augmented_set_of_diseases
)
#status.json(similarities_of_augmented_set_of_diseases, expanded=True)
status.info(
f"Augmented set of diseases: {len(augmented_set_of_diseases)} diseases."
)
status.table(
df_similarities_augmented_set.style.background_gradient(cmap="viridis", axis=None)
)
status.json(augmented_set_of_diseases, expanded=False)
status.divider()
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
status.write("Getting the clinical trials related to the diseases found...")
clinical_trials_related_to_the_diseases = (
get_clinical_trials_related_to_diseases(
augmented_set_of_diseases, encoder
)
)
status.info(
f"Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases."
)
status.json(clinical_trials_related_to_the_diseases, expanded=False)
status.divider()
status.write("Getting the details of the clinical trials...")
json_of_clinical_trials = get_clinical_records_by_ids(
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
)
status.success(f"Details of the clinical trials obtained.")
status.json(json_of_clinical_trials, expanded=False)
status.divider()
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
try:
status.write("Getting a summary of the clinical trials...")
response = get_short_summary_out_of_json_files(json_of_clinical_trials)
status.success("Summary of the clinical trials obtained.")
disease_overview = response
except Exception as e:
print(f"Error while getting a summary of the clinical trials: {e}")
status.warning(
f"Error while getting a summary of the clinical trials. This information will not be shown."
)
try:
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
status.write("Getting summary statistics of the clinical trials...")
response = tagging_insights_from_json(json_of_clinical_trials)
average_minimum_age = response["avg_min_age"]
average_maximum_age = response["avg_max_age"]
most_common_gender = response["most_common_gender"]
print(f"Response from LLM tagging: {response}")
status.success(f"Summary statistics of the clinical trials obtained.")
except Exception as e:
print(
f"Error while extracting numerical data from the clinical trials: {e}"
)
status.warning(
f"Error while extracting numerical data from the clinical trials. This information will not be shown."
)
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
status.update(label="Done!", state="complete")
status.balloons()
show_graph = True
trials = json_of_clinical_trials
# graph
with st.container():
if show_graph:
st.info(
"""This is a graph of the relevant diseases that we found, based on the description that you entered. The diseases are connected by edges if they are similar to each other. The color of the edges represents the similarity of the diseases.
We use the embeddings of the diseases to determine the similarity between them. The embeddings are generated using a Representation Learning algorithm that learns the topological relations among the nodes in the graph, depending on how they are connected. We utilize the [PyKeen](https://github.com/pykeen/pykeen) implementation of TransH to train an embedding model.
[TransH](https://ojs.aaai.org/index.php/AAAI/article/view/8870) utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.
Specifically, it optimizes the following cost function:
$\\text{minimize} \\sum_{(h, r, t) \\in S} \\max(0, \\gamma + f(h, r, t) - f(h, r, t')) + \\sum_{(h, r, t) \\in S'} f(h, r, t)$
By minimizing this cost function, the model learns the embeddings of the entities and relations that best represent the graph. The embeddings are then used to calculate the similarity between the diseases, which is shown in the graph.
"""
)
try:
print(f'df_similarities_augmented_set.index: {df_similarities_augmented_set.index}')
edges_to_show = []
labels_of_diseases = get_labels_of_diseases_from_uris(
augmented_set_of_diseases
)
print(f'labels_of_diseases: {labels_of_diseases}')
uris_and_labels_of_diseases = dict(
zip(df_similarities_augmented_set.index, labels_of_diseases)
)
print(f'uris_and_labels_of_diseases: {uris_and_labels_of_diseases}')
color_mapper = matplotlib.cm.get_cmap("viridis")
for source in df_similarities_augmented_set.index:
for target in df_similarities_augmented_set.columns:
if source != target:
weight = df_similarities_augmented_set.loc[source, target]
color = color_mapper(weight)
# Convert from rgba to hex
color = matplotlib.colors.to_hex(color)
edges_to_show.append(
Edge(
source=source,
target=target,
# Dynamic color based on the weight
color=color,
weight=weight**10,
type="CURVE_SMOOTH",
label=f"{weight:.2f}",
)
)
graph_of_diseases = agraph(
nodes=[
Node(
id=disease,
# If it's nan then use the URI
label=uris_and_labels_of_diseases[disease] if (not pd.isna(uris_and_labels_of_diseases[disease]) and uris_and_labels_of_diseases[disease] != "nan") else disease,
size=50,
shape="circular",
)
for disease in df_similarities_augmented_set.index
],
edges=edges_to_show,
config=Config(height=500, width=500),
)
time.sleep(2)
except Exception as e:
print(f"Error while showing the graph of the diseases: {e}")
st.error("Error while showing the graph of the diseases.")
finally:
show_overview = True
# overview
with st.container():
if show_overview:
try:
st.write("## Overview of Related Clinical Trials")
st.write(disease_overview)
time.sleep(1)
except Exception as e:
print(f"Error while showing the overview of the clinical trials: {e}")
finally:
show_metrics = True
with st.container():
if show_metrics:
try:
st.write("## Metrics of the Clinical Trials")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Average Minimum Age", average_minimum_age)
with col2:
st.metric("Average Maximum Age", average_maximum_age)
with col3:
st.metric("Most Common Gender", most_common_gender)
time.sleep(2)
except Exception as e:
print(f"Error while showing the metrics: {e}")
finally:
show_details = True
# details
with st.container():
if show_details:
st.write("## Clinical Trials Details")
tab_titles = [
f"{trial['protocolSection']['identificationModule']['nctId']}"
for trial in trials
]
tabs = st.tabs(tab_titles)
for i in range(0, len(tabs)):
with tabs[i]:
render_trial_details(trials[i])
st.divider()
st.markdown(
"""This app has been created in HackUPC 2024 by the team 'Klìnic'. The team members are:
- [Aldan Creo](https://acmc-website.web.app)
- [Matthias Seiler](https://www.linkedin.com/in/maseiler/)
- [Tanguyvans Vansnick](https://www.linkedin.com/in/tanguy-vansnick-44186a199/)
- [Arjit Samal](https://www.linkedin.com/in/arijit-samal1/)
"""
)
show_graph_of_all_diseases = False
if show_graph_of_all_diseases:
# If disease_names is not defined, define it
if "disease_names" not in st.session_state:
st.session_state.disease_names = get_all_diseases_name(engine)
chosen_disease_name = st.selectbox(
"Choose a disease",
st.session_state.disease_names,
)
st.write("You selected:", chosen_disease_name)
chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
nodes = []
edges = []
nodes.append(
Node(
id=chosen_disease_uri, label=chosen_disease_name, size=25, shape="circular"
)
)
similar_diseases = get_most_similar_diseases_from_uri(
engine, chosen_disease_uri, threshold=0.6
)
print(similar_diseases)
for uri, name, weight in similar_diseases:
nodes.append(Node(id=uri, label=name, size=25, shape="circular"))
print(True if float(weight) > 0.7 else False)
edges.append(
Edge(
source=chosen_disease_uri,
target=uri,
color="red" if float(weight) > 0.7 else "blue",
weight=float(weight) ** 10,
type="CURVE_SMOOTH",
# type="STRAIGHT"
)
)
config = Config(
width=750,
height=950,
directed=False,
physics=True,
hierarchical=False,
collapsible=False,
# **kwargs
)
return_value = agraph(nodes=nodes, edges=edges, config=config)
|