|
import os |
|
import re |
|
import json |
|
import torch |
|
|
|
import openai |
|
import logging |
|
import asyncio |
|
import aiohttp |
|
import pandas as pd |
|
import numpy as np |
|
import evaluate |
|
import qdrant_client |
|
from pypdf import PdfReader |
|
from pydantic import BaseModel, Field |
|
from typing import Any, List, Tuple, Set, Dict, Optional, Union |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
from unstructured.partition.pdf import partition_pdf |
|
|
|
import llama_index |
|
from llama_index import PromptTemplate |
|
from llama_index.retrievers import VectorIndexRetriever, BaseRetriever, BM25Retriever |
|
from llama_index.query_engine import RetrieverQueryEngine |
|
from llama_index import get_response_synthesizer |
|
from llama_index.schema import NodeWithScore |
|
from llama_index.query_engine import RetrieverQueryEngine |
|
from llama_index import VectorStoreIndex, ServiceContext |
|
from llama_index.embeddings import OpenAIEmbedding |
|
from llama_index.llms import HuggingFaceLLM |
|
import requests |
|
from llama_index.llms import ( |
|
CustomLLM, |
|
CompletionResponse, |
|
CompletionResponseGen, |
|
LLMMetadata, |
|
) |
|
from llama_index.query_engine import RetrieverQueryEngine |
|
from llama_index.llms.base import llm_completion_callback |
|
from llama_index.vector_stores.qdrant import QdrantVectorStore |
|
from llama_index.storage.storage_context import StorageContext |
|
from llama_index.postprocessor import SentenceTransformerRerank, LLMRerank |
|
|
|
from tempfile import NamedTemporaryFile |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class ConfigManager: |
|
""" |
|
A class to manage loading and accessing configuration settings. |
|
|
|
Attributes: |
|
config (dict): Dictionary to hold configuration settings. |
|
|
|
Methods: |
|
load_config(config_path: str): Loads the configuration from a given JSON file. |
|
get_config_value(key: str): Retrieves a specific configuration value. |
|
""" |
|
|
|
def __init__(self): |
|
self.configs = {} |
|
|
|
def load_config(self, config_name: str, config_path: str) -> None: |
|
""" |
|
Loads configuration settings from a specified JSON file into a named configuration. |
|
|
|
Args: |
|
config_name (str): The name to assign to this set of configurations. |
|
config_path (str): The path to the configuration file. |
|
|
|
Raises: |
|
FileNotFoundError: If the config file is not found. |
|
json.JSONDecodeError: If there is an error parsing the config file. |
|
""" |
|
try: |
|
with open(config_path, 'r') as f: |
|
self.configs[config_name] = json.load(f) |
|
except FileNotFoundError: |
|
logging.error(f"Config file not found at {config_path}") |
|
raise |
|
except json.JSONDecodeError as e: |
|
logging.error(f"Error decoding config file: {e}") |
|
raise |
|
|
|
|
|
def get_config_value(self, config_name: str, key: str) -> str: |
|
""" |
|
Retrieves a specific configuration value. |
|
|
|
Args: |
|
key (str): The key for the configuration setting. |
|
|
|
Returns: |
|
str: The value of the configuration setting. |
|
|
|
Raises: |
|
ValueError: If the key is not found or is set to a placeholder value. |
|
""" |
|
value = self.configs.get(config_name, {}).get(key) |
|
if value is None or value == "ENTER_YOUR_TOKEN_HERE": |
|
raise ValueError(f"Please set your '{key}' in the config.json file.") |
|
return value |
|
|
|
class base_utils: |
|
""" |
|
A utility class providing miscellaneous static methods for processing and analyzing text data, |
|
particularly from PDF documents and filenames. This class also includes methods for file operations. |
|
|
|
This class encapsulates the functionality of extracting key information from text, such as scores, |
|
reasoning, and IDs, locating specific data within a DataFrame based on an ID extracted from a filename, |
|
and reading content from files. |
|
|
|
Attributes: |
|
None (This class contains only static methods and does not maintain any state) |
|
|
|
Methods: |
|
extract_score_reasoning(text: str) -> Dict[str, Optional[str]]: |
|
Extracts a score and reasoning from a given text using regular expressions. |
|
|
|
extract_id_from_filename(filename: str) -> Optional[int]: |
|
Extracts an ID from a given filename based on a specified pattern. |
|
|
|
find_row_for_pdf(pdf_filename: str, dataframe: pd.DataFrame) -> Union[pd.Series, str]: |
|
Searches for a row in a DataFrame that matches an ID extracted from a PDF filename. |
|
|
|
read_from_file(file_path: str) -> str: |
|
Reads the content of a file and returns it as a string. |
|
""" |
|
|
|
@staticmethod |
|
def read_from_file(file_path: str) -> str: |
|
""" |
|
Reads the content of a file and returns it as a string. |
|
|
|
Args: |
|
file_path (str): The path to the file to be read. |
|
|
|
Returns: |
|
str: The content of the file. |
|
""" |
|
with open(file_path, 'r') as prompt_file: |
|
prompt = prompt_file.read() |
|
return prompt |
|
|
|
@staticmethod |
|
def extract_id_from_filename(filename: str) -> Optional[int]: |
|
""" |
|
Extracts an ID from a filename, assuming a specific format ('Id_{I}.pdf', where {I} is the ID). |
|
|
|
Args: |
|
filename (str): The filename from which to extract the ID. |
|
|
|
Returns: |
|
int: The extracted ID as an integer, or None if the pattern is not found. |
|
""" |
|
|
|
match = re.search(r'Id_(\d+).pdf', filename) |
|
if match: |
|
return int(match.group(1)) |
|
else: |
|
return None |
|
|
|
@staticmethod |
|
def extract_score_reasoning(text: str) -> Dict[str, Optional[str]]: |
|
""" |
|
Extracts score and the longest reasoning from a given text using regular expressions. |
|
|
|
Args: |
|
text (str): The text from which to extract the score and reasoning. |
|
|
|
Returns: |
|
dict: A dictionary containing 'score' and 'reasoning', extracted from the text. |
|
""" |
|
|
|
score_pattern = r"Score: (\d+)" |
|
reasoning_pattern = r"Reasoning: (\S.+)" |
|
|
|
|
|
score_match = re.search(score_pattern, text) |
|
|
|
|
|
reasoning_matches = re.findall(reasoning_pattern, text, re.DOTALL) |
|
|
|
|
|
longest_reasoning = min(reasoning_matches, key=len) if reasoning_matches else None |
|
|
|
|
|
extracted_data = { |
|
"score": score_match.group(1) if score_match else None, |
|
"reasoning": longest_reasoning.strip() if longest_reasoning else None |
|
} |
|
|
|
return extracted_data |
|
|
|
|
|
@staticmethod |
|
def find_row_for_pdf(pdf_filename: str, dataframe: pd.DataFrame) -> Union[pd.Series, str]: |
|
""" |
|
Finds the row in a dataframe corresponding to the ID extracted from a given PDF filename. |
|
|
|
Args: |
|
pdf_filename (str): The filename of the PDF. |
|
dataframe (pandas.DataFrame): The dataframe in which to find the corresponding row. |
|
|
|
Returns: |
|
pandas.Series or str: The matched row from the dataframe or a message indicating |
|
that no matching row or invalid filename was found. |
|
""" |
|
pdf_id = Utility.extract_id_from_filename(pdf_filename) |
|
if pdf_id is not None: |
|
|
|
matched_row = dataframe[dataframe.iloc[:, 0] == pdf_id] |
|
if not matched_row.empty: |
|
return matched_row |
|
else: |
|
return "No matching row found." |
|
else: |
|
return "Invalid file name." |
|
|
|
|
|
class PDFProcessor_Unstructured: |
|
""" |
|
A class to process PDF files, providing functionalities for extracting, categorizing, |
|
and merging elements from a PDF file. |
|
|
|
This class is designed to handle unstructured PDF documents, particularly useful for |
|
tasks involving text extraction, categorization, and data processing within PDFs. |
|
|
|
Attributes: |
|
file_path (str): The full path to the PDF file. |
|
folder_path (str): The directory path where the PDF file is located. |
|
file_name (str): The name of the PDF file. |
|
texts (List[str]): A list to store extracted text chunks. |
|
tables (List[str]): A list to store extracted tables. |
|
|
|
|
|
Methods: |
|
extract_pdf_elements() -> List: |
|
Extracts images, tables, and text chunks from a PDF file. |
|
|
|
categorize_elements(raw_pdf_elements: List) -> None: |
|
Categorizes extracted elements from a PDF into tables and texts. |
|
|
|
merge_chunks() -> List[str]: |
|
Merges text chunks based on punctuation and character case criteria. |
|
|
|
should_skip_chunk(chunk: str) -> bool: |
|
Determines if a chunk should be skipped based on its content. |
|
|
|
should_merge_with_next(current_chunk: str, next_chunk: str) -> bool: |
|
Determines if the current chunk should be merged with the next one. |
|
|
|
process_pdf() -> Tuple[List[str], List[str]]: |
|
Processes the PDF by extracting, categorizing, and merging elements. |
|
|
|
process_pdf_file(uploaded_file) -> Tuple[List[str], List[str]]: |
|
Processes an uploaded PDF file to extract and categorize text and tables. |
|
""" |
|
|
|
def __init__(self, config: Dict[str, any]): |
|
self.file_path = None |
|
self.folder_path = None |
|
self.file_name = None |
|
self.texts = [] |
|
self.tables = [] |
|
self.config = config if config is not None else self.default_config() |
|
logger.info(f"Initialized PdfProcessor_Unstructured for file: {self.file_name}") |
|
|
|
@staticmethod |
|
def default_config() -> Dict[str, any]: |
|
""" |
|
Returns the default configuration for PDF processing. |
|
|
|
Returns: |
|
Dict[str, any]: Default configuration options. |
|
""" |
|
return { |
|
"extract_images": False, |
|
"infer_table_structure": True, |
|
"chunking_strategy": "by_title", |
|
"max_characters": 10000, |
|
"combine_text_under_n_chars": 100, |
|
"strategy": "fast", |
|
"model_name": "yolox" |
|
} |
|
|
|
|
|
def extract_pdf_elements(self) -> List: |
|
""" |
|
Extracts images, tables, and text chunks from a PDF file. |
|
|
|
Returns: |
|
List: A list of extracted elements from the PDF. |
|
""" |
|
logger.info("Starting extraction of PDF elements.") |
|
try: |
|
extracted_elements = partition_pdf( |
|
filename=self.file_path, |
|
extract_images_in_pdf=False, |
|
infer_table_structure=True, |
|
chunking_strategy="by_title", |
|
strategy = "fast", |
|
max_characters=10000, |
|
combine_text_under_n_chars=100, |
|
image_output_dir_path=self.folder_path, |
|
) |
|
logger.info("Extraction of PDF elements completed successfully.") |
|
return extracted_elements |
|
except Exception as e: |
|
logger.error(f"Error extracting PDF elements: {e}", exc_info=True) |
|
raise |
|
|
|
def categorize_elements(self, raw_pdf_elements: List) -> None: |
|
""" |
|
Categorizes extracted elements from a PDF into tables and texts. |
|
|
|
Args: |
|
raw_pdf_elements (List): A list of elements extracted from the PDF. |
|
""" |
|
logger.debug("Starting categorization of PDF elements.") |
|
for element in raw_pdf_elements: |
|
element_type = str(type(element)) |
|
if "unstructured.documents.elements.Table" in element_type: |
|
self.tables.append(str(element)) |
|
elif "unstructured.documents.elements.CompositeElement" in element_type: |
|
self.texts.append(str(element)) |
|
|
|
logger.debug("Categorization of PDF elements completed.") |
|
|
|
def merge_chunks(self) -> List[str]: |
|
""" |
|
Merges text chunks based on punctuation and character case criteria. |
|
|
|
Returns: |
|
List[str]: A list of merged text chunks. |
|
""" |
|
logger.debug("Starting merging of text chunks.") |
|
|
|
merged_chunks = [] |
|
skip_next = False |
|
|
|
for i, current_chunk in enumerate(self.texts[:-1]): |
|
next_chunk = self.texts[i + 1] |
|
|
|
if self.should_skip_chunk(current_chunk): |
|
continue |
|
|
|
if self.should_merge_with_next(current_chunk, next_chunk): |
|
merged_chunks.append(current_chunk + " " + next_chunk) |
|
skip_next = True |
|
else: |
|
merged_chunks.append(current_chunk) |
|
|
|
if not skip_next: |
|
merged_chunks.append(self.texts[-1]) |
|
|
|
logger.debug("Merging of text chunks completed.") |
|
|
|
return merged_chunks |
|
|
|
@staticmethod |
|
def should_skip_chunk(chunk: str) -> bool: |
|
""" |
|
Determines if a chunk should be skipped based on its content. |
|
|
|
Args: |
|
chunk (str): The text chunk to be evaluated. |
|
|
|
Returns: |
|
bool: True if the chunk should be skipped, False otherwise. |
|
""" |
|
return (chunk.lower().startswith(("figure", "fig", "table")) or |
|
not chunk[0].isalnum() or |
|
re.match(r'^\d+\.', chunk)) |
|
|
|
@staticmethod |
|
def should_merge_with_next(current_chunk: str, next_chunk: str) -> bool: |
|
""" |
|
Determines if the current chunk should be merged with the next one. |
|
|
|
Args: |
|
current_chunk (str): The current text chunk. |
|
next_chunk (str): The next text chunk. |
|
|
|
Returns: |
|
bool: True if the chunks should be merged, False otherwise. |
|
""" |
|
return (current_chunk.endswith(",") or |
|
(current_chunk[-1].islower() and next_chunk[0].islower())) |
|
|
|
def extract_title_from_pdf(self, uploaded_file): |
|
""" |
|
Extracts the title from a PDF file's metadata. |
|
|
|
This function reads the metadata of a PDF file using PyPDF2 and attempts to |
|
extract the title. If the title is present in the metadata, it is returned. |
|
Otherwise, a default message indicating that the title was not found is returned. |
|
|
|
Parameters: |
|
uploaded_file (file): A file object or a path to the PDF file from which |
|
to extract the title. The file must be opened in binary mode. |
|
|
|
Returns: |
|
str: The title of the PDF file as a string. If no title is found, returns |
|
'Title not found'. |
|
""" |
|
|
|
pdf_reader = PdfReader(uploaded_file) |
|
|
|
|
|
meta = pdf_reader.metadata |
|
|
|
|
|
title = meta.title if meta and meta.title else 'Title not found' |
|
return title |
|
|
|
def process_pdf(self) -> Tuple[List[str], List[str]]: |
|
""" |
|
Processes the PDF by extracting, categorizing, and merging elements. |
|
|
|
Returns: |
|
Tuple[List[str], List[str]]: A tuple of merged text chunks and tables. |
|
""" |
|
logger.info("Starting processing of the PDF.") |
|
try: |
|
raw_pdf_elements = self.extract_pdf_elements() |
|
self.categorize_elements(raw_pdf_elements) |
|
merged_chunks = self.merge_chunks() |
|
return merged_chunks, self.tables |
|
except Exception as e: |
|
logger.error(f"Error processing PDF: {e}", exc_info=True) |
|
raise |
|
|
|
def process_pdf_file(self, uploaded_file): |
|
""" |
|
Process an uploaded PDF file. |
|
|
|
If a new file is uploaded, the previously stored file is deleted. |
|
The method updates the file path, processes the PDF, and returns the results. |
|
|
|
Parameters: |
|
uploaded_file: The new PDF file uploaded for processing. |
|
|
|
Returns: |
|
The results of processing the PDF file. |
|
""" |
|
|
|
if self.file_path and os.path.exists(self.file_path): |
|
try: |
|
os.remove(self.file_path) |
|
logging.debug(f"Previous file {self.file_path} deleted.") |
|
except Exception as e: |
|
logging.warning(f"Error deleting previous file: {e}", exc_info=True) |
|
|
|
|
|
self.file_path = str(uploaded_file) |
|
self.folder_path = os.path.dirname(self.file_path) |
|
logging.info(f"Starting to process the PDF file: {self.file_path}") |
|
|
|
try: |
|
logging.debug(f"Processing PDF at {self.file_path}") |
|
results = self.process_pdf() |
|
title = self.extract_title_from_pdf(self.file_path) |
|
logging.info("PDF processing completed successfully.") |
|
return (*results, title) |
|
except Exception as e: |
|
logging.error(f"Error processing PDF file: {e}", exc_info=True) |
|
raise |
|
|
|
|
|
class HybridRetriever(BaseRetriever): |
|
""" |
|
A hybrid retriever that combines results from vector-based and BM25 retrieval methods. |
|
Inherits from BaseRetriever. |
|
|
|
This class uses two different retrieval methods and merges their results to provide a |
|
comprehensive set of documents in response to a query. It ensures diversity in the |
|
retrieved documents by leveraging the strengths of both retrieval methods. |
|
|
|
Attributes: |
|
vector_retriever: An instance of a vector-based retriever. |
|
bm25_retriever: An instance of a BM25 retriever. |
|
|
|
Methods: |
|
__init__(vector_retriever, bm25_retriever): Initializes the HybridRetriever with vector and BM25 retrievers. |
|
_retrieve(query, **kwargs): Performs the retrieval operation by combining results from both retrievers. |
|
_combine_results(bm25_nodes, vector_nodes): Combines and de-duplicates the results from both retrievers. |
|
""" |
|
|
|
def __init__(self, vector_retriever, bm25_retriever): |
|
super().__init__() |
|
self.vector_retriever = vector_retriever |
|
self.bm25_retriever = bm25_retriever |
|
logger.info("HybridRetriever initialized with vector and BM25 retrievers.") |
|
|
|
def _retrieve(self, query: str, **kwargs) -> List: |
|
""" |
|
Retrieves and combines results from both vector and BM25 retrievers. |
|
|
|
Args: |
|
query: The query string for document retrieval. |
|
**kwargs: Additional keyword arguments for retrieval. |
|
|
|
Returns: |
|
List: Combined list of unique nodes retrieved from both methods. |
|
""" |
|
logger.info(f"Retrieving documents for query: {query}") |
|
try: |
|
bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs) |
|
vector_nodes = self.vector_retriever.retrieve(query, **kwargs) |
|
combined_nodes = self._combine_results(bm25_nodes, vector_nodes) |
|
|
|
logger.info(f"Retrieved {len(combined_nodes)} unique nodes combining vector and BM25 retrievers.") |
|
return combined_nodes |
|
except Exception as e: |
|
logger.error(f"Error in retrieval: {e}") |
|
raise |
|
|
|
@staticmethod |
|
def _combine_results(bm25_nodes: List, vector_nodes: List) -> List: |
|
""" |
|
Combines and de-duplicates results from BM25 and vector retrievers. |
|
|
|
Args: |
|
bm25_nodes: Nodes retrieved from BM25 retriever. |
|
vector_nodes: Nodes retrieved from vector retriever. |
|
|
|
Returns: |
|
List: Combined list of unique nodes. |
|
""" |
|
node_ids: Set = set() |
|
combined_nodes = [] |
|
|
|
for node in bm25_nodes + vector_nodes: |
|
if node.node_id not in node_ids: |
|
combined_nodes.append(node) |
|
node_ids.add(node.node_id) |
|
|
|
return combined_nodes |
|
|
|
|
|
|
|
class PDFQueryEngine: |
|
""" |
|
A class to handle the process of setting up a query engine and performing queries on PDF documents. |
|
|
|
This class encapsulates the functionality of creating prompt templates, embedding models, service contexts, |
|
indexes, hybrid retrievers, response synthesizers, and executing queries on the set up engine. |
|
|
|
Attributes: |
|
documents (List): A list of documents to be indexed. |
|
llm (Language Model): The language model to be used for embeddings and queries. |
|
qa_prompt_tmpl (str): Template for creating query prompts. |
|
queries (List[str]): List of queries to be executed. |
|
|
|
Methods: |
|
setup_query_engine(): Sets up the query engine with all necessary components. |
|
execute_queries(): Executes the predefined queries and prints the results. |
|
""" |
|
|
|
def __init__(self, documents: List[Any], llm: Any, embed_model: Any, qa_prompt_tmpl: Any): |
|
|
|
self.documents = documents |
|
self.llm = llm |
|
self.embed_model = embed_model |
|
self.qa_prompt_tmpl = qa_prompt_tmpl |
|
self.base_utils = base_utils() |
|
self.config_manager = ConfigManager() |
|
|
|
|
|
|
|
logger.info("PDFQueryEngine initialized.") |
|
|
|
def format_example(self, example): |
|
""" |
|
Formats a few-shot example into a string. |
|
|
|
Args: |
|
example (dict): A dictionary containing 'query', 'score', and 'reasoning' for the few-shot example. |
|
|
|
Returns: |
|
str: Formatted few-shot example text. |
|
""" |
|
return "Example:\nQuery: {}\nScore: {}\nReasoning: {}\n".format( |
|
example['query'], example['score'], example['reasoning'] |
|
) |
|
|
|
|
|
def setup_query_engine(self): |
|
""" |
|
Sets up the query engine by initializing and configuring the embedding model, service context, index, |
|
hybrid retriever (combining vector and BM25 retrievers), and the response synthesizer. |
|
|
|
Args: |
|
embed_model: The embedding model to be used. |
|
service_context: The context for providing services to the query engine. |
|
index: The index used for storing and retrieving documents. |
|
hybrid_retriever: The retriever that combines vector and BM25 retrieval methods. |
|
response_synthesizer: The synthesizer for generating responses to queries. |
|
|
|
Returns: |
|
Any: The configured query engine. |
|
""" |
|
client = qdrant_client.QdrantClient( |
|
|
|
|
|
|
|
location=":memory:" |
|
|
|
|
|
|
|
|
|
) |
|
try: |
|
logger.info("Initializing the service context for query engine setup.") |
|
service_context = ServiceContext.from_defaults(llm=self.llm, embed_model=self.embed_model) |
|
vector_store = QdrantVectorStore(client=client, collection_name="med_library") |
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
|
|
logger.info("Creating an index from documents.") |
|
index = VectorStoreIndex.from_documents(documents=self.documents, storage_context=storage_context, service_context=service_context) |
|
nodes = service_context.node_parser.get_nodes_from_documents(self.documents) |
|
|
|
logger.info("Setting up vector and BM25 retrievers.") |
|
vector_retriever = index.as_retriever(similarity_top_k=3) |
|
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=3) |
|
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever) |
|
|
|
logger.info("Configuring the response synthesizer with the prompt template.") |
|
qa_prompt = PromptTemplate(self.qa_prompt_tmpl) |
|
response_synthesizer = get_response_synthesizer( |
|
service_context=service_context, |
|
text_qa_template=qa_prompt, |
|
response_mode="compact", |
|
) |
|
|
|
logger.info("Assembling the query engine with reranker and synthesizer.") |
|
reranker = SentenceTransformerRerank(top_n=3, model="BAAI/bge-reranker-base") |
|
query_engine = RetrieverQueryEngine.from_args( |
|
retriever=hybrid_retriever, |
|
node_postprocessors=[reranker], |
|
response_synthesizer=response_synthesizer, |
|
) |
|
|
|
logger.info("Query engine setup complete.") |
|
return query_engine |
|
except Exception as e: |
|
logger.error(f"Error during query engine setup: {e}") |
|
raise |
|
|
|
def evaluate_with_llm(self, reg_result: Any, peer_result: Any, guidelines_result: Any, queries: List[str]) -> Tuple[int, List[int], int, float, List[str]]: |
|
""" |
|
Evaluate documents using a language model based on various criteria. |
|
Args: |
|
reg_result (Any): Result related to registration. |
|
peer_result (Any): Result related to peer review. |
|
guidelines_result (Any): Result related to following guidelines. |
|
queries (List[str]): A list of queries to be processed. |
|
Returns: |
|
Tuple[int, List[int], int, float, List[str]]: A tuple containing the total score, a list of scores per criteria. |
|
""" |
|
|
|
logger.info("Starting evaluation with LLM.") |
|
self.config_manager.load_config("few_shot", "few_shot.json") |
|
query_engine = self.setup_query_engine() |
|
|
|
total_score = 0 |
|
criteria_met = 0 |
|
reasoning = [] |
|
|
|
for j, query in enumerate(queries): |
|
|
|
if j == 1 and reg_result: |
|
extracted_data = {"score": 1, "reasoning": reg_result[0]} |
|
elif j == 2 and guidelines_result: |
|
extracted_data = {"score": 1, "reasoning": "The article is published in a journal following EQUATOR-NETWORK reporting guidelines"} |
|
elif j == 8 and (guidelines_result or peer_result): |
|
extracted_data = {"score": 1, "reasoning": "The article is published in a peer-reviewed journal."} |
|
else: |
|
|
|
|
|
result = query_engine.query(query).response |
|
print(result) |
|
extracted_data = self.base_utils.extract_score_reasoning(result) |
|
|
|
|
|
|
|
extracted_data_score = 0 if extracted_data.get("score") is None else int(extracted_data.get("score")) |
|
if extracted_data_score > 0: |
|
criteria_met += 1 |
|
reasoning.append(extracted_data["reasoning"]) |
|
total_score += extracted_data_score |
|
|
|
score_percentage = (float(total_score) / len(queries)) * 100 |
|
logger.info("Evaluation completed.") |
|
return total_score, criteria_met, score_percentage, reasoning |
|
|
|
|
|
|
|
class MixtralLLM(CustomLLM): |
|
""" |
|
A custom language model class for interfacing with the Hugging Face API, specifically using the Mixtral model. |
|
|
|
Attributes: |
|
context_window (int): Number of tokens used for context during inference. |
|
num_output (int): Number of tokens to generate as output. |
|
temperature (float): Sampling temperature for token generation. |
|
model_name (str): Name of the model on Hugging Face's model hub. |
|
api_key (str): API key for authenticating with the Hugging Face API. |
|
|
|
Methods: |
|
metadata: Retrieves metadata about the model. |
|
do_hf_call: Makes an API call to the Hugging Face model. |
|
complete: Generates a complete response for a given prompt. |
|
stream_complete: Streams a series of token completions for a given prompt. |
|
""" |
|
context_window: int = Field(..., description="Number of tokens used for context during inference.") |
|
num_output: int = Field(..., description="Number of tokens to generate as output.") |
|
temperature: float = Field(..., description="Sampling temperature for token generation.") |
|
model_name: str = Field(..., description="Name of the model on Hugging Face's model hub.") |
|
api_key: str = Field(..., description="API key for authenticating with the Hugging Face API.") |
|
|
|
|
|
@property |
|
def metadata(self) -> LLMMetadata: |
|
""" |
|
Retrieves metadata for the Mixtral LLM. |
|
|
|
Returns: |
|
LLMMetadata: An object containing metadata such as context window, number of outputs, and model name. |
|
""" |
|
return LLMMetadata( |
|
context_window=self.context_window, |
|
num_output=self.num_output, |
|
model_name=self.model_name, |
|
) |
|
|
|
def do_hf_call(self, prompt: str) -> str: |
|
""" |
|
Makes an API call to the Hugging Face model and retrieves the generated response. |
|
|
|
Args: |
|
prompt (str): The input prompt for the model. |
|
|
|
Returns: |
|
str: The text generated by the model in response to the prompt. |
|
|
|
Raises: |
|
Exception: If the API call fails or returns an error. |
|
""" |
|
data = { |
|
"inputs": prompt, |
|
"parameters": {"Temperature": self.temperature} |
|
} |
|
|
|
|
|
response = requests.post( |
|
f'https://api-inference.huggingface.co/models/{self.model_name}', |
|
headers={ |
|
'authorization': f'Bearer {self.api_key}', |
|
'content-type': 'application/json', |
|
}, |
|
json=data, |
|
stream=True |
|
) |
|
|
|
|
|
if response.status_code != 200 or not response.json() or 'error' in response.json(): |
|
print(f"Error: {response}") |
|
return "Unable to answer for technical reasons." |
|
full_txt = response.json()[0]['generated_text'] |
|
|
|
offset = full_txt.find("---------------------") |
|
ss = full_txt[offset:] |
|
|
|
offset = ss.find("Answer:") |
|
return ss[offset+7:].strip() |
|
|
|
|
|
@llm_completion_callback() |
|
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: |
|
""" |
|
Generates a complete response for a given prompt using the Hugging Face API. |
|
|
|
Args: |
|
prompt (str): The input prompt for the model. |
|
**kwargs: Additional keyword arguments for the completion. |
|
|
|
Returns: |
|
CompletionResponse: The complete response from the model. |
|
""" |
|
response = self.do_hf_call(prompt) |
|
return CompletionResponse(text=response) |
|
|
|
|
|
@llm_completion_callback() |
|
def stream_complete( |
|
self, prompt: str, **kwargs: Any |
|
) -> CompletionResponseGen: |
|
""" |
|
Streams a series of token completions as a response for the given prompt. |
|
|
|
This method is useful for streaming responses where each token is generated sequentially. |
|
|
|
Args: |
|
prompt (str): The input prompt for the model. |
|
**kwargs: Additional keyword arguments for the streaming completion. |
|
|
|
Yields: |
|
CompletionResponseGen: A generator yielding each token in the completion response. |
|
""" |
|
|
|
response = "" |
|
for token in self.do_hf_call(prompt): |
|
response += token |
|
yield CompletionResponse(text=response, delta=token) |
|
|
|
|
|
|
|
class KeywordSearch(): |
|
def __init__(self, chunks): |
|
self.chunks = chunks |
|
|
|
def find_journal_name(self, response: str, journal_list: list) -> str: |
|
""" |
|
Searches for a journal name in a given response string. |
|
|
|
This function iterates through a list of known journal names and checks if any of these |
|
names are present in the response string. It returns the first journal name found in the |
|
response. If no journal names from the list are found in the response, a default message |
|
indicating that the journal name was not found is returned. |
|
|
|
Args: |
|
response (str): The response string to search for a journal name. |
|
journal_list (list): A list of journal names to search within the response. |
|
|
|
Returns: |
|
str: The first journal name found in the response, or a default message if no journal name is found. |
|
""" |
|
response_lower = response.lower() |
|
for journal in journal_list: |
|
journal_lower = journal.lower() |
|
|
|
if journal_lower in response_lower: |
|
print(journal_lower,response_lower) |
|
return True |
|
|
|
return False |
|
|
|
def check_registration(self): |
|
""" |
|
Check chunks of text for various registration numbers or URLs of registries. |
|
Returns the sentence containing a registration number, or if not found, |
|
returns chunks containing registry URLs. |
|
|
|
Args: |
|
chunks (list of str): List of text chunks to search. |
|
|
|
Returns: |
|
list of str: List of matching sentences or chunks, or an empty list if no matches are found. |
|
""" |
|
|
|
|
|
patterns = { |
|
"NCT": r"\(?(NCT#?\s*(No\s*)?)(\d{8})\)?", |
|
"ISRCTN": r"(ISRCTN\d{8})", |
|
"EudraCT": r"(\d{4}-\d{6}-\d{2})", |
|
"UMIN-CTR": r"(UMIN\d{9})", |
|
"CTRI": r"(CTRI/\d{4}/\d{2}/\d{6})" |
|
} |
|
|
|
|
|
registry_urls = [ |
|
"www.anzctr.org.au", |
|
"anzctr.org.au", |
|
"www.clinicaltrials.gov", |
|
"clinicaltrials.gov", |
|
"www.ISRCTN.org", |
|
"ISRCTN.org", |
|
"www.umin.ac.jp/ctr/index/htm", |
|
"umin.ac.jp/ctr/index/htm", |
|
"www.onderzoekmetmensen.nl/en", |
|
"onderzoekmetmensen.nl/en", |
|
"eudract.ema.europa.eu", |
|
"www.eudract.ema.europa.eu" |
|
] |
|
|
|
|
|
|
|
for chunk in self.chunks: |
|
|
|
sentences = re.split(r'(?<=[.!?]) +', chunk) |
|
|
|
|
|
for sentence in sentences: |
|
for pattern in patterns.values(): |
|
if re.search(pattern, sentence): |
|
return [sentence] |
|
|
|
|
|
matching_chunks = [] |
|
for chunk in self.chunks: |
|
if any(url in chunk for url in registry_urls): |
|
matching_chunks.append(chunk) |
|
|
|
return matching_chunks |
|
|
|
|
|
|
|
class StringExtraction(): |
|
|
|
""" |
|
A class to handle the the process of extraction of query string from complete LLM responses. |
|
|
|
This class encapsulates the functionality of extracting original ground truth from a labelled data csv and query strings from responses. Please note that |
|
LLMs may generate different formatted answers based on different models or different prompting technique. In such cases, extract_original_prompt may not give |
|
satisfactory results. Best case scenario will be write your own string extraction method in such cases. |
|
|
|
|
|
Methods: |
|
extract_original_prompt(): |
|
extraction_ground_truth(): |
|
""" |
|
|
|
def extract_original_prompt(self,result): |
|
r1 = result.response.strip().split("\n") |
|
binary_response = "" |
|
explanation_response = "" |
|
for r in r1: |
|
if binary_response == "" and (r.find("Yes") >= 0 or r.find("No") >= 0): |
|
binary_response = r |
|
elif r.find("Reasoning:") >= 0: |
|
cut = r.find(":") |
|
explanation_response += r[cut+1:].strip() |
|
|
|
return binary_response,explanation_response |
|
|
|
def extraction_ground_truth(self,paper_name,labelled_data): |
|
id = int(paper_name[paper_name.find("_")+1:paper_name.find(".pdf")]) |
|
id_row = labelled_data[labelled_data["id"] == id] |
|
ground_truth = id_row.iloc[:,2:11].values.tolist()[0] |
|
binary_ground_truth = [] |
|
explanation_ground_truth = [] |
|
for g in ground_truth: |
|
if len(g) > 0: |
|
binary_ground_truth.append("Yes") |
|
explanation_ground_truth.append(g) |
|
else: |
|
binary_ground_truth.append("No") |
|
explanation_ground_truth.append("The article does not provide any relevant information.") |
|
return binary_ground_truth,explanation_ground_truth |
|
|
|
|
|
|
|
class EvaluationMetrics(): |
|
""" |
|
|
|
This class encapsulates the evaluation methods that have been used in the project. |
|
|
|
Attributes: |
|
explanation_response = a list of detailed response from the LLM model corresponding to each query |
|
explanation_ground_truth = the list of ground truth corresponding to each query |
|
|
|
Methods: |
|
metric_cosine_similairty(): Sets up the query engine with all necessary components. |
|
metric_rouge(): Executes the predefined queries and prints the results. |
|
metric_binary_accuracy(): |
|
""" |
|
|
|
|
|
def __init__(self,explanation_response,explanation_ground_truth,embedding_model): |
|
self.explanation_response = explanation_response |
|
self.explanation_ground_truth = explanation_ground_truth |
|
self.embedding_model = embedding_model |
|
|
|
def metric_cosine_similarity(self): |
|
ground_truth_embedding = self.embedding_model.encode(self.explanation_ground_truth) |
|
explanation_response_embedding = self.embedding_model.encode(self.explanation_response) |
|
return np.diag(cosine_similarity(ground_truth_embedding,explanation_response_embedding)) |
|
|
|
def metric_rouge(self): |
|
rouge = evaluate.load("rouge") |
|
results = rouge.compute(predictions = self.explanation_response,references = self.explanation_ground_truth) |
|
return results |
|
|
|
def binary_accuracy(self,binary_response,binary_ground_truth): |
|
count = 0 |
|
if len(binary_response) != len(binary_ground_truth): |
|
return "Arrays which are to be compared has different lengths." |
|
else: |
|
for i in range(len(binary_response)): |
|
if binary_response[i] == binary_ground_truth[i]: |
|
count += 1 |
|
return np.round(count/len(binary_response),2) |