klinic / utils.py
ACMCMC
Final version
2bb81aa
raw
history blame contribute delete
No virus
14 kB
# %%
import os
from typing import Any, Dict, List
import pandas as pd
import requests
import streamlit as st
from sentence_transformers import SentenceTransformer
from sqlalchemy import create_engine, text
username = "demo"
password = "demo"
hostname = os.getenv("IRIS_HOSTNAME", "localhost")
port = "1972"
namespace = "USER"
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
engine = create_engine(CONNECTION_STRING)
def get_all_diseases_name(engine) -> List[List[str]]:
print("Fetching all disease names...")
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT label FROM Test.EntityEmbeddings
"""
result = conn.execute(text(sql))
data = result.fetchall()
all_diseases = [row[0] for row in data if row[0] != "nan"]
return all_diseases
def get_uri_from_name(engine, name: str) -> str:
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT uri FROM Test.EntityEmbeddings
WHERE label = '{name}'
"""
result = conn.execute(text(sql))
data = result.fetchall()
return data[0][0].split("/")[-1]
def get_most_similar_diseases_from_uri(
engine, original_disease_uri: str, threshold: float = 0.8
) -> List[str]:
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT * FROM Test.EntityEmbeddings
"""
result = conn.execute(text(sql))
data = result.fetchall()
all_diseases = [row[1] for row in data if row[1] != "nan"]
return all_diseases
def get_uri_from_name(engine, name: str) -> str:
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT uri FROM Test.EntityEmbeddings
WHERE label = '{name}'
"""
result = conn.execute(text(sql))
data = result.fetchall()
return data[0][0].split("/")[-1]
def get_most_similar_diseases_from_uri(
engine, original_disease_uri: str, threshold: float = 0.8
) -> List[str]:
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2,
VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
WHERE e1.uri = 'http://identifiers.org/medgen/{original_disease_uri}'
AND VECTOR_COSINE(e1.embedding, e2.embedding) > {threshold}
AND e1.uri != e2.uri
ORDER BY distance DESC
"""
result = conn.execute(text(sql))
data = result.fetchall()
similar_diseases = [
(row[1].split("/")[-1], row[3], row[4]) for row in data if row[3] != "nan"
]
return similar_diseases
def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]:
# Request:
# curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \
# -H "accept: text/csv"
request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}"
response = requests.get(request_url, headers={"accept": "application/json"})
return response.json()
def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]:
clinical_records = []
for clinical_record_id in clinical_record_ids:
clinical_record_info = get_clinical_record_info(clinical_record_id)
clinical_records.append(clinical_record_info)
return clinical_records
def get_similarities_among_diseases_uris(
uri_list: List[str],
) -> List[tuple[str, str, float]]:
uri_list = ", ".join([f"'{uri}'" for uri in uri_list])
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
WHERE e1.uri IN ({uri_list}) AND e2.uri IN ({uri_list}) AND e1.uri != e2.uri
"""
result = conn.execute(text(sql))
data = result.fetchall()
return [
{
"uri1": row[0].split("/")[-1],
"uri2": row[1].split("/")[-1],
"distance": float(row[2]),
}
for row in data
]
def augment_the_set_of_diseaces(diseases: List[str]) -> str:
augmented_diseases = diseases.copy()
for i in range(10 - len(augmented_diseases)):
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(augmented_diseases)}) AS score
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
WHERE e1.uri IN ({','.join([f"'{disease}'" for disease in augmented_diseases])})
AND e2.uri NOT IN ({','.join([f"'{disease}'" for disease in augmented_diseases])})
AND e2.label != 'nan'
GROUP BY e2.label
ORDER BY score DESC
"""
result = conn.execute(text(sql))
data = result.fetchall()
augmented_diseases.append(data[0][0])
return augmented_diseases
def get_embedding(string: str, encoder) -> List[float]:
# Embed the string using sentence-transformers
vector = encoder.encode(string, show_progress_bar=False)
return vector
def get_diseases_related_to_a_textual_description(
description: str, encoder
) -> List[str]:
# Embed the description using sentence-transformers
description_embedding = get_embedding(description, encoder)
string_representation = str(description_embedding.tolist())[1:-1]
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT TOP 10 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
FROM Test.DiseaseDescriptions d
ORDER BY distance DESC
"""
result = conn.execute(text(sql))
data = result.fetchall()
return [
{"uri": row[0], "distance": float(row[1])}
for row in data
if float(row[1]) > 0.8
]
def get_clinical_trials_related_to_diseases(diseases: List[str], encoder) -> List[str]:
# Embed the diseases using sentence-transformers
diseases_string = ", ".join(diseases)
disease_embedding = get_embedding(diseases_string, encoder)
string_representation = str(disease_embedding.tolist())[1:-1]
with engine.connect() as conn:
with conn.begin():
sql = f"""
SELECT TOP 20 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
FROM Test.ClinicalTrials d
ORDER BY distance DESC
"""
result = conn.execute(text(sql))
data = result.fetchall()
return [{"nct_id": row[0], "distance": row[1]} for row in data]
def get_similarities_df(diseases: List[Dict[str, Any]]) -> pd.DataFrame:
# Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2
df_diseases_similarities = pd.DataFrame(diseases)
# Use uri1 as the index, and uri2 as the columns. The values are the distances.
df_diseases_similarities = df_diseases_similarities.pivot(
index="uri1", columns="uri2", values="distance"
)
# Fill the diagonal with 1.0
df_diseases_similarities = df_diseases_similarities.fillna(1.0)
return df_diseases_similarities
def filter_out_less_promising_diseases(info_dicts: List[Dict[str, Any]]) -> List[str]:
df_diseases_similarities = get_similarities_df(info_dicts)
# Filter out the diseases that are 0.2 standard deviations below the mean
mean = df_diseases_similarities.mean().mean()
std = df_diseases_similarities.mean().std()
filtered_diseases = df_diseases_similarities.mean()[
df_diseases_similarities.mean() > mean - 0.2 * std
].index.tolist()
return [f'http://identifiers.org/medgen/{d}' for d in filtered_diseases], df_diseases_similarities
def get_labels_of_diseases_from_uris(uris: List[str]) -> List[str]:
with engine.connect() as conn:
with conn.begin():
joined_uris = ", ".join([f"'{uri}'" for uri in uris])
sql = f"""
SELECT label FROM Test.EntityEmbeddings
WHERE uri IN ({joined_uris})
"""
print(text(sql))
result = conn.execute(text(sql))
data = result.fetchall()
return [row[0] for row in data]
def to_capitalized_case(string: str) -> str:
string = string.replace("_", " ")
if string.isupper():
return string[0] + string[1:].lower()
def list_to_capitalized_case(strings: List[str]) -> str:
strings = [to_capitalized_case(s) for s in strings]
return ", ".join(strings)
def render_trial_details(trial: dict) -> None:
# TODO: handle key errors for all cases (→ do not render)
official_title = trial["protocolSection"]["identificationModule"]["officialTitle"]
st.write(f"##### {official_title}")
try:
st.write(trial["protocolSection"]["descriptionModule"]["briefSummary"])
except KeyError:
try:
st.write(
trial["protocolSection"]["descriptionModule"]["detailedDescription"]
)
except KeyError:
st.error("No description available.")
st.write("###### Status")
try:
status_module = {
"Status": to_capitalized_case(
trial["protocolSection"]["statusModule"]["overallStatus"]
),
"Status Date": trial["protocolSection"]["statusModule"][
"statusVerifiedDate"
],
"Has Results": trial["hasResults"],
}
st.table(status_module)
except KeyError:
st.info("No status information available.")
st.write("###### Design")
try:
design_module = {
"Study Type": to_capitalized_case(
trial["protocolSection"]["designModule"]["studyType"]
),
"Phases": list_to_capitalized_case(
trial["protocolSection"]["designModule"]["phases"]
),
"Allocation": to_capitalized_case(
trial["protocolSection"]["designModule"]["designInfo"]["allocation"]
),
"Primary Purpose": to_capitalized_case(
trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]
),
"Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
"count"
],
"Masking": to_capitalized_case(
trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][
"masking"
]
),
"Who Masked": list_to_capitalized_case(
trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][
"whoMasked"
]
),
}
st.table(design_module)
except KeyError:
st.info("No design information available.")
st.write("###### Interventions")
try:
interventions_module = {}
for intervention in trial["protocolSection"]["armsInterventionsModule"][
"interventions"
]:
name = intervention["name"]
desc = intervention["description"]
interventions_module[name] = desc
st.table(interventions_module)
except KeyError:
st.info("No interventions information available.")
# Button to go to ClinicalTrials.gov and see the trial. It takes the user to the official page of the trial.
st.markdown(
f"See more in [ClinicalTrials.gov](https://clinicaltrials.gov/study/{trial['protocolSection']['identificationModule']['nctId']})"
)
if __name__ == "__main__":
username = "demo"
password = "demo"
hostname = os.getenv("IRIS_HOSTNAME", "localhost")
port = "1972"
namespace = "USER"
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
try:
engine = create_engine(CONNECTION_STRING)
diseases = get_most_similar_diseases_from_uri("C1843013")
for disease in diseases:
print(disease)
except Exception as e:
print(e)
try:
print(get_uri_from_name(engine, "Alzheimer disease 3"))
except Exception as e:
print(e)
clinical_record_info = get_clinical_records_by_ids(["NCT00841061"])
print(clinical_record_info)
textual_description = (
"A disease that causes memory loss and other cognitive impairments."
)
encoder = SentenceTransformer("allenai-specter")
diseases = get_diseases_related_to_a_textual_description(
textual_description, encoder
)
for disease in diseases:
print(disease)
try:
similarities = get_similarities_among_diseases_uris(
[
"http://identifiers.org/medgen/C4553765",
"http://identifiers.org/medgen/C4553176",
"http://identifiers.org/medgen/C4024935",
]
)
for similarity in similarities:
print(
f'{similarity[0].split("/")[-1]} and {similarity[1].split("/")[-1]} have a similarity of {similarity[2]}'
)
except Exception as e:
print(e)
# %%