enterprise_knowledge_retriever / src /document_retrieval.py
petrojm's picture
changes to app and document_retrieval
a9bf317
import os
import shutil
import sys
from typing import Any, Dict, List, Optional
import torch
import yaml
from dotenv import load_dotenv
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.prompts import BasePromptTemplate, load_prompt
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.retrievers import BaseRetriever
from transformers import AutoModelForSequenceClassification, AutoTokenizer
current_dir = os.path.dirname(os.path.abspath(__file__)) # src/ directory
kit_dir = os.path.abspath(os.path.join(current_dir, '..')) # EKR/ directory
repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
sys.path.append(kit_dir)
sys.path.append(repo_dir)
#import streamlit as st
from utils.model_wrappers.api_gateway import APIGateway
from utils.vectordb.vector_db import VectorDb
from utils.visual.env_utils import get_wandb_key
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
#load_dotenv(os.path.join(kit_dir, '.env'))
from utils.parsing.sambaparse import parse_doc_universal
# Handle the WANDB_API_KEY resolution before importing weave
#wandb_api_key = get_wandb_key()
# If WANDB_API_KEY is set, proceed with weave initialization
#if wandb_api_key:
# import weave
# Initialize Weave with your project name
# weave.init('sambanova_ekr')
#else:
# print('WANDB_API_KEY is not set. Weave initialization skipped.')
class RetrievalQAChain(Chain):
"""class for question-answering."""
retriever: BaseRetriever
rerank: bool = True
llm: BaseLanguageModel
qa_prompt: BasePromptTemplate
final_k_retrieved_documents: int = 3
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return ['question']
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
return ['answer', 'source_documents']
def _format_docs(self, docs):
return '\n\n'.join(doc.page_content for doc in docs)
def rerank_docs(self, query, docs, final_k):
# Lazy hardcoding for now
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
reranker = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
pairs = []
for d in docs:
pairs.append([query, d.page_content])
with torch.no_grad():
inputs = tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512,
)
scores = (
reranker(**inputs, return_dict=True)
.logits.view(
-1,
)
.float()
)
scores_list = scores.tolist()
scores_sorted_idx = sorted(range(len(scores_list)), key=lambda k: scores_list[k], reverse=True)
docs_sorted = [docs[k] for k in scores_sorted_idx]
# docs_sorted = [docs[k] for k in scores_sorted_idx if scores_list[k]>0]
docs_sorted = docs_sorted[:final_k]
return docs_sorted
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
qa_chain = self.qa_prompt | self.llm | StrOutputParser()
response = {}
documents = self.retriever.invoke(inputs['question'])
if self.rerank:
documents = self.rerank_docs(inputs['question'], documents, self.final_k_retrieved_documents)
docs = self._format_docs(documents)
response['answer'] = qa_chain.invoke({'question': inputs['question'], 'context': docs})
response['source_documents'] = documents
return response
class DocumentRetrieval:
def __init__(self, sambanova_api_key):
self.vectordb = VectorDb()
config_info = self.get_config_info()
self.api_info = config_info[0]
self.llm_info = config_info[1]
self.embedding_model_info = config_info[2]
self.retrieval_info = config_info[3]
self.prompts = config_info[4]
self.prod_mode = config_info[5]
self.retriever = None
self.llm = self.set_llm(sambanova_api_key)
def get_config_info(self):
"""
Loads json config file
"""
# Read config file
with open(CONFIG_PATH, 'r') as yaml_file:
config = yaml.safe_load(yaml_file)
api_info = config['api']
llm_info = config['llm']
embedding_model_info = config['embedding_model']
retrieval_info = config['retrieval']
prompts = config['prompts']
prod_mode = config['prod_mode']
return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
def set_llm(self, sambanova_api_key):
#if self.prod_mode:
# sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
#else:
# if 'SAMBANOVA_API_KEY' in st.session_state:
# sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY
# else:
# sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
#sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
llm = APIGateway.load_llm(
type=self.api_info,
streaming=True,
coe=self.llm_info['coe'],
do_sample=self.llm_info['do_sample'],
max_tokens_to_generate=self.llm_info['max_tokens_to_generate'],
temperature=self.llm_info['temperature'],
select_expert=self.llm_info['select_expert'],
process_prompt=False,
sambanova_api_key=sambanova_api_key,
)
return llm
def parse_doc(self, docs: List, additional_metadata: Optional[Dict] = None) -> List[Document]:
"""
Parse the uploaded documents and return a list of LangChain documents.
Args:
docs (List[UploadFile]): A list of uploaded files.
additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents.
Defaults to an empty dictionary.
Returns:
List[Document]: A list of LangChain documents.
"""
if additional_metadata is None:
additional_metadata = {}
# Create the data/tmp folder if it doesn't exist
temp_folder = os.path.join(kit_dir, 'data/tmp')
if not os.path.exists(temp_folder):
os.makedirs(temp_folder)
else:
# If there are already files there, delete them
for filename in os.listdir(temp_folder):
file_path = os.path.join(temp_folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
# Save all selected files to the tmp dir with their file names
#for doc in docs:
# temp_file = os.path.join(temp_folder, doc.name)
# with open(temp_file, 'wb') as f:
# f.write(doc.getvalue())
for doc_info in docs:
file_name, file_obj = doc_info
temp_file = os.path.join(temp_folder, file_name)
with open(temp_file, 'wb') as f:
f.write(file_obj.read())
# Pass in the temp folder for processing into the parse_doc_universal function
_, _, langchain_docs = parse_doc_universal(doc=temp_folder, additional_metadata=additional_metadata)
return langchain_docs
def load_embedding_model(self):
embeddings = APIGateway.load_embedding_model(
type=self.embedding_model_info['type'],
batch_size=self.embedding_model_info['batch_size'],
coe=self.embedding_model_info['coe'],
select_expert=self.embedding_model_info['select_expert'],
)
return embeddings
def create_vector_store(self, text_chunks, embeddings, output_db=None, collection_name=None):
print(f'Collection name is {collection_name}')
vectorstore = self.vectordb.create_vector_store(
text_chunks, embeddings, output_db=output_db, collection_name=collection_name, db_type='chroma'
)
return vectorstore
def load_vdb(self, db_path, embeddings, collection_name=None):
print(f'Loading collection name is {collection_name}')
vectorstore = self.vectordb.load_vdb(db_path, embeddings, db_type='chroma', collection_name=collection_name)
return vectorstore
def init_retriever(self, vectorstore):
if self.retrieval_info['rerank']:
self.retriever = vectorstore.as_retriever(
search_type='similarity_score_threshold',
search_kwargs={
'score_threshold': self.retrieval_info['score_threshold'],
'k': self.retrieval_info['k_retrieved_documents'],
},
)
else:
self.retriever = vectorstore.as_retriever(
search_type='similarity_score_threshold',
search_kwargs={
'score_threshold': self.retrieval_info['score_threshold'],
'k': self.retrieval_info['final_k_retrieved_documents'],
},
)
def get_qa_retrieval_chain(self):
"""
Generate a qa_retrieval chain using a language model.
This function uses a language model, specifically a SambaNova LLM, to generate a qa_retrieval chain
based on the input vector store of text chunks.
Parameters:
vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context
for generating the conversation chain.
Returns:
RetrievalQA: A chain ready for QA without memory
"""
# customprompt = load_prompt(os.path.join(kit_dir, self.prompts["qa_prompt"]))
# qa_chain = customprompt | self.llm | StrOutputParser()
# response = {}
# documents = self.retriever.invoke(question)
# if self.retrieval_info["rerank"]:
# documents = self.rerank_docs(question, documents, self.retrieval_info["final_k_retrieved_documents"])
# docs = self._format_docs(documents)
# response["answer"] = qa_chain.invoke({"question": question, "context": docs})
# response["source_documents"] = documents
retrievalQAChain = RetrievalQAChain(
retriever=self.retriever,
llm=self.llm,
qa_prompt=load_prompt(os.path.join(kit_dir, self.prompts['qa_prompt'])),
rerank=self.retrieval_info['rerank'],
final_k_retrieved_documents=self.retrieval_info['final_k_retrieved_documents'],
)
return retrievalQAChain
def get_conversational_qa_retrieval_chain(self):
"""
Generate a conversational retrieval qa chain using a language model.
This function uses a language model, specifically a SambaNova LLM, to generate a conversational_qa_retrieval chain
based on the chat history and the relevant retrieved content from the input vector store of text chunks.
Parameters:
vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context
for generating the conversation chain.
Returns:
RetrievalQA: A chain ready for QA with memory
"""