med-bot / utils /chatbot.py
titanhacker's picture
Upload 10 files
7c4e3b0 verified
raw
history blame
7.08 kB
import gradio as gr
import time
import os
from langchain.vectorstores import Chroma
from typing import List, Tuple
import re
import ast
import html
from utils.load_config import LoadConfig
from langchain.embeddings import HuggingFaceEmbeddings
import requests
import torch
FLASK_APP_ENDPOINT = "http://127.0.0.1:8888/generate_text"
APPCFG = LoadConfig()
URL = ""
hyperlink = f"[RAG]({URL})"
class ChatBot:
"""
Class representing a chatbot with document retrieval and response generation capabilities.
This class provides static methods for responding to user queries, handling feedback, and
cleaning references from retrieved documents.
"""
@staticmethod
def respond(chatbot: List,
message: str,
data_type: str = "Preprocessed doc",
temperature: float = 0.1,
top_k: int = 10,
top_p: float = 0.1) -> Tuple:
"""
Generate a response to a user query using document retrieval and language model completion.
Parameters:
chatbot (List): List representing the chatbot's conversation history.
message (str): The user's query.
data_type (str): Type of data used for document retrieval ("Preprocessed doc" or "Upload doc: Process for RAG").
temperature (float): Temperature parameter for language model completion.
Returns:
Tuple: A tuple containing an empty string, the updated chat history, and references from retrieved documents.
"""
# Retrieve embedding function from code env resources
# emb_model = "sentence-transformers/all-MiniLM-L6-v2"
embedding_function = HuggingFaceEmbeddings(
model_name="NeuML/pubmedbert-base-embeddings",
# cache_folder=os.getenv('SENTENCE_TRANSFORMERS_HOME')
)
if data_type == "Preprocessed doc":
# directories
if os.path.exists(APPCFG.persist_directory):
vectordb = Chroma(persist_directory=APPCFG.persist_directory,
embedding_function=embedding_function)
else:
chatbot.append(
(message, f"VectorDB does not exist. Please first execute the 'upload_data_manually.py' module. For further information please visit {hyperlink}."))
return "", chatbot, None
elif data_type == "Upload doc: Process for RAG":
if os.path.exists(APPCFG.custom_persist_directory):
vectordb = Chroma(persist_directory=APPCFG.custom_persist_directory,
embedding_function=embedding_function)
else:
chatbot.append(
(message, f"No file was uploaded. Please first upload your files using the 'upload' button."))
return "", chatbot, None
docs = vectordb.similarity_search(message, k=APPCFG.k)
question = "# Prompt that you have to answer:\n" + message
retrieved_content, markdown_documents = ChatBot.clean_references(docs)
# Memory: previous two Q&A pairs
chat_history = f"Chat history:\n {str(chatbot[-APPCFG.number_of_q_a_pairs:])}\n\n"
if APPCFG.add_history:
prompt_wrapper = f"{APPCFG.llm_system_role_with_history}\n\n{chat_history}\n\n{retrieved_content}{question}"
else:
prompt_wrapper = f"{APPCFG.llm_system_role_without_history}\n\n{question}\n\n{retrieved_content}"
print("========================")
print(prompt_wrapper)
print("========================")
messages = [
{"role": "user", "content": prompt_wrapper},
]
data = {
"prompt": messages,
"max_new_tokens": APPCFG.max_new_tokens,
"do_sample": APPCFG.do_sample,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p
}
response = requests.post(FLASK_APP_ENDPOINT, json=data)
# print(response.text)
response_json = response.json()
chatbot.append(
(message, response_json["response"]))
# Clean up GPU memory
del vectordb
del docs
torch.cuda.empty_cache()
return "", chatbot, markdown_documents
@staticmethod
def clean_references(documents: List) -> str:
"""
Clean and format references from retrieved documents.
Parameters:
documents (List): List of retrieved documents.
Returns:
str: A string containing cleaned and formatted references.
"""
server_url = "http://localhost:8000"
documents = [str(x)+"\n\n" for x in documents]
markdown_documents = ""
retrieved_content = ""
counter = 1
for doc in documents:
# Extract content and metadata
content, metadata = re.match(
r"page_content=(.*?)( metadata=\{.*\})", doc).groups()
metadata = metadata.split('=', 1)[1]
metadata_dict = ast.literal_eval(metadata)
# Decode newlines and other escape sequences
content = bytes(content, "utf-8").decode("unicode_escape")
# Replace escaped newlines with actual newlines
content = re.sub(r'\\n', '\n', content)
content = re.sub(r'\s*<EOS>\s*<pad>\s*', ' ', content)
content = re.sub(r'\s+', ' ', content).strip()
# Decode HTML entities
content = html.unescape(content)
# Replace incorrect unicode characters with correct ones
#content = content.encode('utf-8').decode('utf-8', 'ignore')
# Use UTF-8 encoding instead of latin-1 to avoid encoding issues
content = content.encode('utf-8', 'ignore').decode('utf-8', 'ignore')
# Remove or replace special characters and mathematical symbols
# This step may need to be customized based on the specific symbols in your documents
content = re.sub(r'–', '-', content)
content = re.sub(r'∈', '∈', content)
content = re.sub(r'×', '×', content)
content = re.sub(r'fi', 'fi', content)
content = re.sub(r'∈', '∈', content)
content = re.sub(r'·', '·', content)
content = re.sub(r'fl', 'fl', content)
pdf_url = f"{server_url}/{os.path.basename(metadata_dict['source'])}"
retrieved_content += f"# Content {counter}:\n" + \
content + "\n\n"
# Append cleaned content to the markdown string with two newlines between documents
markdown_documents += f"# Retrieved content {counter}:\n" + content + "\n\n" + \
f"Source: {os.path.basename(metadata_dict['source'])}" + " | " +\
f"Page number: {str(metadata_dict['page'])}" + " | " +\
f"[View PDF]({pdf_url})" "\n\n"
counter += 1
return retrieved_content, markdown_documents