ACMCMC commited on
Commit
52ee7a9
1 Parent(s): 90c8ced

changes to gpt inference

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. llm_res.py +83 -46
app.py CHANGED
@@ -14,7 +14,7 @@ from utils import (
14
  get_clinical_trials_related_to_diseases,
15
  get_clinical_records_by_ids
16
  )
17
- from llm_res import process_dictionaty_with_llm_to_generate_response
18
  import json
19
  import numpy as np
20
  from sentence_transformers import SentenceTransformer
@@ -81,8 +81,9 @@ with st.container():
81
  status.json(json_of_clinical_trials)
82
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
83
  status.write("Getting a summary of the clinical trials...")
84
- response = process_dictionaty_with_llm_to_generate_response(json_of_clinical_trials)
85
  print(f'Response from LLM: {response}')
 
86
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
87
  status.write("Getting summary statistics of the clinical trials...")
88
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
 
14
  get_clinical_trials_related_to_diseases,
15
  get_clinical_records_by_ids
16
  )
17
+ from llm_res import get_short_summary_out_of_json_files
18
  import json
19
  import numpy as np
20
  from sentence_transformers import SentenceTransformer
 
81
  status.json(json_of_clinical_trials)
82
  # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
83
  status.write("Getting a summary of the clinical trials...")
84
+ response = get_short_summary_out_of_json_files(json_of_clinical_trials)
85
  print(f'Response from LLM: {response}')
86
+ status.write(f'Response from LLM: {response}')
87
  # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
88
  status.write("Getting summary statistics of the clinical trials...")
89
  # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
llm_res.py CHANGED
@@ -1,27 +1,27 @@
 
1
  import json
2
- from langchain_community.document_loaders.csv_loader import CSVLoader
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
- import pandas as pd
5
- import langchain
6
  import os
 
 
 
7
  import openai
8
- import ast
 
 
9
  from langchain import OpenAI
 
10
  from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
11
- from langchain.text_splitter import RecursiveCharacterTextSplitter
12
- from langchain_community.document_loaders import JSONLoader
13
  from langchain.document_loaders import UnstructuredURLLoader
14
  from langchain.embeddings import OpenAIEmbeddings
 
15
  from langchain.vectorstores import FAISS
 
 
16
  from langchain_core.prompts import ChatPromptTemplate
17
  from langchain_core.pydantic_v1 import BaseModel, Field
18
  from langchain_openai import ChatOpenAI
19
- from langchain_core.prompts import ChatPromptTemplate
20
- from langchain_core.pydantic_v1 import BaseModel, Field
21
- from langchain_openai import ChatOpenAI
22
- from typing import List, Dict, Any
23
- import requests
24
- from dotenv import load_dotenv
25
 
26
  load_dotenv()
27
 
@@ -78,17 +78,17 @@ def process_json_data_for_llm(data):
78
  except:
79
  status = ""
80
  try:
81
- brief_description = item["protocolSection"]["descriptionModule"][
82
  "briefSummary"
83
  ]
84
  except:
85
- brief_description = ""
86
  try:
87
- detailed_description = item["protocolSection"]["descriptionModule"][
88
  "detailedDescription"
89
  ]
90
  except:
91
- detailed_description = ""
92
  try:
93
  conditions = item["protocolSection"]["conditionsModule"]["conditions"]
94
  except:
@@ -123,8 +123,8 @@ def process_json_data_for_llm(data):
123
  "organization_name": organization_name,
124
  "project_title": project_title,
125
  "status": status,
126
- "brief_description": brief_description,
127
- "detailed_description": detailed_description,
128
  "keywords": keywords,
129
  "interventions": interventions,
130
  "primary_outcomes": primary_outcomes,
@@ -137,22 +137,56 @@ def process_json_data_for_llm(data):
137
  # print(ele)
138
 
139
 
140
- def llm_config():
141
- tagging_prompt = ChatPromptTemplate.from_template(
142
- """
143
- Extract the desired information from the following list of JSON clinical trials.
144
 
145
- Only extract the properties mentioned in the 'Classification' function.
 
146
 
147
- Passage:
148
- {input}
149
 
150
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  class Classification(BaseModel):
154
  description: str = Field(
155
- description="text description grouping all the clinical trials using brief_description and detailed_description keys"
156
  )
157
  project_title: list = Field(
158
  description="Extract the project title of all the clinical trials"
@@ -160,9 +194,9 @@ def llm_config():
160
  status: list = Field(
161
  description="Extract the status of all the clinical trials"
162
  )
163
- keywords: list = Field(
164
- description="Extract the most relevant keywords regrouping all the clinical trials"
165
- )
166
  interventions: list = Field(
167
  description="describe the interventions for each clinical trial using title, name and description"
168
  )
@@ -170,17 +204,17 @@ def llm_config():
170
  description="get the primary outcomes of each clinical trial"
171
  )
172
  # secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
173
- eligibility: list = Field(
174
- description="get the eligibilityCriteria grouping all the clinical trials"
175
- )
176
  # healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
177
- minimum_age: list = Field(
178
- description="get the minimum age from each experiment"
179
- )
180
- maximum_age: list = Field(
181
- description="get the maximum age from each experiment"
182
- )
183
- gender: list = Field(description="get the gender from each experiment")
184
 
185
  def get_dict(self):
186
  return {
@@ -205,9 +239,11 @@ def llm_config():
205
  openai_api_key=os.environ["OPENAI_API_KEY"],
206
  ).with_structured_output(Classification)
207
 
208
- tagging_chain = tagging_prompt | llm
209
 
210
- return tagging_chain
 
 
211
 
212
 
213
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
@@ -216,9 +252,10 @@ def llm_config():
216
  # with open('data.json', 'w') as f:
217
  # json.dump(clinical_record_info, f, indent=4)
218
 
219
- tagging_chain = llm_config()
 
220
 
221
  def process_dictionaty_with_llm_to_generate_response(json_contents):
222
  processed_data = process_json_data_for_llm(json_contents)
223
- res = tagging_chain.invoke({"input": processed_data})
224
- return res
 
1
+ import ast
2
  import json
 
 
 
 
3
  import os
4
+ from typing import Any, Dict, List
5
+
6
+ import langchain
7
  import openai
8
+ import pandas as pd
9
+ import requests
10
+ from dotenv import load_dotenv
11
  from langchain import OpenAI
12
+ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
13
  from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
 
 
14
  from langchain.document_loaders import UnstructuredURLLoader
15
  from langchain.embeddings import OpenAIEmbeddings
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain.vectorstores import FAISS
18
+ from langchain_community.document_loaders import JSONLoader
19
+ from langchain_community.document_loaders.csv_loader import CSVLoader
20
  from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.pydantic_v1 import BaseModel, Field
22
  from langchain_openai import ChatOpenAI
23
+ from langchain.chains.llm import LLMChain
24
+ from langchain_core.prompts import PromptTemplate
 
 
 
 
25
 
26
  load_dotenv()
27
 
 
78
  except:
79
  status = ""
80
  try:
81
+ briefDescription = item["protocolSection"]["descriptionModule"][
82
  "briefSummary"
83
  ]
84
  except:
85
+ briefDescription = ""
86
  try:
87
+ detailedDescription = item["protocolSection"]["descriptionModule"][
88
  "detailedDescription"
89
  ]
90
  except:
91
+ detailedDescription = ""
92
  try:
93
  conditions = item["protocolSection"]["conditionsModule"]["conditions"]
94
  except:
 
123
  "organization_name": organization_name,
124
  "project_title": project_title,
125
  "status": status,
126
+ "briefDescription": briefDescription,
127
+ "detailedDescription": detailedDescription,
128
  "keywords": keywords,
129
  "interventions": interventions,
130
  "primary_outcomes": primary_outcomes,
 
137
  # print(ele)
138
 
139
 
140
+ def get_short_summary_out_of_json_files(data_json):
141
+ prompt_template = """ You are an expert clinician working on the analysis of reports of clinical trials.
 
 
142
 
143
+ # Task
144
+ You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
145
 
146
+ To write your summary, you will need to read the following examples, labeled as "Report 1", "Report 2", and so on. Your answer should be a single paragraph (100-200 words) that summarizes the general content of all the reports.
 
147
 
148
+ {text}
149
+
150
+ General summary:"""
151
+
152
+ prompt = PromptTemplate.from_template(prompt_template)
153
+
154
+ llm = ChatOpenAI(
155
+ temperature=0.4, model_name="gpt-4-turbo", api_key=os.environ["OPENAI_API_KEY"]
156
+ )
157
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
158
+
159
+ # Define StuffDocumentsChain
160
+ stuff_chain = StuffDocumentsChain(
161
+ llm_chain=llm_chain, document_variable_name="text"
162
  )
163
 
164
+ descriptions = [
165
+ (
166
+ x["detailedDescription"]
167
+ if "detailedDescription" in x and len(x["detailedDescription"]) > 0
168
+ else x["briefSummary"]
169
+ )
170
+ for x in data_json
171
+ if "detailedDescription" in x or "briefSummary" in x
172
+ ]
173
+
174
+ combined_descriptions = ""
175
+ for i, description in enumerate(descriptions):
176
+ combined_descriptions += f"Report {i+1}:\n{description}\n"
177
+
178
+ print(f"Combined descriptions: {combined_descriptions}")
179
+
180
+ result = stuff_chain.run(combined_descriptions)
181
+ print(f"Result: {result}")
182
+
183
+ return result
184
+
185
+
186
+ def taggingTemplate():
187
  class Classification(BaseModel):
188
  description: str = Field(
189
+ description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
190
  )
191
  project_title: list = Field(
192
  description="Extract the project title of all the clinical trials"
 
194
  status: list = Field(
195
  description="Extract the status of all the clinical trials"
196
  )
197
+ # keywords: list = Field(
198
+ # description="Extract the most relevant keywords regrouping all the clinical trials"
199
+ # )
200
  interventions: list = Field(
201
  description="describe the interventions for each clinical trial using title, name and description"
202
  )
 
204
  description="get the primary outcomes of each clinical trial"
205
  )
206
  # secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
207
+ # eligibility: list = Field(
208
+ # description="get the eligibilityCriteria grouping all the clinical trials"
209
+ # )
210
  # healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
211
+ # minimum_age: list = Field(
212
+ # description="get the minimum age from each experiment"
213
+ # )
214
+ # maximum_age: list = Field(
215
+ # description="get the maximum age from each experiment"
216
+ # )
217
+ # gender: list = Field(description="get the gender from each experiment")
218
 
219
  def get_dict(self):
220
  return {
 
239
  openai_api_key=os.environ["OPENAI_API_KEY"],
240
  ).with_structured_output(Classification)
241
 
242
+ stuff_chain = StuffDocumentsChain(llm_chain=llm, document_variable_name="text")
243
 
244
+ # tagging_chain = prompt_template | llm
245
+
246
+ # return tagging_chain
247
 
248
 
249
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
 
252
  # with open('data.json', 'w') as f:
253
  # json.dump(clinical_record_info, f, indent=4)
254
 
255
+ # tagging_chain = llm_config()
256
+
257
 
258
  def process_dictionaty_with_llm_to_generate_response(json_contents):
259
  processed_data = process_json_data_for_llm(json_contents)
260
+ # res = tagging_chain.invoke({"input": processed_data})
261
+ # return res