File size: 8,030 Bytes
93e1b64
 
 
 
 
27d40b9
 
 
 
 
 
 
 
 
 
 
52ee7a9
93e1b64
1f35211
 
93e1b64
 
9b5c4aa
 
 
 
 
27d40b9
9b5c4aa
27d40b9
 
 
 
 
93e1b64
 
 
27d40b9
ee05396
 
 
9b5c4aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e2e3b8
9b5c4aa
 
 
 
 
1e2e3b8
9b5c4aa
 
 
 
1e2e3b8
9b5c4aa
1e2e3b8
9b5c4aa
1e2e3b8
9b5c4aa
 
 
1e2e3b8
9b5c4aa
 
 
ee05396
1e2e3b8
 
52ee7a9
1e2e3b8
52ee7a9
9b5c4aa
1e2e3b8
9b5c4aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93e1b64
9b5c4aa
 
 
 
 
 
 
 
 
 
93e1b64
9b5c4aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93e1b64
9b5c4aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import streamlit as st
from streamlit_agraph import agraph, Node, Edge, Config
import os
from sqlalchemy import create_engine, text
import pandas as pd
import time
from utils import (
    get_all_diseases_name,
    get_most_similar_diseases_from_uri,
    get_uri_from_name,
    get_diseases_related_to_a_textual_description,
    get_similarities_among_diseases_uris,
    augment_the_set_of_diseaces,
    get_clinical_trials_related_to_diseases,
    get_clinical_records_by_ids
)
from llm_res import get_short_summary_out_of_json_files
import json
import numpy as np
from sentence_transformers import SentenceTransformer


# variables to reveal next steps
show_graph = False
show_analyze_status = False
show_overview = False
show_details = False

# IRIS connection
username = "demo"
password = "demo"
hostname = os.getenv("IRIS_HOSTNAME", "localhost")
port = "1972"
namespace = "USER"
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
engine = create_engine(CONNECTION_STRING)


st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
st.title("Klìnic", help="AI-powered clinical trial search engine")

with st.container(): # user input
    col1, col2 = st.columns((6, 1))

    with col1:
        description_input = st.text_area(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')

    with col2:
        st.text('') # dummy to center vertically
        st.text('') # dummy to center vertically
        st.text('') # dummy to center vertically
        show_analyze_status = st.button("Analyze 🔎")


# analyze
with st.container():
    if show_analyze_status:
        with st.status("Analyzing...") as status:
            # 1. Embed the textual description that the user entered using the model
            # 2. Get 5 diseases with the highest cosine silimarity from the DB
            status.write("Analyzing the description that you wrote...")
            encoder = SentenceTransformer("allenai-specter")
            diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
                description_input, encoder
            )
            # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
            status.write("Getting the similarities among the diseases to filter out less promising ones...")
            diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
            get_similarities_among_diseases_uris(diseases_uris)
            # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
            # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
            status.write("Augmenting the set of diseases by finding others with related embeddings...")
            augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
            # print(augmented_set_of_diseases)
            # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
            status.write("Getting the clinical trials related to the diseases found...")
            clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
                augmented_set_of_diseases, encoder
            )
            status.write("Getting the details of the clinical trials...")
            json_of_clinical_trials = get_clinical_records_by_ids(
                [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
            )
            status.json(json_of_clinical_trials, expanded=False)
            # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
            status.write("Getting a summary of the clinical trials...")
            response = get_short_summary_out_of_json_files(json_of_clinical_trials)
            print(f'Response from LLM: {response}')
            status.write(f'Response from LLM: {response}')
            # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
            status.write("Getting summary statistics of the clinical trials...")
            # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
            status.update(label="Done!", state="complete")
            time.sleep(1)
            show_graph = True


# graph
with st.container():
    if show_graph:
        # TODO actual graph
        graph_of_diseases = agraph(
            nodes=[
                Node(id="A", label="Node A", size=10),
                Node(id="B", label="Node B", size=10),
                Node(id="C", label="Node C", size=10),
                Node(id="D", label="Node D", size=10),
                Node(id="E", label="Node E", size=10),
                Node(id="F", label="Node F", size=10),
                Node(id="G", label="Node G", size=10),
                Node(id="H", label="Node H", size=10),
                Node(id="I", label="Node I", size=10),
                Node(id="J", label="Node J", size=10),
            ],
            edges=[
                Edge(source="A", target="B"),
                Edge(source="B", target="C"),
                Edge(source="C", target="D"),
                Edge(source="D", target="E"),
                Edge(source="E", target="F"),
                Edge(source="F", target="G"),
                Edge(source="G", target="H"),
                Edge(source="H", target="I"),
                Edge(source="I", target="J"),
            ],
            config=Config(height=500, width=500),
        )
        time.sleep(2)
        show_overview = True


# overview
with st.container():
    if show_overview:
        st.write("## Disease Overview")
        disease_overview = ":red[lorem ipsum]"  # TODO
        st.write(disease_overview)
        time.sleep(2)
        show_details = True


# details
with st.container():
    if show_details:
        st.write("## Clinical Trials Details")
        trials = []
        # TODO replace mock data
        with open("mock_trial.json") as f:
            d = json.load(f)
        for i in range(0, 5):
            trials.append(d)

        for trial in trials:
            with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"):
                official_title = trial["protocolSection"]["identificationModule"][
                    "officialTitle"
                ]
                st.write(f"##### {official_title}")

                brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
                st.write(brief_summary)

                status_module = {
                    "Status": trial["protocolSection"]["statusModule"]["overallStatus"],
                    "Status Date": trial["protocolSection"]["statusModule"][
                        "statusVerifiedDate"
                    ],
                }
                st.write("###### Status")
                st.table(status_module)

                design_module = {
                    "Study Type": trial["protocolSection"]["designModule"]["studyType"],
                    # "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array
                    "Allocation": trial["protocolSection"]["designModule"]["designInfo"][
                        "allocation"
                    ],
                    "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
                        "count"
                    ],
                }
                st.write("###### Design")
                st.table(design_module)

                # TODO more modules?