# %% from typing import List, Dict, Any import os from sqlalchemy import create_engine, text import requests from sentence_transformers import SentenceTransformer username = "demo" password = "demo" hostname = os.getenv("IRIS_HOSTNAME", "localhost") port = "1972" namespace = "USER" CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" engine = create_engine(CONNECTION_STRING) def get_all_diseases_name(engine) -> List[List[str]]: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT * FROM Test.EntityEmbeddings """ result = conn.execute(text(sql)) data = result.fetchall() all_diseases = [row[1] for row in data if row[1] != "nan"] return all_diseases def get_uri_from_name(engine, name: str) -> str: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT uri FROM Test.EntityEmbeddings WHERE label = '{name}' """ result = conn.execute(text(sql)) data = result.fetchall() return data[0][0].split("/")[-1] def get_most_similar_diseases_from_uri( engine, original_disease_uri: str, threshold: float = 0.8 ) -> List[str]: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT * FROM Test.EntityEmbeddings """ result = conn.execute(text(sql)) data = result.fetchall() all_diseases = [row[1] for row in data if row[1] != "nan"] return all_diseases def get_uri_from_name(engine, name: str) -> str: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT uri FROM Test.EntityEmbeddings WHERE label = '{name}' """ result = conn.execute(text(sql)) data = result.fetchall() return data[0][0].split("/")[-1] def get_most_similar_diseases_from_uri( engine, original_disease_uri: str, threshold: float = 0.8 ) -> List[str]: with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 WHERE e1.uri = 'http://identifiers.org/medgen/{original_disease_uri}' AND VECTOR_COSINE(e1.embedding, e2.embedding) > {threshold} AND e1.uri != e2.uri ORDER BY distance DESC """ result = conn.execute(text(sql)) data = result.fetchall() similar_diseases = [ (row[1].split("/")[-1], row[3], row[4]) for row in data if row[3] != "nan" ] return similar_diseases def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]: # Request: # curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \ # -H "accept: text/csv" request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}" response = requests.get(request_url, headers={"accept": "application/json"}) return response.json() def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]: clinical_records = [] for clinical_record_id in clinical_record_ids: clinical_record_info = get_clinical_record_info(clinical_record_id) clinical_records.append(clinical_record_info) return clinical_records def get_similarities_among_diseases_uris( uri_list: List[str], ) -> List[tuple[str, str, float]]: uri_list = ", ".join([f"'{uri}'" for uri in uri_list]) with engine.connect() as conn: with conn.begin(): sql = f""" SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 WHERE e1.uri IN ({uri_list}) AND e2.uri IN ({uri_list}) AND e1.uri != e2.uri """ result = conn.execute(text(sql)) data = result.fetchall() return data def augment_the_set_of_diseaces(diseases: List[str]) -> str: print(diseases) for i in range(15-len(diseases)): with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(diseases)}) AS score FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 WHERE e1.uri IN ({','.join([f"'{disease}'" for disease in diseases])}) AND e2.uri NOT IN ({','.join([f"'{disease}'" for disease in diseases])}) AND e2.label != 'nan' GROUP BY e2.label ORDER BY score DESC """ result = conn.execute(text(sql)) data = result.fetchall() diseases.append(data[0][0].split('/')[-1]) return diseases def get_embedding(string: str, encoder) -> List[float]: # Embed the string using sentence-transformers vector = encoder.encode(string, show_progress_bar=False) return vector def get_diseases_related_to_a_textual_description( description: str, encoder ) -> List[str]: # Embed the description using sentence-transformers description_embedding = get_embedding(description, encoder) string_representation = str(description_embedding.tolist())[1:-1] with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 5 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance FROM Test.DiseaseDescriptions d ORDER BY distance DESC """ result = conn.execute(text(sql)) data = result.fetchall() return [{"uri": row[0], "distance": row[1]} for row in data] def get_clinical_trials_related_to_diseases( diseases: List[str], encoder ) -> List[str]: # Embed the diseases using sentence-transformers diseases_string = ", ".join(diseases) disease_embedding = get_embedding(diseases_string, encoder) string_representation = str(disease_embedding.tolist())[1:-1] with engine.connect() as conn: with conn.begin(): sql = f""" SELECT TOP 5 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance FROM Test.ClinicalTrials d ORDER BY distance DESC """ result = conn.execute(text(sql)) data = result.fetchall() return [{"nct_id": row[0], "distance": row[1]} for row in data] if __name__ == "__main__": username = "demo" password = "demo" hostname = os.getenv("IRIS_HOSTNAME", "localhost") port = "1972" namespace = "USER" CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" try: engine = create_engine(CONNECTION_STRING) diseases = get_most_similar_diseases_from_uri("C1843013") for disease in diseases: print(disease) except Exception as e: print(e) try: print(get_uri_from_name(engine, "Alzheimer disease 3")) except Exception as e: print(e) clinical_record_info = get_clinical_records_by_ids(["NCT00841061"]) print(clinical_record_info) textual_description = ( "A disease that causes memory loss and other cognitive impairments." ) encoder = SentenceTransformer("allenai-specter") diseases = get_diseases_related_to_a_textual_description( textual_description, encoder ) for disease in diseases: print(disease) try: similarities = get_similarities_among_diseases_uris( [ "http://identifiers.org/medgen/C4553765", "http://identifiers.org/medgen/C4553176", "http://identifiers.org/medgen/C4024935", ] ) for similarity in similarities: print( f'{similarity[0].split("/")[-1]} and {similarity[1].split("/")[-1]} have a similarity of {similarity[2]}' ) except Exception as e: print(e) # %%