Spaces:
Runtime error
Runtime error
import os | |
import re | |
import json | |
import torch | |
import openai | |
import logging | |
import asyncio | |
import aiohttp | |
import pandas as pd | |
import numpy as np | |
import evaluate | |
import qdrant_client | |
from pydantic import BaseModel, Field | |
from typing import Any, List, Tuple, Set, Dict, Optional, Union | |
from sklearn.metrics.pairwise import cosine_similarity | |
from unstructured.partition.pdf import partition_pdf | |
import llama_index | |
from llama_index import PromptTemplate | |
from llama_index.retrievers import VectorIndexRetriever, BaseRetriever, BM25Retriever | |
from llama_index.query_engine import RetrieverQueryEngine | |
from llama_index import get_response_synthesizer | |
from llama_index.schema import NodeWithScore | |
from llama_index.query_engine import RetrieverQueryEngine | |
from llama_index import VectorStoreIndex, ServiceContext | |
from llama_index.embeddings import OpenAIEmbedding | |
from llama_index.llms import HuggingFaceLLM | |
import requests | |
from llama_index.llms import ( | |
CustomLLM, | |
CompletionResponse, | |
CompletionResponseGen, | |
LLMMetadata, | |
) | |
from llama_index.query_engine import RetrieverQueryEngine | |
from llama_index.llms.base import llm_completion_callback | |
from llama_index.vector_stores.qdrant import QdrantVectorStore | |
from llama_index.storage.storage_context import StorageContext | |
from llama_index.postprocessor import SentenceTransformerRerank, LLMRerank | |
from tempfile import NamedTemporaryFile | |
# Configure basic logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
# Create a logger object | |
logger = logging.getLogger(__name__) | |
class ConfigManager: | |
""" | |
A class to manage loading and accessing configuration settings. | |
Attributes: | |
config (dict): Dictionary to hold configuration settings. | |
Methods: | |
load_config(config_path: str): Loads the configuration from a given JSON file. | |
get_config_value(key: str): Retrieves a specific configuration value. | |
""" | |
def __init__(self): | |
self.configs = {} | |
def load_config(self, config_name: str, config_path: str) -> None: | |
""" | |
Loads configuration settings from a specified JSON file into a named configuration. | |
Args: | |
config_name (str): The name to assign to this set of configurations. | |
config_path (str): The path to the configuration file. | |
Raises: | |
FileNotFoundError: If the config file is not found. | |
json.JSONDecodeError: If there is an error parsing the config file. | |
""" | |
try: | |
with open(config_path, 'r') as f: | |
self.configs[config_name] = json.load(f) | |
except FileNotFoundError: | |
logging.error(f"Config file not found at {config_path}") | |
raise | |
except json.JSONDecodeError as e: | |
logging.error(f"Error decoding config file: {e}") | |
raise | |
def get_config_value(self, config_name: str, key: str) -> str: | |
""" | |
Retrieves a specific configuration value. | |
Args: | |
key (str): The key for the configuration setting. | |
Returns: | |
str: The value of the configuration setting. | |
Raises: | |
ValueError: If the key is not found or is set to a placeholder value. | |
""" | |
value = self.configs.get(config_name, {}).get(key) | |
if value is None or value == "ENTER_YOUR_TOKEN_HERE": | |
raise ValueError(f"Please set your '{key}' in the config.json file.") | |
return value | |
class base_utils: | |
""" | |
A utility class providing miscellaneous static methods for processing and analyzing text data, | |
particularly from PDF documents and filenames. This class also includes methods for file operations. | |
This class encapsulates the functionality of extracting key information from text, such as scores, | |
reasoning, and IDs, locating specific data within a DataFrame based on an ID extracted from a filename, | |
and reading content from files. | |
Attributes: | |
None (This class contains only static methods and does not maintain any state) | |
Methods: | |
extract_score_reasoning(text: str) -> Dict[str, Optional[str]]: | |
Extracts a score and reasoning from a given text using regular expressions. | |
extract_id_from_filename(filename: str) -> Optional[int]: | |
Extracts an ID from a given filename based on a specified pattern. | |
find_row_for_pdf(pdf_filename: str, dataframe: pd.DataFrame) -> Union[pd.Series, str]: | |
Searches for a row in a DataFrame that matches an ID extracted from a PDF filename. | |
read_from_file(file_path: str) -> str: | |
Reads the content of a file and returns it as a string. | |
""" | |
def read_from_file(file_path: str) -> str: | |
""" | |
Reads the content of a file and returns it as a string. | |
Args: | |
file_path (str): The path to the file to be read. | |
Returns: | |
str: The content of the file. | |
""" | |
with open(file_path, 'r') as prompt_file: | |
prompt = prompt_file.read() | |
return prompt | |
def extract_id_from_filename(filename: str) -> Optional[int]: | |
""" | |
Extracts an ID from a filename, assuming a specific format ('Id_{I}.pdf', where {I} is the ID). | |
Args: | |
filename (str): The filename from which to extract the ID. | |
Returns: | |
int: The extracted ID as an integer, or None if the pattern is not found. | |
""" | |
# Assuming the file name is in the format 'Id_{I}.pdf', where {I} is the ID | |
match = re.search(r'Id_(\d+).pdf', filename) | |
if match: | |
return int(match.group(1)) # Convert to integer if ID is numeric | |
else: | |
return None | |
def extract_score_reasoning(text: str) -> Dict[str, Optional[str]]: | |
""" | |
Extracts score and reasoning from a given text using regular expressions. | |
Args: | |
text (str): The text from which to extract the score and reasoning. | |
Returns: | |
dict: A dictionary containing 'score' and 'reasoning', extracted from the text. | |
""" | |
# Define regular expression patterns for score and reasoning | |
score_pattern = r"Score: (\d+)" | |
reasoning_pattern = r"Reasoning: (.+)" | |
# Extract data using regular expressions | |
score_match = re.search(score_pattern, text) | |
reasoning_match = re.search(reasoning_pattern, text, re.DOTALL) # re.DOTALL allows '.' to match newlines | |
# Extract and return the results | |
extracted_data = { | |
"score": score_match.group(1) if score_match else None, | |
"reasoning": reasoning_match.group(1).strip() if reasoning_match else None | |
} | |
return extracted_data | |
def find_row_for_pdf(pdf_filename: str, dataframe: pd.DataFrame) -> Union[pd.Series, str]: | |
""" | |
Finds the row in a dataframe corresponding to the ID extracted from a given PDF filename. | |
Args: | |
pdf_filename (str): The filename of the PDF. | |
dataframe (pandas.DataFrame): The dataframe in which to find the corresponding row. | |
Returns: | |
pandas.Series or str: The matched row from the dataframe or a message indicating | |
that no matching row or invalid filename was found. | |
""" | |
pdf_id = Utility.extract_id_from_filename(pdf_filename) | |
if pdf_id is not None: | |
# Assuming the first column contains the ID | |
matched_row = dataframe[dataframe.iloc[:, 0] == pdf_id] | |
if not matched_row.empty: | |
return matched_row | |
else: | |
return "No matching row found." | |
else: | |
return "Invalid file name." | |
class PDFProcessor_Unstructured: | |
""" | |
A class to process PDF files, providing functionalities for extracting, categorizing, | |
and merging elements from a PDF file. | |
This class is designed to handle unstructured PDF documents, particularly useful for | |
tasks involving text extraction, categorization, and data processing within PDFs. | |
Attributes: | |
file_path (str): The full path to the PDF file. | |
folder_path (str): The directory path where the PDF file is located. | |
file_name (str): The name of the PDF file. | |
texts (List[str]): A list to store extracted text chunks. | |
tables (List[str]): A list to store extracted tables. | |
Methods: | |
extract_pdf_elements() -> List: | |
Extracts images, tables, and text chunks from a PDF file. | |
categorize_elements(raw_pdf_elements: List) -> None: | |
Categorizes extracted elements from a PDF into tables and texts. | |
merge_chunks() -> List[str]: | |
Merges text chunks based on punctuation and character case criteria. | |
should_skip_chunk(chunk: str) -> bool: | |
Determines if a chunk should be skipped based on its content. | |
should_merge_with_next(current_chunk: str, next_chunk: str) -> bool: | |
Determines if the current chunk should be merged with the next one. | |
process_pdf() -> Tuple[List[str], List[str]]: | |
Processes the PDF by extracting, categorizing, and merging elements. | |
process_pdf_file(uploaded_file) -> Tuple[List[str], List[str]]: | |
Processes an uploaded PDF file to extract and categorize text and tables. | |
""" | |
def __init__(self, config: Dict[str, any]): | |
self.file_path = None | |
self.folder_path = None | |
self.file_name = None | |
self.texts = [] | |
self.tables = [] | |
self.config = config if config is not None else self.default_config() | |
logger.info(f"Initialized PdfProcessor_Unstructured for file: {self.file_name}") | |
def default_config() -> Dict[str, any]: | |
""" | |
Returns the default configuration for PDF processing. | |
Returns: | |
Dict[str, any]: Default configuration options. | |
""" | |
return { | |
"extract_images": False, | |
"infer_table_structure": True, | |
"chunking_strategy": "by_title", | |
"max_characters": 10000, | |
"combine_text_under_n_chars": 100, | |
"strategy": "fast", | |
"model_name": "yolox" | |
} | |
def extract_pdf_elements(self) -> List: | |
""" | |
Extracts images, tables, and text chunks from a PDF file. | |
Returns: | |
List: A list of extracted elements from the PDF. | |
""" | |
logger.info("Starting extraction of PDF elements.") | |
try: | |
extracted_elements = partition_pdf( | |
filename=self.file_path, | |
extract_images_in_pdf=False, | |
infer_table_structure=True, | |
chunking_strategy="by_title", | |
strategy = "fast", | |
max_characters=10000, | |
combine_text_under_n_chars=100, | |
image_output_dir_path=self.folder_path, | |
) | |
logger.info("Extraction of PDF elements completed successfully.") | |
return extracted_elements | |
except Exception as e: | |
logger.error(f"Error extracting PDF elements: {e}", exc_info=True) | |
raise | |
def categorize_elements(self, raw_pdf_elements: List) -> None: | |
""" | |
Categorizes extracted elements from a PDF into tables and texts. | |
Args: | |
raw_pdf_elements (List): A list of elements extracted from the PDF. | |
""" | |
logger.debug("Starting categorization of PDF elements.") | |
for element in raw_pdf_elements: | |
element_type = str(type(element)) | |
if "unstructured.documents.elements.Table" in element_type: | |
self.tables.append(str(element)) | |
elif "unstructured.documents.elements.CompositeElement" in element_type: | |
self.texts.append(str(element)) | |
logger.debug("Categorization of PDF elements completed.") | |
def merge_chunks(self) -> List[str]: | |
""" | |
Merges text chunks based on punctuation and character case criteria. | |
Returns: | |
List[str]: A list of merged text chunks. | |
""" | |
logger.debug("Starting merging of text chunks.") | |
merged_chunks = [] | |
skip_next = False | |
for i, current_chunk in enumerate(self.texts[:-1]): | |
next_chunk = self.texts[i + 1] | |
if self.should_skip_chunk(current_chunk): | |
continue | |
if self.should_merge_with_next(current_chunk, next_chunk): | |
merged_chunks.append(current_chunk + " " + next_chunk) | |
skip_next = True | |
else: | |
merged_chunks.append(current_chunk) | |
if not skip_next: | |
merged_chunks.append(self.texts[-1]) | |
logger.debug("Merging of text chunks completed.") | |
return merged_chunks | |
def should_skip_chunk(chunk: str) -> bool: | |
""" | |
Determines if a chunk should be skipped based on its content. | |
Args: | |
chunk (str): The text chunk to be evaluated. | |
Returns: | |
bool: True if the chunk should be skipped, False otherwise. | |
""" | |
return (chunk.lower().startswith(("figure", "fig", "table")) or | |
not chunk[0].isalnum() or | |
re.match(r'^\d+\.', chunk)) | |
def should_merge_with_next(current_chunk: str, next_chunk: str) -> bool: | |
""" | |
Determines if the current chunk should be merged with the next one. | |
Args: | |
current_chunk (str): The current text chunk. | |
next_chunk (str): The next text chunk. | |
Returns: | |
bool: True if the chunks should be merged, False otherwise. | |
""" | |
return (current_chunk.endswith(",") or | |
(current_chunk[-1].islower() and next_chunk[0].islower())) | |
def process_pdf(self) -> Tuple[List[str], List[str]]: | |
""" | |
Processes the PDF by extracting, categorizing, and merging elements. | |
Returns: | |
Tuple[List[str], List[str]]: A tuple of merged text chunks and tables. | |
""" | |
logger.info("Starting processing of the PDF.") | |
try: | |
raw_pdf_elements = self.extract_pdf_elements() | |
self.categorize_elements(raw_pdf_elements) | |
merged_chunks = self.merge_chunks() | |
return merged_chunks, self.tables | |
except Exception as e: | |
logger.error(f"Error processing PDF: {e}", exc_info=True) | |
raise | |
def process_pdf_file(self, uploaded_file): | |
""" | |
Process an uploaded PDF file. | |
If a new file is uploaded, the previously stored file is deleted. | |
The method updates the file path, processes the PDF, and returns the results. | |
Parameters: | |
uploaded_file: The new PDF file uploaded for processing. | |
Returns: | |
The results of processing the PDF file. | |
""" | |
# Delete the previous file if it exists | |
if self.file_path and os.path.exists(self.file_path): | |
try: | |
os.remove(self.file_path) | |
logging.debug(f"Previous file {self.file_path} deleted.") | |
except Exception as e: | |
logging.warning(f"Error deleting previous file: {e}", exc_info=True) | |
# Process the new file | |
self.file_path = str(uploaded_file) | |
self.folder_path = os.path.dirname(self.file_path) | |
logging.info(f"Starting to process the PDF file: {self.file_path}") | |
try: | |
logging.debug(f"Processing PDF at {self.file_path}") | |
results = self.process_pdf() # Assuming this is a defined method | |
logging.info("PDF processing completed successfully.") | |
return results | |
except Exception as e: | |
logging.error(f"Error processing PDF file: {e}", exc_info=True) | |
raise | |
class HybridRetriever(BaseRetriever): | |
""" | |
A hybrid retriever that combines results from vector-based and BM25 retrieval methods. | |
Inherits from BaseRetriever. | |
This class uses two different retrieval methods and merges their results to provide a | |
comprehensive set of documents in response to a query. It ensures diversity in the | |
retrieved documents by leveraging the strengths of both retrieval methods. | |
Attributes: | |
vector_retriever: An instance of a vector-based retriever. | |
bm25_retriever: An instance of a BM25 retriever. | |
Methods: | |
__init__(vector_retriever, bm25_retriever): Initializes the HybridRetriever with vector and BM25 retrievers. | |
_retrieve(query, **kwargs): Performs the retrieval operation by combining results from both retrievers. | |
_combine_results(bm25_nodes, vector_nodes): Combines and de-duplicates the results from both retrievers. | |
""" | |
def __init__(self, vector_retriever, bm25_retriever): | |
super().__init__() | |
self.vector_retriever = vector_retriever | |
self.bm25_retriever = bm25_retriever | |
logger.info("HybridRetriever initialized with vector and BM25 retrievers.") | |
def _retrieve(self, query: str, **kwargs) -> List: | |
""" | |
Retrieves and combines results from both vector and BM25 retrievers. | |
Args: | |
query: The query string for document retrieval. | |
**kwargs: Additional keyword arguments for retrieval. | |
Returns: | |
List: Combined list of unique nodes retrieved from both methods. | |
""" | |
logger.info(f"Retrieving documents for query: {query}") | |
try: | |
bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs) | |
vector_nodes = self.vector_retriever.retrieve(query, **kwargs) | |
combined_nodes = self._combine_results(bm25_nodes, vector_nodes) | |
logger.info(f"Retrieved {len(combined_nodes)} unique nodes combining vector and BM25 retrievers.") | |
return combined_nodes | |
except Exception as e: | |
logger.error(f"Error in retrieval: {e}") | |
raise | |
def _combine_results(bm25_nodes: List, vector_nodes: List) -> List: | |
""" | |
Combines and de-duplicates results from BM25 and vector retrievers. | |
Args: | |
bm25_nodes: Nodes retrieved from BM25 retriever. | |
vector_nodes: Nodes retrieved from vector retriever. | |
Returns: | |
List: Combined list of unique nodes. | |
""" | |
node_ids: Set = set() | |
combined_nodes = [] | |
for node in bm25_nodes + vector_nodes: | |
if node.node_id not in node_ids: | |
combined_nodes.append(node) | |
node_ids.add(node.node_id) | |
return combined_nodes | |
class PDFQueryEngine: | |
""" | |
A class to handle the process of setting up a query engine and performing queries on PDF documents. | |
This class encapsulates the functionality of creating prompt templates, embedding models, service contexts, | |
indexes, hybrid retrievers, response synthesizers, and executing queries on the set up engine. | |
Attributes: | |
documents (List): A list of documents to be indexed. | |
llm (Language Model): The language model to be used for embeddings and queries. | |
qa_prompt_tmpl (str): Template for creating query prompts. | |
queries (List[str]): List of queries to be executed. | |
Methods: | |
setup_query_engine(): Sets up the query engine with all necessary components. | |
execute_queries(): Executes the predefined queries and prints the results. | |
""" | |
def __init__(self, documents: List[Any], llm: Any, embed_model: Any, qa_prompt_tmpl: Any): | |
self.documents = documents | |
self.llm = llm | |
self.embed_model = embed_model | |
self.qa_prompt_tmpl = qa_prompt_tmpl | |
self.base_utils = base_utils() | |
self.config_manager = ConfigManager() | |
logger.info("PDFQueryEngine initialized.") | |
def format_example(self, example): | |
""" | |
Formats a few-shot example into a string. | |
Args: | |
example (dict): A dictionary containing 'query', 'score', and 'reasoning' for the few-shot example. | |
Returns: | |
str: Formatted few-shot example text. | |
""" | |
return "Example:\nQuery: {}\nScore: {}\nReasoning: {}\n".format( | |
example['query'], example['score'], example['reasoning'] | |
) | |
def setup_query_engine(self): | |
""" | |
Sets up the query engine by initializing and configuring the embedding model, service context, index, | |
hybrid retriever (combining vector and BM25 retrievers), and the response synthesizer. | |
Args: | |
embed_model: The embedding model to be used. | |
service_context: The context for providing services to the query engine. | |
index: The index used for storing and retrieving documents. | |
hybrid_retriever: The retriever that combines vector and BM25 retrieval methods. | |
response_synthesizer: The synthesizer for generating responses to queries. | |
Returns: | |
Any: The configured query engine. | |
""" | |
client = qdrant_client.QdrantClient( | |
# you can use :memory: mode for fast and light-weight experiments, | |
# it does not require to have Qdrant deployed anywhere | |
# but requires qdrant-client >= 1.1.1 | |
location=":memory:" | |
# otherwise set Qdrant instance address with: | |
# uri="http://<host>:<port>" | |
# set API KEY for Qdrant Cloud | |
# api_key="<qdrant-api-key>", | |
) | |
try: | |
logger.info("Initializing the service context for query engine setup.") | |
service_context = ServiceContext.from_defaults(llm=self.llm, embed_model=self.embed_model) | |
vector_store = QdrantVectorStore(client=client, collection_name="med_library") | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
logger.info("Creating an index from documents.") | |
index = VectorStoreIndex.from_documents(documents=self.documents, storage_context=storage_context, service_context=service_context) | |
nodes = service_context.node_parser.get_nodes_from_documents(self.documents) | |
logger.info("Setting up vector and BM25 retrievers.") | |
vector_retriever = index.as_retriever(similarity_top_k=3) | |
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=3) | |
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever) | |
logger.info("Configuring the response synthesizer with the prompt template.") | |
qa_prompt = PromptTemplate(self.qa_prompt_tmpl) | |
response_synthesizer = get_response_synthesizer( | |
service_context=service_context, | |
text_qa_template=qa_prompt, | |
response_mode="compact", | |
) | |
logger.info("Assembling the query engine with reranker and synthesizer.") | |
reranker = SentenceTransformerRerank(top_n=3, model="BAAI/bge-reranker-base") | |
query_engine = RetrieverQueryEngine.from_args( | |
retriever=hybrid_retriever, | |
node_postprocessors=[reranker], | |
response_synthesizer=response_synthesizer, | |
) | |
logger.info("Query engine setup complete.") | |
return query_engine | |
except Exception as e: | |
logger.error(f"Error during query engine setup: {e}") | |
raise | |
# def evaluate_with_llm(self, reg_result: Any, peer_result: Any, guidelines_result: Any, queries: List[str]) -> Tuple[int, List[int], int, float, List[str]]: | |
# """ | |
# Evaluate documents using a language model based on various criteria. | |
# Args: | |
# reg_result (Any): Result related to registration. | |
# peer_result (Any): Result related to peer review. | |
# guidelines_result (Any): Result related to following guidelines. | |
# queries (List[str]): A list of queries to be processed. | |
# Returns: | |
# Tuple[int, List[int], int, float, List[str]]: A tuple containing the total score, a list of scores per criteria. | |
# """ | |
# logger.info("Starting evaluation with LLM.") | |
# self.config_manager.load_config("few_shot", "few_shot.json") | |
# query_engine = self.setup_query_engine() | |
# total_score = 0 | |
# criteria_met = 0 | |
# reasoning = [] | |
# for j, query in enumerate(queries): | |
# # Handle special cases based on the value of j and other conditions | |
# if j == 1 and reg_result: | |
# extracted_data = {"score": 1, "reasoning": reg_result[0]} | |
# elif j == 2 and guidelines_result: | |
# extracted_data = {"score": 1, "reasoning": "The article is published in a journal following EQUATOR-NETWORK reporting guidelines"} | |
# elif j == 8 and (guidelines_result or peer_result): | |
# extracted_data = {"score": 1, "reasoning": "The article is published in a peer-reviewed journal."} | |
# else: | |
# # Execute the query | |
# result = query_engine.query(query).response | |
# extracted_data = self.base_utils.extract_score_reasoning(result) | |
# # Validate and accumulate the scores | |
# extracted_data_score = 0 if extracted_data.get("score") is None else int(extracted_data.get("score")) | |
# if extracted_data_score > 0: | |
# criteria_met += 1 | |
# reasoning.append(extracted_data["reasoning"]) | |
# total_score += extracted_data_score | |
# score_percentage = (float(total_score) / len(queries)) * 100 | |
# logger.info("Evaluation completed.") | |
# return total_score, criteria_met, score_percentage, reasoning | |
async def evaluate_with_llm_async(self, reg_result: Any, peer_result: Any, guidelines_result: Any, queries: List[str]) -> Tuple[int, List[int], int, float, List[str]]: | |
logger.info("Starting evaluation with LLM.") | |
self.config_manager.load_config("few_shot", "few_shot.json") | |
# Setup your query engine, if it's using aiohttp this is where you'd configure it | |
total_score = 0 | |
criteria_met = 0 | |
reasoning = [] | |
async def handle_query(session, j, query): | |
if j == 1 and reg_result: | |
return {"score": 1, "reasoning": reg_result[0]} | |
elif j == 2 and guidelines_result: | |
return {"score": 1, "reasoning": "The article is published in a journal following EQUATOR-NETWORK reporting guidelines"} | |
elif j == 8 and (guidelines_result or peer_result): | |
return {"score": 1, "reasoning": "The article is published in a peer-reviewed journal."} | |
else: | |
# Here, adapt your query engine or direct API call to use aiohttp | |
async with session.post('Your API Endpoint', json={'query': query}) as response: | |
result = await response.json() | |
return self.base_utils.extract_score_reasoning(result) | |
async with aiohttp.ClientSession() as session: | |
tasks = [handle_query(session, j, query) for j, query in enumerate(queries)] | |
results = await asyncio.gather(*tasks) | |
# Process results | |
for extracted_data in results: | |
extracted_data_score = 0 if extracted_data.get("score") is None else int(extracted_data.get("score")) | |
if extracted_data_score > 0: | |
criteria_met += 1 | |
reasoning.append(extracted_data["reasoning"]) | |
total_score += extracted_data_score | |
score_percentage = (float(total_score) / len(queries)) * 100 | |
logger.info("Evaluation completed.") | |
return total_score, criteria_met, len(queries), score_percentage, reasoning | |
class MixtralLLM(CustomLLM): | |
""" | |
A custom language model class for interfacing with the Hugging Face API, specifically using the Mixtral model. | |
Attributes: | |
context_window (int): Number of tokens used for context during inference. | |
num_output (int): Number of tokens to generate as output. | |
temperature (float): Sampling temperature for token generation. | |
model_name (str): Name of the model on Hugging Face's model hub. | |
api_key (str): API key for authenticating with the Hugging Face API. | |
Methods: | |
metadata: Retrieves metadata about the model. | |
do_hf_call: Makes an API call to the Hugging Face model. | |
complete: Generates a complete response for a given prompt. | |
stream_complete: Streams a series of token completions for a given prompt. | |
""" | |
context_window: int = Field(..., description="Number of tokens used for context during inference.") | |
num_output: int = Field(..., description="Number of tokens to generate as output.") | |
temperature: float = Field(..., description="Sampling temperature for token generation.") | |
model_name: str = Field(..., description="Name of the model on Hugging Face's model hub.") | |
api_key: str = Field(..., description="API key for authenticating with the Hugging Face API.") | |
def metadata(self) -> LLMMetadata: | |
""" | |
Retrieves metadata for the Mixtral LLM. | |
Returns: | |
LLMMetadata: An object containing metadata such as context window, number of outputs, and model name. | |
""" | |
return LLMMetadata( | |
context_window=self.context_window, | |
num_output=self.num_output, | |
model_name=self.model_name, | |
) | |
def do_hf_call(self, prompt: str) -> str: | |
""" | |
Makes an API call to the Hugging Face model and retrieves the generated response. | |
Args: | |
prompt (str): The input prompt for the model. | |
Returns: | |
str: The text generated by the model in response to the prompt. | |
Raises: | |
Exception: If the API call fails or returns an error. | |
""" | |
data = { | |
"inputs": prompt, | |
"parameters": {"Temperature": self.temperature} | |
} | |
# Makes a POST request to the Hugging Face API to get the model's response | |
response = requests.post( | |
f'https://api-inference.huggingface.co/models/{self.model_name}', | |
headers={ | |
'authorization': f'Bearer {self.api_key}', | |
'content-type': 'application/json', | |
}, | |
json=data, | |
stream=True | |
) | |
# Checks for a successful response and parses the generated text | |
if response.status_code != 200 or not response.json() or 'error' in response.json(): | |
print(f"Error: {response}") | |
return "Unable to answer for technical reasons." | |
full_txt = response.json()[0]['generated_text'] | |
# Finds the section of the text following the context separator | |
offset = full_txt.find("---------------------") | |
ss = full_txt[offset:] | |
# Extracts the actual answer from the response | |
offset = ss.find("Answer:") | |
return ss[offset+7:].strip() | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
""" | |
Generates a complete response for a given prompt using the Hugging Face API. | |
Args: | |
prompt (str): The input prompt for the model. | |
**kwargs: Additional keyword arguments for the completion. | |
Returns: | |
CompletionResponse: The complete response from the model. | |
""" | |
response = self.do_hf_call(prompt) | |
return CompletionResponse(text=response) | |
def stream_complete( | |
self, prompt: str, **kwargs: Any | |
) -> CompletionResponseGen: | |
""" | |
Streams a series of token completions as a response for the given prompt. | |
This method is useful for streaming responses where each token is generated sequentially. | |
Args: | |
prompt (str): The input prompt for the model. | |
**kwargs: Additional keyword arguments for the streaming completion. | |
Yields: | |
CompletionResponseGen: A generator yielding each token in the completion response. | |
""" | |
# Yields a stream of tokens as the completion response for the given prompt | |
response = "" | |
for token in self.do_hf_call(prompt): | |
response += token | |
yield CompletionResponse(text=response, delta=token) | |
class KeywordSearch(): | |
def __init__(self, chunks): | |
self.chunks = chunks | |
def find_journal_name(self, response: str, journal_list: list) -> str: | |
""" | |
Searches for a journal name in a given response string. | |
This function iterates through a list of known journal names and checks if any of these | |
names are present in the response string. It returns the first journal name found in the | |
response. If no journal names from the list are found in the response, a default message | |
indicating that the journal name was not found is returned. | |
Args: | |
response (str): The response string to search for a journal name. | |
journal_list (list): A list of journal names to search within the response. | |
Returns: | |
str: The first journal name found in the response, or a default message if no journal name is found. | |
""" | |
response_lower = response.lower() | |
for journal in journal_list: | |
journal_lower = journal.lower() | |
if journal_lower in response_lower: | |
print(journal_lower,response_lower) | |
return True | |
return False | |
def check_registration(self): | |
""" | |
Check chunks of text for various registration numbers or URLs of registries. | |
Returns the sentence containing a registration number, or if not found, | |
returns chunks containing registry URLs. | |
Args: | |
chunks (list of str): List of text chunks to search. | |
Returns: | |
list of str: List of matching sentences or chunks, or an empty list if no matches are found. | |
""" | |
# Patterns for different registration types | |
patterns = { | |
"NCT": r"\(?(NCT#?\s*(No\s*)?)(\d{8})\)?", | |
"ISRCTN": r"(ISRCTN\d{8})", | |
"EudraCT": r"(\d{4}-\d{6}-\d{2})", | |
"UMIN-CTR": r"(UMIN\d{9})", | |
"CTRI": r"(CTRI/\d{4}/\d{2}/\d{6})" | |
} | |
# Registry URLs | |
registry_urls = [ | |
"www.anzctr.org.au", | |
"anzctr.org.au", | |
"www.clinicaltrials.gov", | |
"clinicaltrials.gov", | |
"www.ISRCTN.org", | |
"ISRCTN.org", | |
"www.umin.ac.jp/ctr/index/htm", | |
"umin.ac.jp/ctr/index/htm", | |
"www.onderzoekmetmensen.nl/en", | |
"onderzoekmetmensen.nl/en", | |
"eudract.ema.europa.eu", | |
"www.eudract.ema.europa.eu" | |
] | |
# Check each chunk for registration numbers | |
for chunk in self.chunks: | |
# Split chunk into sentences | |
sentences = re.split(r'(?<=[.!?]) +', chunk) | |
# Check each sentence for any registration number | |
for sentence in sentences: | |
for pattern in patterns.values(): | |
if re.search(pattern, sentence): | |
return [sentence] # Return immediately if a registration number is found | |
# If no registration number found, check for URLs in chunks | |
matching_chunks = [] | |
for chunk in self.chunks: | |
if any(url in chunk for url in registry_urls): | |
matching_chunks.append(chunk) | |
return matching_chunks | |
class StringExtraction(): | |
""" | |
A class to handle the the process of extraction of query string from complete LLM responses. | |
This class encapsulates the functionality of extracting original ground truth from a labelled data csv and query strings from responses. Please note that | |
LLMs may generate different formatted answers based on different models or different prompting technique. In such cases, extract_original_prompt may not give | |
satisfactory results. Best case scenario will be write your own string extraction method in such cases. | |
Methods: | |
extract_original_prompt(): | |
extraction_ground_truth(): | |
""" | |
def extract_original_prompt(self,result): | |
r1 = result.response.strip().split("\n") | |
binary_response = "" | |
explanation_response = "" | |
for r in r1: | |
if binary_response == "" and (r.find("Yes") >= 0 or r.find("No") >= 0): | |
binary_response = r | |
elif r.find("Reasoning:") >= 0: | |
cut = r.find(":") | |
explanation_response += r[cut+1:].strip() | |
return binary_response,explanation_response | |
def extraction_ground_truth(self,paper_name,labelled_data): | |
id = int(paper_name[paper_name.find("_")+1:paper_name.find(".pdf")]) | |
id_row = labelled_data[labelled_data["id"] == id] | |
ground_truth = id_row.iloc[:,2:11].values.tolist()[0] | |
binary_ground_truth = [] | |
explanation_ground_truth = [] | |
for g in ground_truth: | |
if len(g) > 0: | |
binary_ground_truth.append("Yes") | |
explanation_ground_truth.append(g) | |
else: | |
binary_ground_truth.append("No") | |
explanation_ground_truth.append("The article does not provide any relevant information.") | |
return binary_ground_truth,explanation_ground_truth | |
class EvaluationMetrics(): | |
""" | |
This class encapsulates the evaluation methods that have been used in the project. | |
Attributes: | |
explanation_response = a list of detailed response from the LLM model corresponding to each query | |
explanation_ground_truth = the list of ground truth corresponding to each query | |
Methods: | |
metric_cosine_similairty(): Sets up the query engine with all necessary components. | |
metric_rouge(): Executes the predefined queries and prints the results. | |
metric_binary_accuracy(): | |
""" | |
def __init__(self,explanation_response,explanation_ground_truth,embedding_model): | |
self.explanation_response = explanation_response | |
self.explanation_ground_truth = explanation_ground_truth | |
self.embedding_model = embedding_model | |
def metric_cosine_similarity(self): | |
ground_truth_embedding = self.embedding_model.encode(self.explanation_ground_truth) | |
explanation_response_embedding = self.embedding_model.encode(self.explanation_response) | |
return np.diag(cosine_similarity(ground_truth_embedding,explanation_response_embedding)) | |
def metric_rouge(self): | |
rouge = evaluate.load("rouge") | |
results = rouge.compute(predictions = self.explanation_response,references = self.explanation_ground_truth) | |
return results | |
def binary_accuracy(self,binary_response,binary_ground_truth): | |
count = 0 | |
if len(binary_response) != len(binary_ground_truth): | |
return "Arrays which are to be compared has different lengths." | |
else: | |
for i in range(len(binary_response)): | |
if binary_response[i] == binary_ground_truth[i]: | |
count += 1 | |
return np.round(count/len(binary_response),2) |