# %% import os from typing import Any, Dict, List import pandas as pd import requests import streamlit as st from sentence_transformers import SentenceTransformer from sqlalchemy import create_engine, text username = "demo" password = "demo" hostname = os.getenv("IRIS_HOSTNAME", "localhost") port = "1972" namespace = "USER" CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" engine = create_engine(CONNECTION_STRING) def get_all_diseases_name(engine) -> List[List[str]]: print("Fetching all disease names...") with engine.connect() as conn: with conn.begin(): sql = f""" SELECT label FROM Test.EntityEmbeddings """ result = conn.execute(text(sql)) data = result.fetchall() all_diseases = [row[0] for row in data if row[0] != "nan"] return all_diseases def get_uri_from_name(engine, name: str) -> str: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT uri FROM Test.EntityEmbeddings WHERE label = '{name}' """ result = conn.execute(text(sql)) data = result.fetchall() return data[0][0].split("/")[-1] def get_most_similar_diseases_from_uri( engine, original_disease_uri: str, threshold: float = 0.8 ) -> List[str]: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT * FROM Test.EntityEmbeddings """ result = conn.execute(text(sql)) data = result.fetchall() all_diseases = [row[1] for row in data if row[1] != "nan"] return all_diseases def get_uri_from_name(engine, name: str) -> str: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT uri FROM Test.EntityEmbeddings WHERE label = '{name}' """ result = conn.execute(text(sql)) data = result.fetchall() return data[0][0].split("/")[-1] def get_most_similar_diseases_from_uri( engine, original_disease_uri: str, threshold: float = 0.8 ) -> List[str]: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 WHERE e1.uri = 'http://identifiers.org/medgen/{original_disease_uri}' AND VECTOR_COSINE(e1.embedding, e2.embedding) > {threshold} AND e1.uri != e2.uri ORDER BY distance DESC """ result = conn.execute(text(sql)) data = result.fetchall() similar_diseases = [ (row[1].split("/")[-1], row[3], row[4]) for row in data if row[3] != "nan" ] return similar_diseases def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]: # Request: # curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \ # -H "accept: text/csv" request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}" response = requests.get(request_url, headers={"accept": "application/json"}) return response.json() def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]: clinical_records = [] for clinical_record_id in clinical_record_ids: clinical_record_info = get_clinical_record_info(clinical_record_id) clinical_records.append(clinical_record_info) return clinical_records def get_similarities_among_diseases_uris( uri_list: List[str], ) -> List[tuple[str, str, float]]: uri_list = ", ".join([f"'{uri}'" for uri in uri_list]) with engine.connect() as conn: with conn.begin(): sql = f""" SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 WHERE e1.uri IN ({uri_list}) AND e2.uri IN ({uri_list}) AND e1.uri != e2.uri """ result = conn.execute(text(sql)) data = result.fetchall() return [ { "uri1": row[0].split("/")[-1], "uri2": row[1].split("/")[-1], "distance": float(row[2]), } for row in data ] def augment_the_set_of_diseaces(diseases: List[str]) -> str: augmented_diseases = diseases.copy() for i in range(10 - len(augmented_diseases)): with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(augmented_diseases)}) AS score FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 WHERE e1.uri IN ({','.join([f"'{disease}'" for disease in augmented_diseases])}) AND e2.uri NOT IN ({','.join([f"'{disease}'" for disease in augmented_diseases])}) AND e2.label != 'nan' GROUP BY e2.label ORDER BY score DESC """ result = conn.execute(text(sql)) data = result.fetchall() augmented_diseases.append(data[0][0]) return augmented_diseases def get_embedding(string: str, encoder) -> List[float]: # Embed the string using sentence-transformers vector = encoder.encode(string, show_progress_bar=False) return vector def get_diseases_related_to_a_textual_description( description: str, encoder ) -> List[str]: # Embed the description using sentence-transformers description_embedding = get_embedding(description, encoder) string_representation = str(description_embedding.tolist())[1:-1] with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 10 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance FROM Test.DiseaseDescriptions d ORDER BY distance DESC """ result = conn.execute(text(sql)) data = result.fetchall() return [ {"uri": row[0], "distance": float(row[1])} for row in data if float(row[1]) > 0.8 ] def get_clinical_trials_related_to_diseases(diseases: List[str], encoder) -> List[str]: # Embed the diseases using sentence-transformers diseases_string = ", ".join(diseases) disease_embedding = get_embedding(diseases_string, encoder) string_representation = str(disease_embedding.tolist())[1:-1] with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 20 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance FROM Test.ClinicalTrials d ORDER BY distance DESC """ result = conn.execute(text(sql)) data = result.fetchall() return [{"nct_id": row[0], "distance": row[1]} for row in data] def get_similarities_df(diseases: List[Dict[str, Any]]) -> pd.DataFrame: # Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2 df_diseases_similarities = pd.DataFrame(diseases) # Use uri1 as the index, and uri2 as the columns. The values are the distances. df_diseases_similarities = df_diseases_similarities.pivot( index="uri1", columns="uri2", values="distance" ) # Fill the diagonal with 1.0 df_diseases_similarities = df_diseases_similarities.fillna(1.0) return df_diseases_similarities def filter_out_less_promising_diseases(info_dicts: List[Dict[str, Any]]) -> List[str]: df_diseases_similarities = get_similarities_df(info_dicts) # Filter out the diseases that are 0.2 standard deviations below the mean mean = df_diseases_similarities.mean().mean() std = df_diseases_similarities.mean().std() filtered_diseases = df_diseases_similarities.mean()[ df_diseases_similarities.mean() > mean - 0.2 * std ].index.tolist() return filtered_diseases, df_diseases_similarities def get_labels_of_diseases_from_uris(uris: List[str]) -> List[str]: with engine.connect() as conn: with conn.begin(): joined_uris = ", ".join([f"'{uri}'" for uri in uris]) sql = f""" SELECT label FROM Test.EntityEmbeddings WHERE uri IN ({joined_uris}) """ result = conn.execute(text(sql)) data = result.fetchall() return [row[0] for row in data] def to_capitalized_case(string: str) -> str: string = string.replace("_", " ") if string.isupper(): return string[0] + string[1:].lower() def list_to_capitalized_case(strings: List[str]) -> str: strings = [to_capitalized_case(s) for s in strings] return ", ".join(strings) def render_trial_details(trial: dict) -> None: # TODO: handle key errors for all cases (→ do not render) official_title = trial["protocolSection"]["identificationModule"]["officialTitle"] st.write(f"##### {official_title}") try: st.write(trial["protocolSection"]["descriptionModule"]["briefSummary"]) except KeyError: try: st.write( trial["protocolSection"]["descriptionModule"]["detailedDescription"] ) except KeyError: st.error("No description available.") st.write("###### Status") try: status_module = { "Status": to_capitalized_case( trial["protocolSection"]["statusModule"]["overallStatus"] ), "Status Date": trial["protocolSection"]["statusModule"][ "statusVerifiedDate" ], "Has Results": trial["hasResults"], } st.table(status_module) except KeyError: st.info("No status information available.") st.write("###### Design") try: design_module = { "Study Type": to_capitalized_case( trial["protocolSection"]["designModule"]["studyType"] ), "Phases": list_to_capitalized_case( trial["protocolSection"]["designModule"]["phases"] ), "Allocation": to_capitalized_case( trial["protocolSection"]["designModule"]["designInfo"]["allocation"] ), "Primary Purpose": to_capitalized_case( trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"] ), "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][ "count" ], "Masking": to_capitalized_case( trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][ "masking" ] ), "Who Masked": list_to_capitalized_case( trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][ "whoMasked" ] ), } st.table(design_module) except KeyError: st.info("No design information available.") st.write("###### Interventions") try: interventions_module = {} for intervention in trial["protocolSection"]["armsInterventionsModule"][ "interventions" ]: name = intervention["name"] desc = intervention["description"] interventions_module[name] = desc st.table(interventions_module) except KeyError: st.info("No interventions information available.") # Button to go to ClinicalTrials.gov and see the trial. It takes the user to the official page of the trial. st.markdown( f"See more in [ClinicalTrials.gov](https://clinicaltrials.gov/study/{trial['protocolSection']['identificationModule']['nctId']})" ) if __name__ == "__main__": username = "demo" password = "demo" hostname = os.getenv("IRIS_HOSTNAME", "localhost") port = "1972" namespace = "USER" CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" try: engine = create_engine(CONNECTION_STRING) diseases = get_most_similar_diseases_from_uri("C1843013") for disease in diseases: print(disease) except Exception as e: print(e) try: print(get_uri_from_name(engine, "Alzheimer disease 3")) except Exception as e: print(e) clinical_record_info = get_clinical_records_by_ids(["NCT00841061"]) print(clinical_record_info) textual_description = ( "A disease that causes memory loss and other cognitive impairments." ) encoder = SentenceTransformer("allenai-specter") diseases = get_diseases_related_to_a_textual_description( textual_description, encoder ) for disease in diseases: print(disease) try: similarities = get_similarities_among_diseases_uris( [ "http://identifiers.org/medgen/C4553765", "http://identifiers.org/medgen/C4553176", "http://identifiers.org/medgen/C4024935", ] ) for similarity in similarities: print( f'{similarity[0].split("/")[-1]} and {similarity[1].split("/")[-1]} have a similarity of {similarity[2]}' ) except Exception as e: print(e) # %%