ACMCMC commited on
Commit
27d40b9
1 Parent(s): 7833461

UI Changes

Browse files
Files changed (3) hide show
  1. app.py +72 -24
  2. database.ipynb +83 -2
  3. utils.py +7 -11
app.py CHANGED
@@ -3,58 +3,106 @@ from streamlit_agraph import agraph, Node, Edge, Config
3
  import os
4
  from sqlalchemy import create_engine, text
5
  import pandas as pd
6
- from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description, get_similarities_among_diseases_uris
 
 
 
 
 
 
 
 
 
 
7
  import json
8
  import numpy as np
9
  from sentence_transformers import SentenceTransformer
10
 
11
 
12
- username = 'demo'
13
- password = 'demo'
14
- hostname = os.getenv('IRIS_HOSTNAME', 'localhost')
15
- port = '1972'
16
- namespace = 'USER'
 
 
 
17
  CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
18
  engine = create_engine(CONNECTION_STRING)
19
 
20
- def handle_click_on_analyze_button(user_text):
 
 
 
 
 
 
21
  # 1. Embed the textual description that the user entered using the model
22
  # 2. Get 5 diseases with the highest cosine silimarity from the DB
23
  encoder = SentenceTransformer("allenai-specter")
24
- diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(user_text, encoder)
25
- #for disease_label in diseases_related_to_the_user_text:
 
 
26
  # st.text(disease_label)
27
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
28
- diseases_uris = [disease['uri'] for disease in diseases_related_to_the_user_text]
29
  get_similarities_among_diseases_uris(diseases_uris)
30
  print(diseases_related_to_the_user_text)
31
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
32
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
 
 
33
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
34
- # 7. Use an LLM to get a summary of the clinical trials, in plain text format
 
 
 
 
 
 
 
35
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
36
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
37
- pass
38
-
39
-
40
- st.write("# Klìnic")
41
-
42
- description_input = st.text_input(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')
43
- if st.button("Analyze"):
44
- handle_click_on_analyze_button(description_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # TODO: also when user clicks enter
46
 
47
- st.write(":red[Here should be the graph]") # TODO remove
48
  chart_data = pd.DataFrame(
49
  np.random.randn(20, 3), columns=["a", "b", "c"]
50
  ) # TODO remove
51
- st.scatter_chart(chart_data) # TODO remove
52
 
53
- st.write("## Disease Overview")
54
  disease_overview = ":red[lorem ipsum]" # TODO
55
- st.write(disease_overview)
56
 
57
- st.write("## Clinical Trials Details")
58
  trials = []
59
  # TODO replace mock data
60
  with open("mock_trial.json") as f:
 
3
  import os
4
  from sqlalchemy import create_engine, text
5
  import pandas as pd
6
+ import time
7
+ from utils import (
8
+ get_all_diseases_name,
9
+ get_most_similar_diseases_from_uri,
10
+ get_uri_from_name,
11
+ get_diseases_related_to_a_textual_description,
12
+ get_similarities_among_diseases_uris,
13
+ augment_the_set_of_diseaces,
14
+ get_clinical_trials_related_to_diseases,
15
+ get_clinical_records_by_ids
16
+ )
17
  import json
18
  import numpy as np
19
  from sentence_transformers import SentenceTransformer
20
 
21
 
22
+ begin = st.container()
23
+
24
+
25
+ username = "demo"
26
+ password = "demo"
27
+ hostname = os.getenv("IRIS_HOSTNAME", "localhost")
28
+ port = "1972"
29
+ namespace = "USER"
30
  CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
31
  engine = create_engine(CONNECTION_STRING)
32
 
33
+ begin.write("# Klìnic")
34
+
35
+ description_input = begin.text_input(
36
+ label="Enter the disease description 👇",
37
+ placeholder="A disease that causes memory loss and other cognitive impairments.",
38
+ )
39
+ if begin.button("Analyze 🔎"):
40
  # 1. Embed the textual description that the user entered using the model
41
  # 2. Get 5 diseases with the highest cosine silimarity from the DB
42
  encoder = SentenceTransformer("allenai-specter")
43
+ diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
44
+ description_input, encoder
45
+ )
46
+ # for disease_label in diseases_related_to_the_user_text:
47
  # st.text(disease_label)
48
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
49
+ diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
50
  get_similarities_among_diseases_uris(diseases_uris)
51
  print(diseases_related_to_the_user_text)
52
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
53
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
54
+ augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
55
+ print(augmented_set_of_diseases)
56
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
57
+ clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
58
+ augmented_set_of_diseases, encoder
59
+ )
60
+ print(f'clinical_trials_related_to_the_diseases: {clinical_trials_related_to_the_diseases}')
61
+ json_of_clinical_trials = get_clinical_records_by_ids(
62
+ [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
63
+ )
64
+ print(f'json_of_clinical_trials: {json_of_clinical_trials}')
65
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
66
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
67
+ graph_of_diseases = agraph(
68
+ nodes=[
69
+ Node(id="A", label="Node A", size=10),
70
+ Node(id="B", label="Node B", size=10),
71
+ Node(id="C", label="Node C", size=10),
72
+ Node(id="D", label="Node D", size=10),
73
+ Node(id="E", label="Node E", size=10),
74
+ Node(id="F", label="Node F", size=10),
75
+ Node(id="G", label="Node G", size=10),
76
+ Node(id="H", label="Node H", size=10),
77
+ Node(id="I", label="Node I", size=10),
78
+ Node(id="J", label="Node J", size=10),
79
+ ],
80
+ edges=[
81
+ Edge(source="A", target="B"),
82
+ Edge(source="B", target="C"),
83
+ Edge(source="C", target="D"),
84
+ Edge(source="D", target="E"),
85
+ Edge(source="E", target="F"),
86
+ Edge(source="F", target="G"),
87
+ Edge(source="G", target="H"),
88
+ Edge(source="H", target="I"),
89
+ Edge(source="I", target="J"),
90
+ ],
91
+ config=Config(height=500, width=500),
92
+ )
93
  # TODO: also when user clicks enter
94
 
95
+ begin.write(":red[Here should be the graph]") # TODO remove
96
  chart_data = pd.DataFrame(
97
  np.random.randn(20, 3), columns=["a", "b", "c"]
98
  ) # TODO remove
99
+ begin.scatter_chart(chart_data) # TODO remove
100
 
101
+ begin.write("## Disease Overview")
102
  disease_overview = ":red[lorem ipsum]" # TODO
103
+ begin.write(disease_overview)
104
 
105
+ begin.write("## Clinical Trials Details")
106
  trials = []
107
  # TODO replace mock data
108
  with open("mock_trial.json") as f:
database.ipynb CHANGED
@@ -288,9 +288,90 @@
288
  },
289
  {
290
  "cell_type": "code",
291
- "execution_count": null,
292
  "metadata": {},
293
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  "source": [
295
  "# Load knowledge graph\n",
296
  "clinical_trials = pd.read_csv(\"clinical_trials_embeddings.csv\")\n",
 
288
  },
289
  {
290
  "cell_type": "code",
291
+ "execution_count": 22,
292
  "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/html": [
297
+ "<div>\n",
298
+ "<style scoped>\n",
299
+ " .dataframe tbody tr th:only-of-type {\n",
300
+ " vertical-align: middle;\n",
301
+ " }\n",
302
+ "\n",
303
+ " .dataframe tbody tr th {\n",
304
+ " vertical-align: top;\n",
305
+ " }\n",
306
+ "\n",
307
+ " .dataframe thead th {\n",
308
+ " text-align: right;\n",
309
+ " }\n",
310
+ "</style>\n",
311
+ "<table border=\"1\" class=\"dataframe\">\n",
312
+ " <thead>\n",
313
+ " <tr style=\"text-align: right;\">\n",
314
+ " <th></th>\n",
315
+ " <th>desease_condition</th>\n",
316
+ " <th>embeddings</th>\n",
317
+ " <th>nct_id</th>\n",
318
+ " </tr>\n",
319
+ " </thead>\n",
320
+ " <tbody>\n",
321
+ " <tr>\n",
322
+ " <th>0</th>\n",
323
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
324
+ " <td>-0.8323991298675537, 1.47855544090271, 0.00130...</td>\n",
325
+ " <td>NCT03055377</td>\n",
326
+ " </tr>\n",
327
+ " <tr>\n",
328
+ " <th>1</th>\n",
329
+ " <td>tuberculosis, latent tuberculosis, infections,...</td>\n",
330
+ " <td>-0.43443307280540466, 0.9625586271286011, -0.1...</td>\n",
331
+ " <td>NCT03042754</td>\n",
332
+ " </tr>\n",
333
+ " <tr>\n",
334
+ " <th>2</th>\n",
335
+ " <td>heart failure, heart diseases, cardiovascular ...</td>\n",
336
+ " <td>-0.5791705250740051, 0.13008448481559753, 0.13...</td>\n",
337
+ " <td>NCT03035123</td>\n",
338
+ " </tr>\n",
339
+ " <tr>\n",
340
+ " <th>3</th>\n",
341
+ " <td>lymphoma, neoplasms by histologic type, neopla...</td>\n",
342
+ " <td>-0.1608569175004959, 0.8489153981208801, -0.55...</td>\n",
343
+ " <td>NCT02272751</td>\n",
344
+ " </tr>\n",
345
+ " <tr>\n",
346
+ " <th>4</th>\n",
347
+ " <td>anemia, hematologic diseases</td>\n",
348
+ " <td>0.21379394829273224, 0.17073844373226166, -0.1...</td>\n",
349
+ " <td>NCT00931606</td>\n",
350
+ " </tr>\n",
351
+ " </tbody>\n",
352
+ "</table>\n",
353
+ "</div>"
354
+ ],
355
+ "text/plain": [
356
+ " desease_condition \\\n",
357
+ "0 marijuana abuse, substance-related disorders, ... \n",
358
+ "1 tuberculosis, latent tuberculosis, infections,... \n",
359
+ "2 heart failure, heart diseases, cardiovascular ... \n",
360
+ "3 lymphoma, neoplasms by histologic type, neopla... \n",
361
+ "4 anemia, hematologic diseases \n",
362
+ "\n",
363
+ " embeddings nct_id \n",
364
+ "0 -0.8323991298675537, 1.47855544090271, 0.00130... NCT03055377 \n",
365
+ "1 -0.43443307280540466, 0.9625586271286011, -0.1... NCT03042754 \n",
366
+ "2 -0.5791705250740051, 0.13008448481559753, 0.13... NCT03035123 \n",
367
+ "3 -0.1608569175004959, 0.8489153981208801, -0.55... NCT02272751 \n",
368
+ "4 0.21379394829273224, 0.17073844373226166, -0.1... NCT00931606 "
369
+ ]
370
+ },
371
+ "metadata": {},
372
+ "output_type": "display_data"
373
+ }
374
+ ],
375
  "source": [
376
  "# Load knowledge graph\n",
377
  "clinical_trials = pd.read_csv(\"clinical_trials_embeddings.csv\")\n",
utils.py CHANGED
@@ -123,16 +123,16 @@ def get_similarities_among_diseases_uris(
123
  return data
124
 
125
 
126
- def augment_the_set_of_diseaces(engine, diseases: List[str]) -> str:
127
-
128
  for i in range(15-len(diseases)):
129
  with engine.connect() as conn:
130
  with conn.begin():
131
  sql = f"""
132
  SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(diseases)}) AS score
133
  FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
134
- WHERE e1.uri IN ({','.join([f"'http://identifiers.org/medgen/{disease}'" for disease in diseases])})
135
- AND e2.uri NOT IN ({','.join([f"'http://identifiers.org/medgen/{disease}'" for disease in diseases])})
136
  AND e2.label != 'nan'
137
  GROUP BY e2.label
138
  ORDER BY score DESC
@@ -156,9 +156,7 @@ def get_diseases_related_to_a_textual_description(
156
  ) -> List[str]:
157
  # Embed the description using sentence-transformers
158
  description_embedding = get_embedding(description, encoder)
159
- print(f"Size of the embedding: {len(description_embedding)}")
160
  string_representation = str(description_embedding.tolist())[1:-1]
161
- print(f"String representation: {string_representation}")
162
 
163
  with engine.connect() as conn:
164
  with conn.begin():
@@ -172,27 +170,25 @@ def get_diseases_related_to_a_textual_description(
172
 
173
  return [{"uri": row[0], "distance": row[1]} for row in data]
174
 
175
- def get_diseases_related_to_clinical_trials(
176
  diseases: List[str], encoder
177
  ) -> List[str]:
178
  # Embed the diseases using sentence-transformers
179
  diseases_string = ", ".join(diseases)
180
  disease_embedding = get_embedding(diseases_string, encoder)
181
- print(f"Size of the embedding: {len(disease_embedding)}")
182
  string_representation = str(disease_embedding.tolist())[1:-1]
183
- print(f"String representation: {string_representation}")
184
 
185
  with engine.connect() as conn:
186
  with conn.begin():
187
  sql = f"""
188
- SELECT TOP 5 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
189
  FROM Test.ClinicalTrials d
190
  ORDER BY distance DESC
191
  """
192
  result = conn.execute(text(sql))
193
  data = result.fetchall()
194
 
195
- return [{"uri": row[0], "distance": row[1]} for row in data]
196
 
197
 
198
  if __name__ == "__main__":
 
123
  return data
124
 
125
 
126
+ def augment_the_set_of_diseaces(diseases: List[str]) -> str:
127
+ print(diseases)
128
  for i in range(15-len(diseases)):
129
  with engine.connect() as conn:
130
  with conn.begin():
131
  sql = f"""
132
  SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(diseases)}) AS score
133
  FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
134
+ WHERE e1.uri IN ({','.join([f"'{disease}'" for disease in diseases])})
135
+ AND e2.uri NOT IN ({','.join([f"'{disease}'" for disease in diseases])})
136
  AND e2.label != 'nan'
137
  GROUP BY e2.label
138
  ORDER BY score DESC
 
156
  ) -> List[str]:
157
  # Embed the description using sentence-transformers
158
  description_embedding = get_embedding(description, encoder)
 
159
  string_representation = str(description_embedding.tolist())[1:-1]
 
160
 
161
  with engine.connect() as conn:
162
  with conn.begin():
 
170
 
171
  return [{"uri": row[0], "distance": row[1]} for row in data]
172
 
173
+ def get_clinical_trials_related_to_diseases(
174
  diseases: List[str], encoder
175
  ) -> List[str]:
176
  # Embed the diseases using sentence-transformers
177
  diseases_string = ", ".join(diseases)
178
  disease_embedding = get_embedding(diseases_string, encoder)
 
179
  string_representation = str(disease_embedding.tolist())[1:-1]
 
180
 
181
  with engine.connect() as conn:
182
  with conn.begin():
183
  sql = f"""
184
+ SELECT TOP 5 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
185
  FROM Test.ClinicalTrials d
186
  ORDER BY distance DESC
187
  """
188
  result = conn.execute(text(sql))
189
  data = result.fetchall()
190
 
191
+ return [{"nct_id": row[0], "distance": row[1]} for row in data]
192
 
193
 
194
  if __name__ == "__main__":