ACMCMC commited on
Commit
1e2e3b8
1 Parent(s): dda0120
Files changed (4) hide show
  1. app.py +13 -10
  2. clinical_trials_embeddings.ipynb +3 -3
  3. llm_res.py +105 -78
  4. requirements.txt +1 -0
app.py CHANGED
@@ -14,6 +14,7 @@ from utils import (
14
  get_clinical_trials_related_to_diseases,
15
  get_clinical_records_by_ids
16
  )
 
17
  import json
18
  import numpy as np
19
  from sentence_transformers import SentenceTransformer
@@ -35,10 +36,6 @@ CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}
35
  engine = create_engine(CONNECTION_STRING)
36
 
37
 
38
- st.title("Klìnic")
39
- st.header("", divider='rainbow')
40
- st.text('') # dummy to add spacing
41
-
42
  with st.container(): # user input
43
  col1, col2 = st.columns((6, 1))
44
 
@@ -58,30 +55,36 @@ with st.container():
58
  with st.status("Analyzing...") as status:
59
  # 1. Embed the textual description that the user entered using the model
60
  # 2. Get 5 diseases with the highest cosine silimarity from the DB
 
61
  encoder = SentenceTransformer("allenai-specter")
62
  diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
63
  description_input, encoder
64
  )
65
- # for disease_label in diseases_related_to_the_user_text:
66
- # st.text(disease_label)
67
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
 
68
  diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
69
  get_similarities_among_diseases_uris(diseases_uris)
70
- #print(diseases_related_to_the_user_text)
71
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
72
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
 
73
  augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
74
- #print(augmented_set_of_diseases)
75
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
 
76
  clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
77
  augmented_set_of_diseases, encoder
78
  )
79
- #print(f'clinical_trials_related_to_the_diseases: {clinical_trials_related_to_the_diseases}')
80
  json_of_clinical_trials = get_clinical_records_by_ids(
81
  [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
82
  )
83
- #print(f'json_of_clinical_trials: {json_of_clinical_trials}')
 
 
 
 
84
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
 
85
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
86
  status.update(label="Done!", state="complete")
87
  time.sleep(1)
 
14
  get_clinical_trials_related_to_diseases,
15
  get_clinical_records_by_ids
16
  )
17
+ from llm_res import process_dictionaty_with_llm_to_generate_response
18
  import json
19
  import numpy as np
20
  from sentence_transformers import SentenceTransformer
 
36
  engine = create_engine(CONNECTION_STRING)
37
 
38
 
 
 
 
 
39
  with st.container(): # user input
40
  col1, col2 = st.columns((6, 1))
41
 
 
55
  with st.status("Analyzing...") as status:
56
  # 1. Embed the textual description that the user entered using the model
57
  # 2. Get 5 diseases with the highest cosine silimarity from the DB
58
+ status.write("Analyzing the description that you wrote...")
59
  encoder = SentenceTransformer("allenai-specter")
60
  diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
61
  description_input, encoder
62
  )
 
 
63
  # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
64
+ status.write("Getting the similarities among the diseases to filter out less promising ones...")
65
  diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
66
  get_similarities_among_diseases_uris(diseases_uris)
 
67
  # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
68
  # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
69
+ status.write("Augmenting the set of diseases by finding others with related embeddings...")
70
  augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
71
+ # print(augmented_set_of_diseases)
72
  # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
73
+ status.write("Getting the clinical trials related to the diseases found...")
74
  clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
75
  augmented_set_of_diseases, encoder
76
  )
77
+ status.write("Getting the details of the clinical trials...")
78
  json_of_clinical_trials = get_clinical_records_by_ids(
79
  [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
80
  )
81
+ status.json(json_of_clinical_trials)
82
+ # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
83
+ status.write("Getting a summary of the clinical trials...")
84
+ response = process_dictionaty_with_llm_to_generate_response(json_of_clinical_trials)
85
+ print(f'Response from LLM: {response}')
86
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
87
+ status.write("Getting summary statistics of the clinical trials...")
88
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
89
  status.update(label="Done!", state="complete")
90
  time.sleep(1)
clinical_trials_embeddings.ipynb CHANGED
@@ -61,9 +61,9 @@
61
  "metadata": {},
62
  "outputs": [],
63
  "source": [
64
- "os.environ[\"OPENAI_API_KEY\"] = (\n",
65
- " \"sk-proj-CG2E98bSWs53X2eWO0Z4T3BlbkFJLm7H1vfkbua0zP548CKQ\"\n",
66
- ")"
67
  ]
68
  },
69
  {
 
61
  "metadata": {},
62
  "outputs": [],
63
  "source": [
64
+ "from dotenv import load_dotenv\n",
65
+ "\n",
66
+ "load_dotenv()"
67
  ]
68
  },
69
  {
llm_res.py CHANGED
@@ -21,6 +21,10 @@ from langchain_core.pydantic_v1 import BaseModel, Field
21
  from langchain_openai import ChatOpenAI
22
  from typing import List, Dict, Any
23
  import requests
 
 
 
 
24
 
25
  # getting the json files
26
  def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]:
@@ -31,6 +35,7 @@ def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]:
31
  response = requests.get(request_url, headers={"accept": "application/json"})
32
  return response.json()
33
 
 
34
  def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]:
35
  clinical_records = []
36
  for clinical_record_id in clinical_record_ids:
@@ -38,80 +43,99 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
38
  clinical_records.append(clinical_record_info)
39
  return clinical_records
40
 
41
- def process_json(json_file):
42
- # processing the files and getting the info needed
43
- # Open the JSON file for reading
44
- with open(json_file, 'r') as f:
45
- data = json.load(f) # Parse JSON data into a Python dictionary
46
 
47
  # Define the fields you want to keep
48
- fields_to_keep = ['class_of_organization', 'title', 'overallStatus', 'descriptionModule', 'conditions', 'interventions', 'outcomesModule', 'eligibilityModule']
 
 
 
 
 
 
 
 
 
49
 
50
  # Iterate through the dictionary and keep only the desired fields
51
  filtered_data = []
52
  for item in data:
53
  try:
54
- organization_name= item['protocolSection']['identificationModule']['organization']['fullName']
 
 
55
  except:
56
- organization_name= ""
57
  try:
58
- project_title= item['protocolSection']['identificationModule']['officialTitle']
 
 
59
  except:
60
- project_title= ""
61
- try:
62
- status= item['protocolSection']['statusModule']['overallStatus']
63
  except:
64
- status= ""
65
  try:
66
- brief_description= item['protocolSection']['descriptionModule']['briefSummary']
 
 
67
  except:
68
- brief_description= ""
69
  try:
70
- detailed_description= item['protocolSection']['descriptionModule']['detailedDescription']
 
 
71
  except:
72
- detailed_description= ""
73
  try:
74
- conditions= item['protocolSection']['conditionsModule']['conditions']
75
  except:
76
- conditions= []
77
  try:
78
- keywords= item['protocolSection']['conditionsModule']['keywords']
79
  except:
80
- keywords= []
81
  try:
82
- interventions= item['protocolSection']['armsInterventionsModule']['interventions']
 
 
83
  except:
84
- interventions= []
85
  try:
86
- primary_outcomes= item['protocolSection']['outcomesModule']['primaryOutcomes']
 
 
87
  except:
88
- primary_outcomes= []
89
  try:
90
- secondary_outcomes= item['protocolSection']['outcomesModule']['secondaryOutcomes']
 
 
91
  except:
92
- secondary_outcomes= []
93
  try:
94
- eligibility= item['protocolSection']['eligibilityModule']
95
  except:
96
- eligibility= {}
97
- filtered_item = {"organization_name": organization_name,
98
- "project_title": project_title,
99
- "status": status,
100
- "brief_description": brief_description,
101
- "detailed_description": detailed_description,
102
- "keywords":keywords,
103
- "interventions": interventions,
104
- "primary_outcomes": primary_outcomes,
105
- "secondary_outcomes": secondary_outcomes,
106
- "eligibility": eligibility}
 
 
107
  filtered_data.append(filtered_item)
108
 
109
  # for ele in filtered_data:
110
  # print(ele)
111
 
112
- # Write the filtered data to a new JSON file
113
- with open('output.json', 'w') as f:
114
- json.dump(filtered_data, f, indent=4)
115
 
116
  def llm_config():
117
  tagging_prompt = ChatPromptTemplate.from_template(
@@ -127,20 +151,38 @@ def llm_config():
127
  )
128
 
129
  class Classification(BaseModel):
130
- description: str = Field(description= "text description grouping all the clinical trials using brief_description and detailed_description keys")
131
- project_title: list = Field(description="Extract the project title of all the clinical trials")
132
- status: list= Field(description="Extract the status of all the clinical trials")
133
- keywords: list= Field(description="Extract the most relevant keywords regrouping all the clinical trials")
134
- interventions: list= Field(description="describe the interventions for each clinical trial using title, name and description")
135
- primary_outcomes: list= Field(description= "get the primary outcomes of each clinical trial")
 
 
 
 
 
 
 
 
 
 
 
 
136
  # secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
137
- eligibility: list= Field(description= "get the eligibilityCriteria grouping all the clinical trials")
 
 
138
  # healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
139
- minimum_age: list = Field(description="get the minimum age from each experiment")
140
- maximum_age: list = Field(description="get the maximum age from each experiment")
 
 
 
 
141
  gender: list = Field(description="get the gender from each experiment")
142
 
143
- def get_dict(self):
144
  return {
145
  "summary": self.description,
146
  "project_title": self.project_title,
@@ -153,45 +195,30 @@ def llm_config():
153
  # "healthy_volunteers": self.healthy_volunteers,
154
  "minimum_age": self.minimum_age,
155
  "maximum_age": self.maximum_age,
156
- "gender": self.gender
157
  }
158
-
159
  # LLM
160
  llm = ChatOpenAI(
161
- temperature=0.6,
162
  model="gpt-4",
163
- openai_api_key="sk-proj-CG2E98bSWs53X2eWO0Z4T3BlbkFJLm7H1vfkbua0zP548CKQ"
164
- ).with_structured_output(
165
- Classification
166
- )
167
 
168
  tagging_chain = tagging_prompt | llm
169
-
170
  return tagging_chain
171
 
172
- def get_llm_results(results):
173
- result_dict= results.get_dict()
174
- return result_dict
175
 
176
- def save_llm_results(results_json):
177
- with open('llm_results.json', 'w') as f:
178
- json.dump(results_json, f, indent=4)
179
-
180
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
181
  # print(clinical_record_info)
182
 
183
  # with open('data.json', 'w') as f:
184
  # json.dump(clinical_record_info, f, indent=4)
185
 
186
- # change the json file here and run it to get the output
187
- json_file= "D:/HACKUPC/hupc/klinic/data.json"
188
- process_json(json_file)
189
-
190
- with open('output.json', 'r') as file:
191
- data = json.load(file)
192
 
193
- tagging_chain= llm_config()
194
- res= tagging_chain.invoke({"input": data})
195
- result_json= get_llm_results(res)
196
- save_llm_results(result_json)
197
- print(result_json)
 
21
  from langchain_openai import ChatOpenAI
22
  from typing import List, Dict, Any
23
  import requests
24
+ from dotenv import load_dotenv
25
+
26
+ load_dotenv()
27
+
28
 
29
  # getting the json files
30
  def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]:
 
35
  response = requests.get(request_url, headers={"accept": "application/json"})
36
  return response.json()
37
 
38
+
39
  def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]:
40
  clinical_records = []
41
  for clinical_record_id in clinical_record_ids:
 
43
  clinical_records.append(clinical_record_info)
44
  return clinical_records
45
 
46
+
47
+ def process_json_data_for_llm(data):
 
 
 
48
 
49
  # Define the fields you want to keep
50
+ fields_to_keep = [
51
+ "class_of_organization",
52
+ "title",
53
+ "overallStatus",
54
+ "descriptionModule",
55
+ "conditions",
56
+ "interventions",
57
+ "outcomesModule",
58
+ "eligibilityModule",
59
+ ]
60
 
61
  # Iterate through the dictionary and keep only the desired fields
62
  filtered_data = []
63
  for item in data:
64
  try:
65
+ organization_name = item["protocolSection"]["identificationModule"][
66
+ "organization"
67
+ ]["fullName"]
68
  except:
69
+ organization_name = ""
70
  try:
71
+ project_title = item["protocolSection"]["identificationModule"][
72
+ "officialTitle"
73
+ ]
74
  except:
75
+ project_title = ""
76
+ try:
77
+ status = item["protocolSection"]["statusModule"]["overallStatus"]
78
  except:
79
+ status = ""
80
  try:
81
+ brief_description = item["protocolSection"]["descriptionModule"][
82
+ "briefSummary"
83
+ ]
84
  except:
85
+ brief_description = ""
86
  try:
87
+ detailed_description = item["protocolSection"]["descriptionModule"][
88
+ "detailedDescription"
89
+ ]
90
  except:
91
+ detailed_description = ""
92
  try:
93
+ conditions = item["protocolSection"]["conditionsModule"]["conditions"]
94
  except:
95
+ conditions = []
96
  try:
97
+ keywords = item["protocolSection"]["conditionsModule"]["keywords"]
98
  except:
99
+ keywords = []
100
  try:
101
+ interventions = item["protocolSection"]["armsInterventionsModule"][
102
+ "interventions"
103
+ ]
104
  except:
105
+ interventions = []
106
  try:
107
+ primary_outcomes = item["protocolSection"]["outcomesModule"][
108
+ "primaryOutcomes"
109
+ ]
110
  except:
111
+ primary_outcomes = []
112
  try:
113
+ secondary_outcomes = item["protocolSection"]["outcomesModule"][
114
+ "secondaryOutcomes"
115
+ ]
116
  except:
117
+ secondary_outcomes = []
118
  try:
119
+ eligibility = item["protocolSection"]["eligibilityModule"]
120
  except:
121
+ eligibility = {}
122
+ filtered_item = {
123
+ "organization_name": organization_name,
124
+ "project_title": project_title,
125
+ "status": status,
126
+ "brief_description": brief_description,
127
+ "detailed_description": detailed_description,
128
+ "keywords": keywords,
129
+ "interventions": interventions,
130
+ "primary_outcomes": primary_outcomes,
131
+ "secondary_outcomes": secondary_outcomes,
132
+ "eligibility": eligibility,
133
+ }
134
  filtered_data.append(filtered_item)
135
 
136
  # for ele in filtered_data:
137
  # print(ele)
138
 
 
 
 
139
 
140
  def llm_config():
141
  tagging_prompt = ChatPromptTemplate.from_template(
 
151
  )
152
 
153
  class Classification(BaseModel):
154
+ description: str = Field(
155
+ description="text description grouping all the clinical trials using brief_description and detailed_description keys"
156
+ )
157
+ project_title: list = Field(
158
+ description="Extract the project title of all the clinical trials"
159
+ )
160
+ status: list = Field(
161
+ description="Extract the status of all the clinical trials"
162
+ )
163
+ keywords: list = Field(
164
+ description="Extract the most relevant keywords regrouping all the clinical trials"
165
+ )
166
+ interventions: list = Field(
167
+ description="describe the interventions for each clinical trial using title, name and description"
168
+ )
169
+ primary_outcomes: list = Field(
170
+ description="get the primary outcomes of each clinical trial"
171
+ )
172
  # secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
173
+ eligibility: list = Field(
174
+ description="get the eligibilityCriteria grouping all the clinical trials"
175
+ )
176
  # healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
177
+ minimum_age: list = Field(
178
+ description="get the minimum age from each experiment"
179
+ )
180
+ maximum_age: list = Field(
181
+ description="get the maximum age from each experiment"
182
+ )
183
  gender: list = Field(description="get the gender from each experiment")
184
 
185
+ def get_dict(self):
186
  return {
187
  "summary": self.description,
188
  "project_title": self.project_title,
 
195
  # "healthy_volunteers": self.healthy_volunteers,
196
  "minimum_age": self.minimum_age,
197
  "maximum_age": self.maximum_age,
198
+ "gender": self.gender,
199
  }
200
+
201
  # LLM
202
  llm = ChatOpenAI(
203
+ temperature=0.6,
204
  model="gpt-4",
205
+ openai_api_key=os.environ["OPENAI_API_KEY"],
206
+ ).with_structured_output(Classification)
 
 
207
 
208
  tagging_chain = tagging_prompt | llm
209
+
210
  return tagging_chain
211
 
 
 
 
212
 
 
 
 
 
213
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
214
  # print(clinical_record_info)
215
 
216
  # with open('data.json', 'w') as f:
217
  # json.dump(clinical_record_info, f, indent=4)
218
 
219
+ tagging_chain = llm_config()
 
 
 
 
 
220
 
221
+ def process_dictionaty_with_llm_to_generate_response(json_contents):
222
+ processed_data = process_json_data_for_llm(json_contents)
223
+ res = tagging_chain.invoke({"input": processed_data})
224
+ return res
 
requirements.txt CHANGED
@@ -10,3 +10,4 @@ openai==1.25.1
10
  sentence_transformers==2.7.0
11
  streamlit-agraph
12
  streamlit==1.34.0
 
 
10
  sentence_transformers==2.7.0
11
  streamlit-agraph
12
  streamlit==1.34.0
13
+ langchain-openai==0.1.6