Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
•
a6bd112
1
Parent(s):
b375334
WIP
Browse files- graph_analysis.m → MATLAB/get_metrics.m +3 -11
- MATLAB/main.m +1 -1
- MATLAB/visualize_app.mlapp +0 -0
- app.py +178 -113
- calculate_smilar_nodes.py +9 -3
- llm_res.py +16 -15
- main.ipynb +33 -30
- utils.py +138 -73
graph_analysis.m → MATLAB/get_metrics.m
RENAMED
@@ -1,7 +1,6 @@
|
|
1 |
% Read the CSV file
|
2 |
-
data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
|
3 |
data = renamevars(data,"#CUI1","CUI1");
|
4 |
-
data = data(1:1000,:);
|
5 |
ids_1 = data.CUI1;
|
6 |
for k = 1 : length(ids_1)
|
7 |
cellContents = ids_1{k};
|
@@ -10,7 +9,6 @@ for k = 1 : length(ids_1)
|
|
10 |
end
|
11 |
ids_1 = str2double(ids_1);
|
12 |
ids_2 = data.CUI2;
|
13 |
-
ids_2 = data.CUI1(2:end);
|
14 |
for k = 1 : length(ids_2)
|
15 |
cellContents = ids_2{k};
|
16 |
% Truncate and stick back into the cell
|
@@ -18,11 +16,6 @@ for k = 1 : length(ids_2)
|
|
18 |
end
|
19 |
ids_2 = str2double(ids_2);
|
20 |
|
21 |
-
|
22 |
-
ids_1 = ids_1(1:end-1);
|
23 |
-
ids_2 = ids_2(2:end);
|
24 |
-
|
25 |
-
|
26 |
% Get the number of unique nodes
|
27 |
%nodes = unique([ids_1; ids_2]);
|
28 |
%num_nodes = length(nodes);
|
@@ -36,8 +29,7 @@ ids_2 = ids_2(2:end);
|
|
36 |
%G = digraph(A);
|
37 |
G = digraph(ids_1, ids_2);
|
38 |
[bin,binsize] = conncomp(G,'Type','weak');
|
39 |
-
bin(1:
|
40 |
-
size(unique(bin))
|
41 |
max(binsize)
|
42 |
pg_ranks = centrality(G,'pagerank');
|
43 |
G.Nodes.PageRank = pg_ranks;
|
@@ -46,4 +38,4 @@ G.Nodes.PageRank = pg_ranks;
|
|
46 |
%G.Nodes.Hubs = hub_ranks;
|
47 |
%G.Nodes.Authorities = auth_ranks;
|
48 |
G.Nodes
|
49 |
-
%plot(G);
|
|
|
1 |
% Read the CSV file
|
2 |
+
data = readtable('../MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
|
3 |
data = renamevars(data,"#CUI1","CUI1");
|
|
|
4 |
ids_1 = data.CUI1;
|
5 |
for k = 1 : length(ids_1)
|
6 |
cellContents = ids_1{k};
|
|
|
9 |
end
|
10 |
ids_1 = str2double(ids_1);
|
11 |
ids_2 = data.CUI2;
|
|
|
12 |
for k = 1 : length(ids_2)
|
13 |
cellContents = ids_2{k};
|
14 |
% Truncate and stick back into the cell
|
|
|
16 |
end
|
17 |
ids_2 = str2double(ids_2);
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
% Get the number of unique nodes
|
20 |
%nodes = unique([ids_1; ids_2]);
|
21 |
%num_nodes = length(nodes);
|
|
|
29 |
%G = digraph(A);
|
30 |
G = digraph(ids_1, ids_2);
|
31 |
[bin,binsize] = conncomp(G,'Type','weak');
|
32 |
+
bin(1:10)
|
|
|
33 |
max(binsize)
|
34 |
pg_ranks = centrality(G,'pagerank');
|
35 |
G.Nodes.PageRank = pg_ranks;
|
|
|
38 |
%G.Nodes.Hubs = hub_ranks;
|
39 |
%G.Nodes.Authorities = auth_ranks;
|
40 |
G.Nodes
|
41 |
+
%plot(G);
|
MATLAB/main.m
CHANGED
@@ -17,7 +17,7 @@ end
|
|
17 |
|
18 |
data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
|
19 |
data = renamevars(data,"#CUI1","CUI1");
|
20 |
-
data = data(1:
|
21 |
|
22 |
% Create a Map to store connections
|
23 |
connectionsMap = containers.Map('KeyType','char', 'ValueType','any');
|
|
|
17 |
|
18 |
data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
|
19 |
data = renamevars(data,"#CUI1","CUI1");
|
20 |
+
data = data(1:2000,:);
|
21 |
|
22 |
% Create a Map to store connections
|
23 |
connectionsMap = containers.Map('KeyType','char', 'ValueType','any');
|
MATLAB/visualize_app.mlapp
CHANGED
Binary files a/MATLAB/visualize_app.mlapp and b/MATLAB/visualize_app.mlapp differ
|
|
app.py
CHANGED
@@ -1,27 +1,30 @@
|
|
1 |
-
import
|
2 |
-
from streamlit_agraph import agraph, Node, Edge, Config
|
3 |
import os
|
4 |
-
from sqlalchemy import create_engine, text
|
5 |
-
import pandas as pd
|
6 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from utils import (
|
|
|
|
|
8 |
get_all_diseases_name,
|
9 |
-
|
10 |
-
|
11 |
get_diseases_related_to_a_textual_description,
|
|
|
12 |
get_similarities_among_diseases_uris,
|
13 |
-
|
14 |
-
|
15 |
-
get_clinical_records_by_ids,
|
16 |
render_trial_details,
|
17 |
-
|
18 |
)
|
19 |
-
from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
|
20 |
-
import json
|
21 |
-
import numpy as np
|
22 |
-
from sentence_transformers import SentenceTransformer
|
23 |
-
import matplotlib
|
24 |
-
|
25 |
|
26 |
# variables to reveal next steps
|
27 |
show_graph = False
|
@@ -42,17 +45,22 @@ engine = create_engine(CONNECTION_STRING)
|
|
42 |
|
43 |
st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
|
44 |
st.title("Klìnic", help="AI-powered clinical trial search engine")
|
45 |
-
st.subheader(
|
|
|
|
|
46 |
|
47 |
-
with st.container():
|
48 |
col1, col2 = st.columns((6, 1))
|
49 |
|
50 |
with col1:
|
51 |
-
description_input = st.text_area(
|
|
|
|
|
|
|
52 |
with col2:
|
53 |
-
st.text(
|
54 |
-
st.text(
|
55 |
-
st.text(
|
56 |
show_analyze_status = st.button("Analyze 🔎")
|
57 |
|
58 |
|
@@ -64,45 +72,78 @@ with st.container():
|
|
64 |
# 2. Get 5 diseases with the highest cosine silimarity from the DB
|
65 |
status.write("Analyzing the description that you wrote...")
|
66 |
encoder = SentenceTransformer("allenai-specter")
|
67 |
-
diseases_related_to_the_user_text =
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
69 |
)
|
70 |
-
status.info(f'Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
|
71 |
status.json(diseases_related_to_the_user_text, expanded=False)
|
72 |
status.divider()
|
73 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
74 |
-
status.write(
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
similarities = get_similarities_among_diseases_uris(diseases_uris)
|
77 |
-
status.info(
|
|
|
|
|
78 |
status.json(similarities, expanded=False)
|
79 |
-
filtered_diseases_uris, df_similarities =
|
|
|
|
|
80 |
# Apply a colormap to the table
|
81 |
-
status.table(
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
status.json(filtered_diseases_uris, expanded=False)
|
84 |
status.divider()
|
85 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
86 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
87 |
-
status.write(
|
|
|
|
|
88 |
augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
status.json(augmented_set_of_diseases, expanded=False)
|
92 |
status.divider()
|
93 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
94 |
status.write("Getting the clinical trials related to the diseases found...")
|
95 |
-
clinical_trials_related_to_the_diseases =
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
)
|
98 |
-
status.info(f'Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases.')
|
99 |
status.json(clinical_trials_related_to_the_diseases, expanded=False)
|
100 |
status.divider()
|
101 |
status.write("Getting the details of the clinical trials...")
|
102 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
103 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
104 |
)
|
105 |
-
status.success(f
|
106 |
status.json(json_of_clinical_trials, expanded=False)
|
107 |
status.divider()
|
108 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
@@ -112,22 +153,27 @@ with st.container():
|
|
112 |
status.success("Summary of the clinical trials obtained.")
|
113 |
disease_overview = response
|
114 |
except Exception as e:
|
115 |
-
print(f
|
116 |
-
status.warning(
|
|
|
|
|
117 |
try:
|
118 |
-
|
119 |
status.write("Getting summary statistics of the clinical trials...")
|
120 |
response = tagging_insights_from_json(json_of_clinical_trials)
|
121 |
average_minimum_age = response["avg_min_age"]
|
122 |
average_maximum_age = response["avg_max_age"]
|
123 |
-
most_common_gender = response[
|
124 |
|
125 |
-
print(f
|
126 |
-
status.success(f
|
127 |
except Exception as e:
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
131 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
132 |
status.update(label="Done!", state="complete")
|
133 |
status.balloons()
|
@@ -146,37 +192,55 @@ We use the embeddings of the diseases to determine the similarity between them.
|
|
146 |
[TransH](https://ojs.aaai.org/index.php/AAAI/article/view/8870) utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.
|
147 |
|
148 |
Specifically, it optimizes the following cost function:
|
149 |
-
|
150 |
-
|
151 |
-
# TODO actual graph
|
152 |
-
graph_of_diseases = agraph(
|
153 |
-
nodes=[
|
154 |
-
Node(id="A", label="Node A", size=10),
|
155 |
-
Node(id="B", label="Node B", size=10),
|
156 |
-
Node(id="C", label="Node C", size=10),
|
157 |
-
Node(id="D", label="Node D", size=10),
|
158 |
-
Node(id="E", label="Node E", size=10),
|
159 |
-
Node(id="F", label="Node F", size=10),
|
160 |
-
Node(id="G", label="Node G", size=10),
|
161 |
-
Node(id="H", label="Node H", size=10),
|
162 |
-
Node(id="I", label="Node I", size=10),
|
163 |
-
Node(id="J", label="Node J", size=10),
|
164 |
-
],
|
165 |
-
edges=[
|
166 |
-
Edge(source="A", target="B"),
|
167 |
-
Edge(source="B", target="C"),
|
168 |
-
Edge(source="C", target="D"),
|
169 |
-
Edge(source="D", target="E"),
|
170 |
-
Edge(source="E", target="F"),
|
171 |
-
Edge(source="F", target="G"),
|
172 |
-
Edge(source="G", target="H"),
|
173 |
-
Edge(source="H", target="I"),
|
174 |
-
Edge(source="I", target="J"),
|
175 |
-
],
|
176 |
-
config=Config(height=500, width=500),
|
177 |
)
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
|
182 |
# overview
|
@@ -187,7 +251,7 @@ with st.container():
|
|
187 |
st.write(disease_overview)
|
188 |
time.sleep(1)
|
189 |
except Exception as e:
|
190 |
-
print(f
|
191 |
finally:
|
192 |
show_metrics = True
|
193 |
|
@@ -196,7 +260,7 @@ with st.container():
|
|
196 |
if show_metrics:
|
197 |
try:
|
198 |
st.write("## Metrics of the Clinical Trials")
|
199 |
-
col1, col2, col3
|
200 |
with col1:
|
201 |
st.metric("Average Minimum Age", average_minimum_age)
|
202 |
with col2:
|
@@ -205,7 +269,7 @@ with st.container():
|
|
205 |
st.metric("Most Common Gender", most_common_gender)
|
206 |
time.sleep(2)
|
207 |
except Exception as e:
|
208 |
-
print(f
|
209 |
finally:
|
210 |
show_details = True
|
211 |
|
@@ -215,7 +279,10 @@ with st.container():
|
|
215 |
if show_details:
|
216 |
st.write("## Clinical Trials Details")
|
217 |
|
218 |
-
tab_titles = [
|
|
|
|
|
|
|
219 |
|
220 |
tabs = st.tabs(tab_titles)
|
221 |
|
@@ -231,7 +298,7 @@ if show_graph_of_all_diseases:
|
|
231 |
chosen_disease_name = st.selectbox(
|
232 |
"Choose a disease",
|
233 |
st.session_state.disease_names,
|
234 |
-
|
235 |
|
236 |
st.write("You selected:", chosen_disease_name)
|
237 |
chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
|
@@ -239,41 +306,39 @@ if show_graph_of_all_diseases:
|
|
239 |
nodes = []
|
240 |
edges = []
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
shape="circular")
|
247 |
)
|
248 |
|
249 |
-
similar_diseases = get_most_similar_diseases_from_uri(
|
|
|
|
|
250 |
print(similar_diseases)
|
251 |
for uri, name, weight in similar_diseases:
|
252 |
-
nodes.append(
|
253 |
-
label=name,
|
254 |
-
size=25,
|
255 |
-
shape="circular")
|
256 |
-
)
|
257 |
|
258 |
print(True if float(weight) > 0.7 else False)
|
259 |
-
edges.append(
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
276 |
|
277 |
-
return_value = agraph(nodes=nodes,
|
278 |
-
edges=edges,
|
279 |
-
config=config)
|
|
|
1 |
+
import json
|
|
|
2 |
import os
|
|
|
|
|
3 |
import time
|
4 |
+
|
5 |
+
import matplotlib
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import streamlit as st
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
from sqlalchemy import create_engine, text
|
11 |
+
from streamlit_agraph import Config, Edge, Node, agraph
|
12 |
+
|
13 |
+
from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
|
14 |
from utils import (
|
15 |
+
augment_the_set_of_diseaces,
|
16 |
+
filter_out_less_promising_diseases,
|
17 |
get_all_diseases_name,
|
18 |
+
get_clinical_records_by_ids,
|
19 |
+
get_clinical_trials_related_to_diseases,
|
20 |
get_diseases_related_to_a_textual_description,
|
21 |
+
get_most_similar_diseases_from_uri,
|
22 |
get_similarities_among_diseases_uris,
|
23 |
+
get_similarities_df,
|
24 |
+
get_uri_from_name,
|
|
|
25 |
render_trial_details,
|
26 |
+
get_labels_of_diseases_from_uris,
|
27 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# variables to reveal next steps
|
30 |
show_graph = False
|
|
|
45 |
|
46 |
st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
|
47 |
st.title("Klìnic", help="AI-powered clinical trial search engine")
|
48 |
+
st.subheader(
|
49 |
+
"Find clinical trials in a scoped domain of biomedical research, guiding your research with AI-powered insights."
|
50 |
+
)
|
51 |
|
52 |
+
with st.container(): # user input
|
53 |
col1, col2 = st.columns((6, 1))
|
54 |
|
55 |
with col1:
|
56 |
+
description_input = st.text_area(
|
57 |
+
label="Enter a disease description 👇",
|
58 |
+
placeholder="A disorder manifested in memory loss and other cognitive impairments among elderly patients (60+ years old), especially women.",
|
59 |
+
)
|
60 |
with col2:
|
61 |
+
st.text("") # dummy to center vertically
|
62 |
+
st.text("") # dummy to center vertically
|
63 |
+
st.text("") # dummy to center vertically
|
64 |
show_analyze_status = st.button("Analyze 🔎")
|
65 |
|
66 |
|
|
|
72 |
# 2. Get 5 diseases with the highest cosine silimarity from the DB
|
73 |
status.write("Analyzing the description that you wrote...")
|
74 |
encoder = SentenceTransformer("allenai-specter")
|
75 |
+
diseases_related_to_the_user_text = (
|
76 |
+
get_diseases_related_to_a_textual_description(
|
77 |
+
description_input, encoder
|
78 |
+
)
|
79 |
+
)
|
80 |
+
status.info(
|
81 |
+
f"Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered."
|
82 |
)
|
|
|
83 |
status.json(diseases_related_to_the_user_text, expanded=False)
|
84 |
status.divider()
|
85 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
86 |
+
status.write(
|
87 |
+
"Getting the similarities among the diseases to filter out less promising ones..."
|
88 |
+
)
|
89 |
+
diseases_uris = [
|
90 |
+
disease["uri"] for disease in diseases_related_to_the_user_text
|
91 |
+
]
|
92 |
similarities = get_similarities_among_diseases_uris(diseases_uris)
|
93 |
+
status.info(
|
94 |
+
f"Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings."
|
95 |
+
)
|
96 |
status.json(similarities, expanded=False)
|
97 |
+
filtered_diseases_uris, df_similarities = (
|
98 |
+
filter_out_less_promising_diseases(similarities)
|
99 |
+
)
|
100 |
# Apply a colormap to the table
|
101 |
+
status.table(
|
102 |
+
df_similarities.style.background_gradient(cmap="viridis", axis=None)
|
103 |
+
)
|
104 |
+
status.info(
|
105 |
+
f"Filtered out less promising diseases, keeping {len(filtered_diseases_uris)} diseases."
|
106 |
+
)
|
107 |
status.json(filtered_diseases_uris, expanded=False)
|
108 |
status.divider()
|
109 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
110 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
111 |
+
status.write(
|
112 |
+
"Augmenting the set of diseases by finding others with related embeddings..."
|
113 |
+
)
|
114 |
augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
|
115 |
+
similarities_of_augmented_set_of_diseases = (
|
116 |
+
get_similarities_among_diseases_uris(augmented_set_of_diseases)
|
117 |
+
)
|
118 |
+
df_similarities_augmented_set = get_similarities_df(
|
119 |
+
similarities_of_augmented_set_of_diseases
|
120 |
+
)
|
121 |
+
status.table(
|
122 |
+
df_similarities_augmented_set.style.background_gradient(cmap="viridis", axis=None)
|
123 |
+
)
|
124 |
+
status.json(similarities_of_augmented_set_of_diseases, expanded=True)
|
125 |
+
status.info(
|
126 |
+
f"Augmented set of diseases: {len(augmented_set_of_diseases)} diseases."
|
127 |
+
)
|
128 |
status.json(augmented_set_of_diseases, expanded=False)
|
129 |
status.divider()
|
130 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
131 |
status.write("Getting the clinical trials related to the diseases found...")
|
132 |
+
clinical_trials_related_to_the_diseases = (
|
133 |
+
get_clinical_trials_related_to_diseases(
|
134 |
+
augmented_set_of_diseases, encoder
|
135 |
+
)
|
136 |
+
)
|
137 |
+
status.info(
|
138 |
+
f"Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases."
|
139 |
)
|
|
|
140 |
status.json(clinical_trials_related_to_the_diseases, expanded=False)
|
141 |
status.divider()
|
142 |
status.write("Getting the details of the clinical trials...")
|
143 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
144 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
145 |
)
|
146 |
+
status.success(f"Details of the clinical trials obtained.")
|
147 |
status.json(json_of_clinical_trials, expanded=False)
|
148 |
status.divider()
|
149 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
|
|
153 |
status.success("Summary of the clinical trials obtained.")
|
154 |
disease_overview = response
|
155 |
except Exception as e:
|
156 |
+
print(f"Error while getting a summary of the clinical trials: {e}")
|
157 |
+
status.warning(
|
158 |
+
f"Error while getting a summary of the clinical trials. This information will not be shown."
|
159 |
+
)
|
160 |
try:
|
161 |
+
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
162 |
status.write("Getting summary statistics of the clinical trials...")
|
163 |
response = tagging_insights_from_json(json_of_clinical_trials)
|
164 |
average_minimum_age = response["avg_min_age"]
|
165 |
average_maximum_age = response["avg_max_age"]
|
166 |
+
most_common_gender = response["most_common_gender"]
|
167 |
|
168 |
+
print(f"Response from LLM tagging: {response}")
|
169 |
+
status.success(f"Summary statistics of the clinical trials obtained.")
|
170 |
except Exception as e:
|
171 |
+
print(
|
172 |
+
f"Error while extracting numerical data from the clinical trials: {e}"
|
173 |
+
)
|
174 |
+
status.warning(
|
175 |
+
f"Error while extracting numerical data from the clinical trials. This information will not be shown."
|
176 |
+
)
|
177 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
178 |
status.update(label="Done!", state="complete")
|
179 |
status.balloons()
|
|
|
192 |
[TransH](https://ojs.aaai.org/index.php/AAAI/article/view/8870) utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.
|
193 |
|
194 |
Specifically, it optimizes the following cost function:
|
195 |
+
$\\text{minimize} \\sum_{(h, r, t) \\in S} \\max(0, \\gamma + f(h, r, t) - f(h, r, t')) + \\sum_{(h, r, t) \\in S'} f(h, r, t)$
|
196 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
)
|
198 |
+
try:
|
199 |
+
edges_to_show = []
|
200 |
+
labels_of_diseases = get_labels_of_diseases_from_uris(
|
201 |
+
df_similarities_augmented_set.index
|
202 |
+
)
|
203 |
+
uris_and_labels_of_diseases = dict(
|
204 |
+
zip(df_similarities_augmented_set.index, labels_of_diseases)
|
205 |
+
)
|
206 |
+
color_mapper = matplotlib.cm.get_cmap("viridis")
|
207 |
+
for source in df_similarities_augmented_set.index:
|
208 |
+
for target in df_similarities_augmented_set.columns:
|
209 |
+
if source != target:
|
210 |
+
weight = df_similarities_augmented_set.loc[source, target]
|
211 |
+
color = color_mapper(weight)
|
212 |
+
# Convert from rgba to hex
|
213 |
+
color = matplotlib.colors.to_hex(color)
|
214 |
+
edges_to_show.append(
|
215 |
+
Edge(
|
216 |
+
source=source,
|
217 |
+
target=target,
|
218 |
+
# Dynamic color based on the weight
|
219 |
+
color=color,
|
220 |
+
weight=weight**10,
|
221 |
+
type="CURVE_SMOOTH",
|
222 |
+
label=f"{weight:.2f}",
|
223 |
+
)
|
224 |
+
)
|
225 |
+
graph_of_diseases = agraph(
|
226 |
+
nodes=[
|
227 |
+
Node(
|
228 |
+
id=disease,
|
229 |
+
label=disease,#uris_and_labels_of_diseases[disease],
|
230 |
+
size=25,
|
231 |
+
shape="circular",
|
232 |
+
)
|
233 |
+
for disease in df_similarities_augmented_set.index
|
234 |
+
],
|
235 |
+
edges=edges_to_show,
|
236 |
+
config=Config(height=500, width=500),
|
237 |
+
)
|
238 |
+
time.sleep(2)
|
239 |
+
except Exception as e:
|
240 |
+
print(f"Error while showing the graph of the diseases: {e}")
|
241 |
+
st.error("Error while showing the graph of the diseases.")
|
242 |
+
finally:
|
243 |
+
show_overview = True
|
244 |
|
245 |
|
246 |
# overview
|
|
|
251 |
st.write(disease_overview)
|
252 |
time.sleep(1)
|
253 |
except Exception as e:
|
254 |
+
print(f"Error while showing the overview of the clinical trials: {e}")
|
255 |
finally:
|
256 |
show_metrics = True
|
257 |
|
|
|
260 |
if show_metrics:
|
261 |
try:
|
262 |
st.write("## Metrics of the Clinical Trials")
|
263 |
+
col1, col2, col3 = st.columns(3)
|
264 |
with col1:
|
265 |
st.metric("Average Minimum Age", average_minimum_age)
|
266 |
with col2:
|
|
|
269 |
st.metric("Most Common Gender", most_common_gender)
|
270 |
time.sleep(2)
|
271 |
except Exception as e:
|
272 |
+
print(f"Error while showing the metrics: {e}")
|
273 |
finally:
|
274 |
show_details = True
|
275 |
|
|
|
279 |
if show_details:
|
280 |
st.write("## Clinical Trials Details")
|
281 |
|
282 |
+
tab_titles = [
|
283 |
+
f"{trial['protocolSection']['identificationModule']['nctId']}"
|
284 |
+
for trial in trials
|
285 |
+
]
|
286 |
|
287 |
tabs = st.tabs(tab_titles)
|
288 |
|
|
|
298 |
chosen_disease_name = st.selectbox(
|
299 |
"Choose a disease",
|
300 |
st.session_state.disease_names,
|
301 |
+
)
|
302 |
|
303 |
st.write("You selected:", chosen_disease_name)
|
304 |
chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
|
|
|
306 |
nodes = []
|
307 |
edges = []
|
308 |
|
309 |
+
nodes.append(
|
310 |
+
Node(
|
311 |
+
id=chosen_disease_uri, label=chosen_disease_name, size=25, shape="circular"
|
312 |
+
)
|
|
|
313 |
)
|
314 |
|
315 |
+
similar_diseases = get_most_similar_diseases_from_uri(
|
316 |
+
engine, chosen_disease_uri, threshold=0.6
|
317 |
+
)
|
318 |
print(similar_diseases)
|
319 |
for uri, name, weight in similar_diseases:
|
320 |
+
nodes.append(Node(id=uri, label=name, size=25, shape="circular"))
|
|
|
|
|
|
|
|
|
321 |
|
322 |
print(True if float(weight) > 0.7 else False)
|
323 |
+
edges.append(
|
324 |
+
Edge(
|
325 |
+
source=chosen_disease_uri,
|
326 |
+
target=uri,
|
327 |
+
color="red" if float(weight) > 0.7 else "blue",
|
328 |
+
weight=float(weight) ** 10,
|
329 |
+
type="CURVE_SMOOTH",
|
330 |
+
# type="STRAIGHT"
|
331 |
+
)
|
332 |
+
)
|
333 |
+
|
334 |
+
config = Config(
|
335 |
+
width=750,
|
336 |
+
height=950,
|
337 |
+
directed=False,
|
338 |
+
physics=True,
|
339 |
+
hierarchical=False,
|
340 |
+
collapsible=False,
|
341 |
+
# **kwargs
|
342 |
+
)
|
343 |
|
344 |
+
return_value = agraph(nodes=nodes, edges=edges, config=config)
|
|
|
|
calculate_smilar_nodes.py
CHANGED
@@ -6,6 +6,7 @@ def transe_distance(head, tail, relation, entity_embeddings, relation_embeddings
|
|
6 |
distance = head_embedding + relation_embeddings - tail_embedding
|
7 |
return distance
|
8 |
|
|
|
9 |
def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10):
|
10 |
distances = []
|
11 |
for i in range(len(entity_embeddings)):
|
@@ -14,6 +15,7 @@ def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=
|
|
14 |
distances.sort(key=lambda x: x[1].norm().item())
|
15 |
return distances[:top_n]
|
16 |
|
|
|
17 |
# %%
|
18 |
import pandas as pd
|
19 |
|
@@ -55,9 +57,13 @@ print(
|
|
55 |
)
|
56 |
# %%
|
57 |
# Calculate similar nodes to the head
|
58 |
-
similar_nodes = calculate_similar_nodes(
|
|
|
|
|
59 |
print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):")
|
60 |
# Print the similar nodes
|
61 |
for i, (node, distance) in enumerate(similar_nodes):
|
62 |
-
print(
|
63 |
-
|
|
|
|
|
|
6 |
distance = head_embedding + relation_embeddings - tail_embedding
|
7 |
return distance
|
8 |
|
9 |
+
|
10 |
def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10):
|
11 |
distances = []
|
12 |
for i in range(len(entity_embeddings)):
|
|
|
15 |
distances.sort(key=lambda x: x[1].norm().item())
|
16 |
return distances[:top_n]
|
17 |
|
18 |
+
|
19 |
# %%
|
20 |
import pandas as pd
|
21 |
|
|
|
57 |
)
|
58 |
# %%
|
59 |
# Calculate similar nodes to the head
|
60 |
+
similar_nodes = calculate_similar_nodes(
|
61 |
+
head, entity_embeddings["embedding"], relation_embeddings["embedding"]
|
62 |
+
)
|
63 |
print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):")
|
64 |
# Print the similar nodes
|
65 |
for i, (node, distance) in enumerate(similar_nodes):
|
66 |
+
print(
|
67 |
+
f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}"
|
68 |
+
)
|
69 |
+
# %%
|
llm_res.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
import ast
|
2 |
import json
|
3 |
import os
|
|
|
|
|
4 |
from typing import Any, Dict, List
|
5 |
|
6 |
import langchain
|
7 |
import openai
|
8 |
import pandas as pd
|
|
|
9 |
import requests
|
10 |
from dotenv import load_dotenv
|
11 |
from langchain import OpenAI
|
12 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
|
13 |
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
|
14 |
from langchain.document_loaders import UnstructuredURLLoader
|
15 |
from langchain.embeddings import OpenAIEmbeddings
|
@@ -17,14 +21,9 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
17 |
from langchain.vectorstores import FAISS
|
18 |
from langchain_community.document_loaders import JSONLoader
|
19 |
from langchain_community.document_loaders.csv_loader import CSVLoader
|
20 |
-
from langchain_core.prompts import ChatPromptTemplate
|
21 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
22 |
from langchain_openai import ChatOpenAI
|
23 |
-
from langchain.chains.llm import LLMChain
|
24 |
-
from langchain_core.prompts import PromptTemplate
|
25 |
-
from collections import Counter
|
26 |
-
import statistics
|
27 |
-
import regex as re
|
28 |
|
29 |
load_dotenv()
|
30 |
|
@@ -245,7 +244,7 @@ General summary:"""
|
|
245 |
prompt = PromptTemplate.from_template(prompt_template)
|
246 |
|
247 |
llm = ChatOpenAI(
|
248 |
-
temperature=0.
|
249 |
)
|
250 |
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
251 |
|
@@ -279,8 +278,12 @@ General summary:"""
|
|
279 |
def analyze_data(data):
|
280 |
print(f"Data: {data}")
|
281 |
# Extract minimum and maximum ages: Turn ['18 Years', '20 Years'] into [18, 20]
|
282 |
-
min_ages = [
|
283 |
-
|
|
|
|
|
|
|
|
|
284 |
# primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
|
285 |
|
286 |
# Calculate average minimum and maximum ages
|
@@ -292,13 +295,13 @@ def analyze_data(data):
|
|
292 |
most_common_gender = gender_counter.most_common(1)[0][0]
|
293 |
|
294 |
# Flatten keywords list and find common keywords
|
295 |
-
#keywords = [keyword for sublist in data["keywords"] for keyword in sublist]
|
296 |
-
#common_keywords = [word for word, count in Counter(keywords).most_common()]
|
297 |
|
298 |
return {
|
299 |
"avg_min_age": avg_min_age,
|
300 |
"avg_max_age": avg_max_age,
|
301 |
-
"most_common_gender": most_common_gender
|
302 |
}
|
303 |
|
304 |
|
@@ -379,9 +382,7 @@ def tagging_insights_from_json(data_json):
|
|
379 |
res = tagging_chain.invoke({"input": processed_json})
|
380 |
unprocessed_results_dict = res.get_dict()
|
381 |
|
382 |
-
results_dict = analyze_data(
|
383 |
-
unprocessed_results_dict
|
384 |
-
)
|
385 |
|
386 |
# stats_dict= {'Average Minimum age': avg_min_age,
|
387 |
# 'Average Maximum age': avg_max_age,
|
|
|
1 |
import ast
|
2 |
import json
|
3 |
import os
|
4 |
+
import statistics
|
5 |
+
from collections import Counter
|
6 |
from typing import Any, Dict, List
|
7 |
|
8 |
import langchain
|
9 |
import openai
|
10 |
import pandas as pd
|
11 |
+
import regex as re
|
12 |
import requests
|
13 |
from dotenv import load_dotenv
|
14 |
from langchain import OpenAI
|
15 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
16 |
+
from langchain.chains.llm import LLMChain
|
17 |
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
|
18 |
from langchain.document_loaders import UnstructuredURLLoader
|
19 |
from langchain.embeddings import OpenAIEmbeddings
|
|
|
21 |
from langchain.vectorstores import FAISS
|
22 |
from langchain_community.document_loaders import JSONLoader
|
23 |
from langchain_community.document_loaders.csv_loader import CSVLoader
|
24 |
+
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
25 |
from langchain_core.pydantic_v1 import BaseModel, Field
|
26 |
from langchain_openai import ChatOpenAI
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
load_dotenv()
|
29 |
|
|
|
244 |
prompt = PromptTemplate.from_template(prompt_template)
|
245 |
|
246 |
llm = ChatOpenAI(
|
247 |
+
temperature=0.5, model_name="gpt-4-turbo", api_key=os.environ["OPENAI_API_KEY"]
|
248 |
)
|
249 |
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
250 |
|
|
|
278 |
def analyze_data(data):
|
279 |
print(f"Data: {data}")
|
280 |
# Extract minimum and maximum ages: Turn ['18 Years', '20 Years'] into [18, 20]
|
281 |
+
min_ages = [
|
282 |
+
int(re.search(r"\d+", age).group()) for age in data["minimum_age"] if age
|
283 |
+
]
|
284 |
+
max_ages = [
|
285 |
+
int(re.search(r"\d+", age).group()) for age in data["maximum_age"] if age
|
286 |
+
]
|
287 |
# primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
|
288 |
|
289 |
# Calculate average minimum and maximum ages
|
|
|
295 |
most_common_gender = gender_counter.most_common(1)[0][0]
|
296 |
|
297 |
# Flatten keywords list and find common keywords
|
298 |
+
# keywords = [keyword for sublist in data["keywords"] for keyword in sublist]
|
299 |
+
# common_keywords = [word for word, count in Counter(keywords).most_common()]
|
300 |
|
301 |
return {
|
302 |
"avg_min_age": avg_min_age,
|
303 |
"avg_max_age": avg_max_age,
|
304 |
+
"most_common_gender": most_common_gender,
|
305 |
}
|
306 |
|
307 |
|
|
|
382 |
res = tagging_chain.invoke({"input": processed_json})
|
383 |
unprocessed_results_dict = res.get_dict()
|
384 |
|
385 |
+
results_dict = analyze_data(unprocessed_results_dict)
|
|
|
|
|
386 |
|
387 |
# stats_dict= {'Average Minimum age': avg_min_age,
|
388 |
# 'Average Maximum age': avg_max_age,
|
main.ipynb
CHANGED
@@ -245,52 +245,55 @@
|
|
245 |
}
|
246 |
],
|
247 |
"source": [
|
248 |
-
"df_summary = pd.read_csv(
|
249 |
-
"df_summary = df_summary.rename(columns={
|
250 |
"\n",
|
251 |
"### create and merge intervention ###\n",
|
252 |
-
"df_intervention = pd.read_csv(
|
253 |
"\n",
|
254 |
-
"intervention_grouped =
|
255 |
-
"
|
|
|
|
|
|
|
|
|
256 |
"merged_df = pd.merge(\n",
|
257 |
-
" df_summary[[
|
258 |
-
" intervention_grouped[[
|
259 |
-
" on
|
|
|
260 |
"\n",
|
261 |
-
"df_intervention = df_intervention.rename(
|
|
|
|
|
262 |
"\n",
|
263 |
"merged_df = pd.merge(\n",
|
264 |
" merged_df,\n",
|
265 |
-
" df_intervention[[
|
266 |
-
" on
|
|
|
267 |
"\n",
|
268 |
"### create and merge keywords ###\n",
|
269 |
-
"df_keyword = pd.read_csv(
|
270 |
-
"keywords_grouped = df_keyword.groupby(
|
271 |
-
"keywords_grouped = keywords_grouped.rename(columns={
|
272 |
"\n",
|
273 |
-
"merged_df = pd.merge(\n",
|
274 |
-
" merged_df,\n",
|
275 |
-
" keywords_grouped,\n",
|
276 |
-
" on='nct_id'\n",
|
277 |
-
")\n",
|
278 |
"\n",
|
279 |
"### create and merge browse conditions\n",
|
280 |
-
"df_condition = pd.read_csv(
|
281 |
-
"conditions_grouped =
|
282 |
-
"
|
283 |
-
"\n",
|
284 |
-
"
|
285 |
-
"
|
286 |
-
" conditions_grouped,\n",
|
287 |
-
" on='nct_id'\n",
|
288 |
")\n",
|
289 |
"\n",
|
290 |
-
"merged_df =
|
291 |
"\n",
|
292 |
-
"merged_df.
|
293 |
-
"\n"
|
|
|
294 |
]
|
295 |
},
|
296 |
{
|
|
|
245 |
}
|
246 |
],
|
247 |
"source": [
|
248 |
+
"df_summary = pd.read_csv(\"file_db/brief_summaries.txt\", delimiter=\"|\")\n",
|
249 |
+
"df_summary = df_summary.rename(columns={\"description\": \"summary\"})\n",
|
250 |
"\n",
|
251 |
"### create and merge intervention ###\n",
|
252 |
+
"df_intervention = pd.read_csv(\"file_db/interventions.txt\", delimiter=\"|\")\n",
|
253 |
"\n",
|
254 |
+
"intervention_grouped = (\n",
|
255 |
+
" df_intervention.groupby(\"nct_id\")[\"name\"].apply(list).reset_index()\n",
|
256 |
+
")\n",
|
257 |
+
"intervention_grouped = intervention_grouped.rename(\n",
|
258 |
+
" columns={\"name\": \"intervention_name\"}\n",
|
259 |
+
")\n",
|
260 |
"merged_df = pd.merge(\n",
|
261 |
+
" df_summary[[\"nct_id\", \"summary\"]],\n",
|
262 |
+
" intervention_grouped[[\"nct_id\", \"intervention_name\"]],\n",
|
263 |
+
" on=\"nct_id\",\n",
|
264 |
+
")\n",
|
265 |
"\n",
|
266 |
+
"df_intervention = df_intervention.rename(\n",
|
267 |
+
" columns={\"description\": \"intervention_description\"}\n",
|
268 |
+
")\n",
|
269 |
"\n",
|
270 |
"merged_df = pd.merge(\n",
|
271 |
" merged_df,\n",
|
272 |
+
" df_intervention[[\"nct_id\", \"intervention_type\", \"intervention_description\"]],\n",
|
273 |
+
" on=\"nct_id\",\n",
|
274 |
+
")\n",
|
275 |
"\n",
|
276 |
"### create and merge keywords ###\n",
|
277 |
+
"df_keyword = pd.read_csv(\"file_db/keywords.txt\", delimiter=\"|\")\n",
|
278 |
+
"keywords_grouped = df_keyword.groupby(\"nct_id\")[\"name\"].apply(list).reset_index()\n",
|
279 |
+
"keywords_grouped = keywords_grouped.rename(columns={\"name\": \"keywords\"})\n",
|
280 |
"\n",
|
281 |
+
"merged_df = pd.merge(merged_df, keywords_grouped, on=\"nct_id\")\n",
|
|
|
|
|
|
|
|
|
282 |
"\n",
|
283 |
"### create and merge browse conditions\n",
|
284 |
+
"df_condition = pd.read_csv(\"file_db/browse_conditions.txt\", delimiter=\"|\")\n",
|
285 |
+
"conditions_grouped = (\n",
|
286 |
+
" df_condition.groupby(\"nct_id\")[\"downcase_mesh_term\"].apply(list).reset_index()\n",
|
287 |
+
")\n",
|
288 |
+
"conditions_grouped = conditions_grouped.rename(\n",
|
289 |
+
" columns={\"downcase_mesh_term\": \"desease_condition\"}\n",
|
|
|
|
|
290 |
")\n",
|
291 |
"\n",
|
292 |
+
"merged_df = pd.merge(merged_df, conditions_grouped, on=\"nct_id\")\n",
|
293 |
"\n",
|
294 |
+
"merged_df = merged_df.drop_duplicates(subset=\"nct_id\")\n",
|
295 |
+
"\n",
|
296 |
+
"merged_df.head()"
|
297 |
]
|
298 |
},
|
299 |
{
|
utils.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
# %%
|
2 |
-
from typing import List, Dict, Any
|
3 |
import os
|
4 |
-
from
|
|
|
|
|
5 |
import requests
|
6 |
-
from sentence_transformers import SentenceTransformer
|
7 |
import streamlit as st
|
8 |
-
|
|
|
9 |
|
10 |
username = "demo"
|
11 |
password = "demo"
|
@@ -124,16 +125,19 @@ def get_similarities_among_diseases_uris(
|
|
124 |
result = conn.execute(text(sql))
|
125 |
data = result.fetchall()
|
126 |
|
127 |
-
return [
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
132 |
|
133 |
|
134 |
def augment_the_set_of_diseaces(diseases: List[str]) -> str:
|
135 |
augmented_diseases = diseases.copy()
|
136 |
-
for i in range(
|
137 |
with engine.connect() as conn:
|
138 |
with conn.begin():
|
139 |
sql = f"""
|
@@ -153,6 +157,7 @@ def augment_the_set_of_diseaces(diseases: List[str]) -> str:
|
|
153 |
|
154 |
return augmented_diseases
|
155 |
|
|
|
156 |
def get_embedding(string: str, encoder) -> List[float]:
|
157 |
# Embed the string using sentence-transformers
|
158 |
vector = encoder.encode(string, show_progress_bar=False)
|
@@ -176,11 +181,14 @@ def get_diseases_related_to_a_textual_description(
|
|
176 |
result = conn.execute(text(sql))
|
177 |
data = result.fetchall()
|
178 |
|
179 |
-
return [
|
|
|
|
|
|
|
|
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
) -> List[str]:
|
184 |
# Embed the diseases using sentence-transformers
|
185 |
diseases_string = ", ".join(diseases)
|
186 |
disease_embedding = get_embedding(diseases_string, encoder)
|
@@ -189,7 +197,7 @@ def get_clinical_trials_related_to_diseases(
|
|
189 |
with engine.connect() as conn:
|
190 |
with conn.begin():
|
191 |
sql = f"""
|
192 |
-
SELECT TOP
|
193 |
FROM Test.ClinicalTrials d
|
194 |
ORDER BY distance DESC
|
195 |
"""
|
@@ -198,82 +206,139 @@ def get_clinical_trials_related_to_diseases(
|
|
198 |
|
199 |
return [{"nct_id": row[0], "distance": row[1]} for row in data]
|
200 |
|
201 |
-
|
|
|
202 |
# Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2
|
203 |
-
df_diseases_similarities = pd.DataFrame(
|
204 |
# Use uri1 as the index, and uri2 as the columns. The values are the distances.
|
205 |
-
df_diseases_similarities = df_diseases_similarities.pivot(
|
|
|
|
|
206 |
# Fill the diagonal with 1.0
|
207 |
df_diseases_similarities = df_diseases_similarities.fillna(1.0)
|
208 |
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
mean = df_diseases_similarities.mean().mean()
|
211 |
std = df_diseases_similarities.mean().std()
|
212 |
-
filtered_diseases = df_diseases_similarities.mean()[
|
|
|
|
|
213 |
return filtered_diseases, df_diseases_similarities
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
def to_capitalized_case(string: str) -> str:
|
216 |
string = string.replace("_", " ")
|
217 |
if string.isupper():
|
218 |
return string[0] + string[1:].lower()
|
219 |
-
|
|
|
220 |
def list_to_capitalized_case(strings: List[str]) -> str:
|
221 |
strings = [to_capitalized_case(s) for s in strings]
|
222 |
return ", ".join(strings)
|
223 |
|
|
|
224 |
def render_trial_details(trial: dict) -> None:
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
st.
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
if __name__ == "__main__":
|
279 |
username = "demo"
|
|
|
1 |
# %%
|
|
|
2 |
import os
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
import requests
|
|
|
7 |
import streamlit as st
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from sqlalchemy import create_engine, text
|
10 |
|
11 |
username = "demo"
|
12 |
password = "demo"
|
|
|
125 |
result = conn.execute(text(sql))
|
126 |
data = result.fetchall()
|
127 |
|
128 |
+
return [
|
129 |
+
{
|
130 |
+
"uri1": row[0].split("/")[-1],
|
131 |
+
"uri2": row[1].split("/")[-1],
|
132 |
+
"distance": float(row[2]),
|
133 |
+
}
|
134 |
+
for row in data
|
135 |
+
]
|
136 |
|
137 |
|
138 |
def augment_the_set_of_diseaces(diseases: List[str]) -> str:
|
139 |
augmented_diseases = diseases.copy()
|
140 |
+
for i in range(10 - len(augmented_diseases)):
|
141 |
with engine.connect() as conn:
|
142 |
with conn.begin():
|
143 |
sql = f"""
|
|
|
157 |
|
158 |
return augmented_diseases
|
159 |
|
160 |
+
|
161 |
def get_embedding(string: str, encoder) -> List[float]:
|
162 |
# Embed the string using sentence-transformers
|
163 |
vector = encoder.encode(string, show_progress_bar=False)
|
|
|
181 |
result = conn.execute(text(sql))
|
182 |
data = result.fetchall()
|
183 |
|
184 |
+
return [
|
185 |
+
{"uri": row[0], "distance": float(row[1])}
|
186 |
+
for row in data
|
187 |
+
if float(row[1]) > 0.8
|
188 |
+
]
|
189 |
|
190 |
+
|
191 |
+
def get_clinical_trials_related_to_diseases(diseases: List[str], encoder) -> List[str]:
|
|
|
192 |
# Embed the diseases using sentence-transformers
|
193 |
diseases_string = ", ".join(diseases)
|
194 |
disease_embedding = get_embedding(diseases_string, encoder)
|
|
|
197 |
with engine.connect() as conn:
|
198 |
with conn.begin():
|
199 |
sql = f"""
|
200 |
+
SELECT TOP 20 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
|
201 |
FROM Test.ClinicalTrials d
|
202 |
ORDER BY distance DESC
|
203 |
"""
|
|
|
206 |
|
207 |
return [{"nct_id": row[0], "distance": row[1]} for row in data]
|
208 |
|
209 |
+
|
210 |
+
def get_similarities_df(diseases: List[Dict[str, Any]]) -> pd.DataFrame:
|
211 |
# Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2
|
212 |
+
df_diseases_similarities = pd.DataFrame(diseases)
|
213 |
# Use uri1 as the index, and uri2 as the columns. The values are the distances.
|
214 |
+
df_diseases_similarities = df_diseases_similarities.pivot(
|
215 |
+
index="uri1", columns="uri2", values="distance"
|
216 |
+
)
|
217 |
# Fill the diagonal with 1.0
|
218 |
df_diseases_similarities = df_diseases_similarities.fillna(1.0)
|
219 |
|
220 |
+
return df_diseases_similarities
|
221 |
+
|
222 |
+
|
223 |
+
def filter_out_less_promising_diseases(info_dicts: List[Dict[str, Any]]) -> List[str]:
|
224 |
+
df_diseases_similarities = get_similarities_df(info_dicts)
|
225 |
+
|
226 |
+
# Filter out the diseases that are 0.2 standard deviations below the mean
|
227 |
mean = df_diseases_similarities.mean().mean()
|
228 |
std = df_diseases_similarities.mean().std()
|
229 |
+
filtered_diseases = df_diseases_similarities.mean()[
|
230 |
+
df_diseases_similarities.mean() > mean - 0.2 * std
|
231 |
+
].index.tolist()
|
232 |
return filtered_diseases, df_diseases_similarities
|
233 |
|
234 |
+
|
235 |
+
def get_labels_of_diseases_from_uris(uris: List[str]) -> List[str]:
|
236 |
+
with engine.connect() as conn:
|
237 |
+
with conn.begin():
|
238 |
+
joined_uris = ", ".join([f"'{uri}'" for uri in uris])
|
239 |
+
sql = f"""
|
240 |
+
SELECT label FROM Test.EntityEmbeddings
|
241 |
+
WHERE uri IN ({joined_uris})
|
242 |
+
"""
|
243 |
+
result = conn.execute(text(sql))
|
244 |
+
data = result.fetchall()
|
245 |
+
|
246 |
+
return [row[0] for row in data]
|
247 |
+
|
248 |
+
|
249 |
def to_capitalized_case(string: str) -> str:
|
250 |
string = string.replace("_", " ")
|
251 |
if string.isupper():
|
252 |
return string[0] + string[1:].lower()
|
253 |
+
|
254 |
+
|
255 |
def list_to_capitalized_case(strings: List[str]) -> str:
|
256 |
strings = [to_capitalized_case(s) for s in strings]
|
257 |
return ", ".join(strings)
|
258 |
|
259 |
+
|
260 |
def render_trial_details(trial: dict) -> None:
|
261 |
+
# TODO: handle key errors for all cases (→ do not render)
|
262 |
+
|
263 |
+
official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
|
264 |
+
st.write(f"##### {official_title}")
|
265 |
+
|
266 |
+
try:
|
267 |
+
st.write(trial["protocolSection"]["descriptionModule"]["briefSummary"])
|
268 |
+
except KeyError:
|
269 |
+
try:
|
270 |
+
st.write(
|
271 |
+
trial["protocolSection"]["descriptionModule"]["detailedDescription"]
|
272 |
+
)
|
273 |
+
except KeyError:
|
274 |
+
st.error("No description available.")
|
275 |
+
|
276 |
+
st.write("###### Status")
|
277 |
+
try:
|
278 |
+
status_module = {
|
279 |
+
"Status": to_capitalized_case(
|
280 |
+
trial["protocolSection"]["statusModule"]["overallStatus"]
|
281 |
+
),
|
282 |
+
"Status Date": trial["protocolSection"]["statusModule"][
|
283 |
+
"statusVerifiedDate"
|
284 |
+
],
|
285 |
+
"Has Results": trial["hasResults"],
|
286 |
+
}
|
287 |
+
st.table(status_module)
|
288 |
+
except KeyError:
|
289 |
+
st.info("No status information available.")
|
290 |
+
|
291 |
+
st.write("###### Design")
|
292 |
+
try:
|
293 |
+
design_module = {
|
294 |
+
"Study Type": to_capitalized_case(
|
295 |
+
trial["protocolSection"]["designModule"]["studyType"]
|
296 |
+
),
|
297 |
+
"Phases": list_to_capitalized_case(
|
298 |
+
trial["protocolSection"]["designModule"]["phases"]
|
299 |
+
),
|
300 |
+
"Allocation": to_capitalized_case(
|
301 |
+
trial["protocolSection"]["designModule"]["designInfo"]["allocation"]
|
302 |
+
),
|
303 |
+
"Primary Purpose": to_capitalized_case(
|
304 |
+
trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]
|
305 |
+
),
|
306 |
+
"Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
|
307 |
+
"count"
|
308 |
+
],
|
309 |
+
"Masking": to_capitalized_case(
|
310 |
+
trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][
|
311 |
+
"masking"
|
312 |
+
]
|
313 |
+
),
|
314 |
+
"Who Masked": list_to_capitalized_case(
|
315 |
+
trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][
|
316 |
+
"whoMasked"
|
317 |
+
]
|
318 |
+
),
|
319 |
+
}
|
320 |
+
st.table(design_module)
|
321 |
+
except KeyError:
|
322 |
+
st.info("No design information available.")
|
323 |
+
|
324 |
+
st.write("###### Interventions")
|
325 |
+
try:
|
326 |
+
interventions_module = {}
|
327 |
+
for intervention in trial["protocolSection"]["armsInterventionsModule"][
|
328 |
+
"interventions"
|
329 |
+
]:
|
330 |
+
name = intervention["name"]
|
331 |
+
desc = intervention["description"]
|
332 |
+
interventions_module[name] = desc
|
333 |
+
st.table(interventions_module)
|
334 |
+
except KeyError:
|
335 |
+
st.info("No interventions information available.")
|
336 |
+
|
337 |
+
# Button to go to ClinicalTrials.gov and see the trial. It takes the user to the official page of the trial.
|
338 |
+
st.markdown(
|
339 |
+
f"See more in [ClinicalTrials.gov](https://clinicaltrials.gov/study/{trial['protocolSection']['identificationModule']['nctId']})"
|
340 |
+
)
|
341 |
+
|
342 |
|
343 |
if __name__ == "__main__":
|
344 |
username = "demo"
|