ACMCMC commited on
Commit
1f35211
1 Parent(s): 2408e3d
Files changed (2) hide show
  1. app.py +15 -4
  2. utils.py +47 -16
app.py CHANGED
@@ -3,8 +3,10 @@ from streamlit_agraph import agraph, Node, Edge, Config
3
  import os
4
  from sqlalchemy import create_engine, text
5
  import pandas as pd
6
- from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description
7
  import json
 
 
8
 
9
 
10
  username = 'demo'
@@ -15,11 +17,17 @@ namespace = 'USER'
15
  CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
16
  engine = create_engine(CONNECTION_STRING)
17
 
18
- def handle_click_on_analyze_button():
19
  # 1. Embed the textual description that the user entered using the model
20
- diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(description_input)
21
  # 2. Get 5 diseases with the highest cosine silimarity from the DB
 
 
 
 
22
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
 
 
 
23
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
24
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
25
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
@@ -31,7 +39,10 @@ def handle_click_on_analyze_button():
31
 
32
  st.write("# Klìnic")
33
 
34
- description_input = st.text_input(label="Enter the disease description 👇")
 
 
 
35
 
36
  st.write(":red[Here should be the graph]") # TODO remove
37
  chart_data = pd.DataFrame(
 
3
  import os
4
  from sqlalchemy import create_engine, text
5
  import pandas as pd
6
+ from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description, get_similarities_among_diseases_uris
7
  import json
8
+ import numpy as np
9
+ from sentence_transformers import SentenceTransformer
10
 
11
 
12
  username = 'demo'
 
17
  CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
18
  engine = create_engine(CONNECTION_STRING)
19
 
20
+ def handle_click_on_analyze_button(user_text):
21
  # 1. Embed the textual description that the user entered using the model
 
22
  # 2. Get 5 diseases with the highest cosine silimarity from the DB
23
+ encoder = SentenceTransformer("allenai-specter")
24
+ diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(user_text, encoder)
25
+ #for disease_label in diseases_related_to_the_user_text:
26
+ # st.text(disease_label)
27
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
28
+ diseases_uris = [disease['uri'] for disease in diseases_related_to_the_user_text]
29
+ get_similarities_among_diseases_uris(diseases_uris)
30
+ print(diseases_related_to_the_user_text)
31
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
32
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
33
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
 
39
 
40
  st.write("# Klìnic")
41
 
42
+ description_input = st.text_input(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')
43
+ if st.button("Analyze"):
44
+ handle_click_on_analyze_button(description_input)
45
+ # TODO: also when user clicks enter
46
 
47
  st.write(":red[Here should be the graph]") # TODO remove
48
  chart_data = pd.DataFrame(
utils.py CHANGED
@@ -5,6 +5,15 @@ from sqlalchemy import create_engine, text
5
  import requests
6
  from sentence_transformers import SentenceTransformer
7
 
 
 
 
 
 
 
 
 
 
8
 
9
  def get_all_diseases_name(engine) -> List[List[str]]:
10
  with engine.connect() as conn:
@@ -98,46 +107,48 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
98
  return clinical_records
99
 
100
 
101
- def get_uris_of_similar_diseases(uri_list: List[str]) -> List[tuple[str, str, float]]:
102
- uri_list = tuple(uri_list)
 
 
103
  with engine.connect() as conn:
104
  with conn.begin():
105
  sql = f"""
106
  SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
107
  FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
108
- WHERE e1.uri IN {uri_list} AND e2.uri IN {uri_list} AND e1.uri != e2.uri
109
  """
110
  result = conn.execute(text(sql))
111
  data = result.fetchall()
112
  return data
113
 
114
 
115
- encoder = SentenceTransformer("allenai-specter")
116
-
117
-
118
- def get_embedding(string: str) -> List[float]:
119
  # Embed the string using sentence-transformers
120
  vector = encoder.encode(string, show_progress_bar=False)
121
  return vector
122
 
123
 
124
- def get_diseases_related_to_a_textual_description(description: str) -> List[str]:
 
 
125
  # Embed the description using sentence-transformers
126
- description_embedding = get_embedding(description)
127
- print(f'Size of the embedding: {len(description_embedding)}')
128
  string_representation = str(description_embedding.tolist())[1:-1]
129
- print(f'String representation: {string_representation}')
130
 
131
  with engine.connect() as conn:
132
  with conn.begin():
133
  sql = f"""
134
- SELECT TOP 5 uri, VECTOR_COSINE(e.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
135
- FROM Test.DiseaseDescriptions e
136
  ORDER BY distance DESC
137
  """
138
  result = conn.execute(text(sql))
139
  data = result.fetchall()
140
- return data
 
141
 
142
 
143
  if __name__ == "__main__":
@@ -164,9 +175,29 @@ if __name__ == "__main__":
164
  clinical_record_info = get_clinical_records_by_ids(["NCT00841061"])
165
  print(clinical_record_info)
166
 
167
- textual_description = "A disease that causes memory loss and other cognitive impairments."
168
- diseases = get_diseases_related_to_a_textual_description(textual_description)
 
 
 
 
 
169
  for disease in diseases:
170
  print(disease)
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  # %%
 
5
  import requests
6
  from sentence_transformers import SentenceTransformer
7
 
8
+ username = "demo"
9
+ password = "demo"
10
+ hostname = os.getenv("IRIS_HOSTNAME", "localhost")
11
+ port = "1972"
12
+ namespace = "USER"
13
+ CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
14
+
15
+ engine = create_engine(CONNECTION_STRING)
16
+
17
 
18
  def get_all_diseases_name(engine) -> List[List[str]]:
19
  with engine.connect() as conn:
 
107
  return clinical_records
108
 
109
 
110
+ def get_similarities_among_diseases_uris(
111
+ uri_list: List[str],
112
+ ) -> List[tuple[str, str, float]]:
113
+ uri_list = ", ".join([f"'{uri}'" for uri in uri_list])
114
  with engine.connect() as conn:
115
  with conn.begin():
116
  sql = f"""
117
  SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
118
  FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
119
+ WHERE e1.uri IN ({uri_list}) AND e2.uri IN ({uri_list}) AND e1.uri != e2.uri
120
  """
121
  result = conn.execute(text(sql))
122
  data = result.fetchall()
123
  return data
124
 
125
 
126
+ def get_embedding(string: str, encoder) -> List[float]:
 
 
 
127
  # Embed the string using sentence-transformers
128
  vector = encoder.encode(string, show_progress_bar=False)
129
  return vector
130
 
131
 
132
+ def get_diseases_related_to_a_textual_description(
133
+ description: str, encoder
134
+ ) -> List[str]:
135
  # Embed the description using sentence-transformers
136
+ description_embedding = get_embedding(description, encoder)
137
+ print(f"Size of the embedding: {len(description_embedding)}")
138
  string_representation = str(description_embedding.tolist())[1:-1]
139
+ print(f"String representation: {string_representation}")
140
 
141
  with engine.connect() as conn:
142
  with conn.begin():
143
  sql = f"""
144
+ SELECT TOP 5 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
145
+ FROM Test.DiseaseDescriptions d
146
  ORDER BY distance DESC
147
  """
148
  result = conn.execute(text(sql))
149
  data = result.fetchall()
150
+
151
+ return [{"uri": row[0], "distance": row[1]} for row in data]
152
 
153
 
154
  if __name__ == "__main__":
 
175
  clinical_record_info = get_clinical_records_by_ids(["NCT00841061"])
176
  print(clinical_record_info)
177
 
178
+ textual_description = (
179
+ "A disease that causes memory loss and other cognitive impairments."
180
+ )
181
+ encoder = SentenceTransformer("allenai-specter")
182
+ diseases = get_diseases_related_to_a_textual_description(
183
+ textual_description, encoder
184
+ )
185
  for disease in diseases:
186
  print(disease)
187
 
188
+ try:
189
+ similarities = get_similarities_among_diseases_uris(
190
+ [
191
+ "http://identifiers.org/medgen/C4553765",
192
+ "http://identifiers.org/medgen/C4553176",
193
+ "http://identifiers.org/medgen/C4024935",
194
+ ]
195
+ )
196
+ for similarity in similarities:
197
+ print(
198
+ f'{similarity[0].split("/")[-1]} and {similarity[1].split("/")[-1]} have a similarity of {similarity[2]}'
199
+ )
200
+ except Exception as e:
201
+ print(e)
202
+
203
  # %%