klinic / llm_res.py
ACMCMC
WIP
1e2e3b8
raw
history blame
No virus
7.87 kB
import json
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import langchain
import os
import openai
import ast
from langchain import OpenAI
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import JSONLoader
from langchain.document_loaders import UnstructuredURLLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from typing import List, Dict, Any
import requests
from dotenv import load_dotenv
load_dotenv()
# getting the json files
def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]:
# Request:
# curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \
# -H "accept: text/csv"
request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}"
response = requests.get(request_url, headers={"accept": "application/json"})
return response.json()
def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]:
clinical_records = []
for clinical_record_id in clinical_record_ids:
clinical_record_info = get_clinical_record_info(clinical_record_id)
clinical_records.append(clinical_record_info)
return clinical_records
def process_json_data_for_llm(data):
# Define the fields you want to keep
fields_to_keep = [
"class_of_organization",
"title",
"overallStatus",
"descriptionModule",
"conditions",
"interventions",
"outcomesModule",
"eligibilityModule",
]
# Iterate through the dictionary and keep only the desired fields
filtered_data = []
for item in data:
try:
organization_name = item["protocolSection"]["identificationModule"][
"organization"
]["fullName"]
except:
organization_name = ""
try:
project_title = item["protocolSection"]["identificationModule"][
"officialTitle"
]
except:
project_title = ""
try:
status = item["protocolSection"]["statusModule"]["overallStatus"]
except:
status = ""
try:
brief_description = item["protocolSection"]["descriptionModule"][
"briefSummary"
]
except:
brief_description = ""
try:
detailed_description = item["protocolSection"]["descriptionModule"][
"detailedDescription"
]
except:
detailed_description = ""
try:
conditions = item["protocolSection"]["conditionsModule"]["conditions"]
except:
conditions = []
try:
keywords = item["protocolSection"]["conditionsModule"]["keywords"]
except:
keywords = []
try:
interventions = item["protocolSection"]["armsInterventionsModule"][
"interventions"
]
except:
interventions = []
try:
primary_outcomes = item["protocolSection"]["outcomesModule"][
"primaryOutcomes"
]
except:
primary_outcomes = []
try:
secondary_outcomes = item["protocolSection"]["outcomesModule"][
"secondaryOutcomes"
]
except:
secondary_outcomes = []
try:
eligibility = item["protocolSection"]["eligibilityModule"]
except:
eligibility = {}
filtered_item = {
"organization_name": organization_name,
"project_title": project_title,
"status": status,
"brief_description": brief_description,
"detailed_description": detailed_description,
"keywords": keywords,
"interventions": interventions,
"primary_outcomes": primary_outcomes,
"secondary_outcomes": secondary_outcomes,
"eligibility": eligibility,
}
filtered_data.append(filtered_item)
# for ele in filtered_data:
# print(ele)
def llm_config():
tagging_prompt = ChatPromptTemplate.from_template(
"""
Extract the desired information from the following list of JSON clinical trials.
Only extract the properties mentioned in the 'Classification' function.
Passage:
{input}
"""
)
class Classification(BaseModel):
description: str = Field(
description="text description grouping all the clinical trials using brief_description and detailed_description keys"
)
project_title: list = Field(
description="Extract the project title of all the clinical trials"
)
status: list = Field(
description="Extract the status of all the clinical trials"
)
keywords: list = Field(
description="Extract the most relevant keywords regrouping all the clinical trials"
)
interventions: list = Field(
description="describe the interventions for each clinical trial using title, name and description"
)
primary_outcomes: list = Field(
description="get the primary outcomes of each clinical trial"
)
# secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
eligibility: list = Field(
description="get the eligibilityCriteria grouping all the clinical trials"
)
# healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
minimum_age: list = Field(
description="get the minimum age from each experiment"
)
maximum_age: list = Field(
description="get the maximum age from each experiment"
)
gender: list = Field(description="get the gender from each experiment")
def get_dict(self):
return {
"summary": self.description,
"project_title": self.project_title,
"status": self.status,
"keywords": self.keywords,
"interventions": self.interventions,
"primary_outcomes": self.primary_outcomes,
# "secondary_outcomes": self.secondary_outcomes,
"eligibility": self.eligibility,
# "healthy_volunteers": self.healthy_volunteers,
"minimum_age": self.minimum_age,
"maximum_age": self.maximum_age,
"gender": self.gender,
}
# LLM
llm = ChatOpenAI(
temperature=0.6,
model="gpt-4",
openai_api_key=os.environ["OPENAI_API_KEY"],
).with_structured_output(Classification)
tagging_chain = tagging_prompt | llm
return tagging_chain
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
# print(clinical_record_info)
# with open('data.json', 'w') as f:
# json.dump(clinical_record_info, f, indent=4)
tagging_chain = llm_config()
def process_dictionaty_with_llm_to_generate_response(json_contents):
processed_data = process_json_data_for_llm(json_contents)
res = tagging_chain.invoke({"input": processed_data})
return res