ASSISTANT_PAC / rag_module.py
Ilyas KHIAT
test
e346593
raw
history blame
6.43 kB
#load & split data
from langchain.text_splitter import RecursiveCharacterTextSplitter
# embed data
from langchain_mistralai import MistralAIEmbeddings
# vector store
from langchain_community.vectorstores import FAISS
# prompt
from langchain.prompts import PromptTemplate
# memory
from langchain.memory import ConversationBufferMemory
#llm
from langchain_mistralai.chat_models import ChatMistralAI
#chain modules
from langchain.chains import RetrievalQA
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.document_loaders import PyPDFLoader
# import PyPDF2
import os
import re
from dotenv import load_dotenv
load_dotenv()
from collections import defaultdict
api_key = os.environ.get("MISTRAL_API_KEY")
def extract_pdfs_from_folder(folder_path):
pdf_files = []
for file_name in os.listdir(folder_path):
if file_name.endswith(".pdf"):
pdf_files.append(os.path.join(folder_path, file_name))
extracted_texts = []
for pdf_file in pdf_files:
loader = PyPDFLoader(pdf_file)
pages = loader.load()
extracted_texts += pages
return extracted_texts
class RagModule():
def __init__(self):
self.mistral_api_key = api_key
self.model_name_embedding = "mistral-embed"
print(f"API KEY:, {self.mistral_api_key}")
self.embedding_model = MistralAIEmbeddings(model=self.model_name_embedding, mistral_api_key=self.mistral_api_key)
self.chunk_size = 1000
self.chunk_overlap = 120
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
self.db_faiss_path = "data/vector_store"
#params llm
self.llm_model = "mistral-small"
self.max_new_tokens = 512
self.top_p = 0.5
self.temperature = 0.1
def split_text(self, text:str) -> list:
"""Split the text into chunk
Args:
text (str): _description_
Returns:
list: _description_
"""
texts = self.text_splitter.split_text(text)
return texts
def get_metadata(self, texts:list) -> list:
"""_summary_
Args:
texts (list): _description_
Returns:
list: _description_
"""
metadatas = [{"source": f'Paragraphe: {i}'} for i in range(len(texts))]
return metadatas
def get_faiss_db(self):
"""load local faiss vector store containing all embeddings
"""
data = extract_pdfs_from_folder("./data/")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
chunked_documents = text_splitter.split_documents(data)
embedding_model = MistralAIEmbeddings(model=self.model_name_embedding, mistral_api_key=self.mistral_api_key)
store = LocalFileStore("./cache/")
embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model)
vector_store = FAISS.from_documents(chunked_documents, embedder)
vector_store.save_local("faiss_index")
return vector_store
def set_custom_prompt(self, prompt_template:str):
"""Instantiate prompt template for Q&A retreival for each vectore stores
Args:
prompt_template (str): description of the prompt
input_variables (list): variables in the prompt
"""
prompt = PromptTemplate.from_template(
template=prompt_template,
)
return prompt
def load_mistral(self):
"""instantiate LLM
"""
model_kwargs = {
"mistral_api_key": self.mistral_api_key,
"model": self.llm_model,
"max_new_tokens": self.max_new_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
}
llm = ChatMistralAI(**model_kwargs)
return llm
def retrieval_qa_memory_chain(self, db, prompt_template):
"""_summary_
"""
llm = self.load_mistral()
prompt = self.set_custom_prompt(prompt_template)
memory = ConversationBufferMemory(
memory_key = 'history',
input_key = 'question'
)
chain_type_kwargs= {
"prompt" : prompt,
"memory" : memory
}
qa_chain = RetrievalQA.from_chain_type(
llm = llm,
chain_type = 'stuff',
retriever = db.as_retriever(search_kwargs={"k":5}),
chain_type_kwargs = chain_type_kwargs,
return_source_documents = True,
)
return qa_chain
def retrieval_qa_chain(self, db, prompt_template):
"""_summary_
"""
llm = self.load_llm()
prompt = self.set_custom_prompt(prompt_template)
chain_type_kwargs= {
"prompt" : prompt,
}
qa_chain = RetrievalQA.from_chain_type(
llm = llm,
chain_type = 'stuff',
retriever = db.as_retriever(search_kwargs={"k":3}),
chain_type_kwargs = chain_type_kwargs,
return_source_documents = True,
)
return qa_chain
def get_sources_document(self, source_documents:list) -> dict:
"""generate dictionnary with path (as a key) and list of pages associated to one path
Args:
source_document (list): list of documents containing source_document of rag response
Returns:
dict: {
path/to/file1 : [0, 1, 3],
path/to/file2 : [5, 2]
}
"""
sources = defaultdict(list)
for doc in source_documents:
sources[doc.metadata["source"]].append(doc.metadata["page"])
return sources
def shape_answer_with_source(self, answer: str, sources: dict):
"""_summary_
Args:
answer (str): _description_
source (dict): _description_
"""
pattern = r"^(.+)\/([^\/]+)$"
source_msg = ""
for path, page in sources.items():
file = re.findall(pattern, path)[0][1]
source_msg += f"\nFichier: {file} - Page: {page}"
answer += f"\n{source_msg}"
return answer