Spaces:
Running
Running
#load & split data | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# embed data | |
from langchain_mistralai import MistralAIEmbeddings | |
# vector store | |
from langchain_community.vectorstores import FAISS | |
# prompt | |
from langchain.prompts import PromptTemplate | |
# memory | |
from langchain.memory import ConversationBufferMemory | |
#llm | |
from langchain_mistralai.chat_models import ChatMistralAI | |
#chain modules | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.storage import LocalFileStore | |
from langchain_community.document_loaders import PyPDFLoader | |
# import PyPDF2 | |
import os | |
import re | |
from dotenv import load_dotenv | |
load_dotenv() | |
from collections import defaultdict | |
api_key = os.environ.get("MISTRAL_API_KEY") | |
def extract_pdfs_from_folder(folder_path): | |
pdf_files = [] | |
for file_name in os.listdir(folder_path): | |
if file_name.endswith(".pdf"): | |
pdf_files.append(os.path.join(folder_path, file_name)) | |
extracted_texts = [] | |
for pdf_file in pdf_files: | |
loader = PyPDFLoader(pdf_file) | |
pages = loader.load() | |
extracted_texts += pages | |
return extracted_texts | |
class RagModule(): | |
def __init__(self): | |
self.mistral_api_key = api_key | |
self.model_name_embedding = "mistral-embed" | |
print(f"API KEY:, {self.mistral_api_key}") | |
self.embedding_model = MistralAIEmbeddings(model=self.model_name_embedding, mistral_api_key=self.mistral_api_key) | |
self.chunk_size = 1000 | |
self.chunk_overlap = 120 | |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) | |
self.db_faiss_path = "data/vector_store" | |
#params llm | |
self.llm_model = "mistral-small" | |
self.max_new_tokens = 512 | |
self.top_p = 0.5 | |
self.temperature = 0.1 | |
def split_text(self, text:str) -> list: | |
"""Split the text into chunk | |
Args: | |
text (str): _description_ | |
Returns: | |
list: _description_ | |
""" | |
texts = self.text_splitter.split_text(text) | |
return texts | |
def get_metadata(self, texts:list) -> list: | |
"""_summary_ | |
Args: | |
texts (list): _description_ | |
Returns: | |
list: _description_ | |
""" | |
metadatas = [{"source": f'Paragraphe: {i}'} for i in range(len(texts))] | |
return metadatas | |
def get_faiss_db(self): | |
"""load local faiss vector store containing all embeddings | |
""" | |
data = extract_pdfs_from_folder("./data/") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=100 | |
) | |
chunked_documents = text_splitter.split_documents(data) | |
embedding_model = MistralAIEmbeddings(model=self.model_name_embedding, mistral_api_key=self.mistral_api_key) | |
store = LocalFileStore("./cache/") | |
embedder = CacheBackedEmbeddings.from_bytes_store(embedding_model, store, namespace=embedding_model.model) | |
vector_store = FAISS.from_documents(chunked_documents, embedder) | |
vector_store.save_local("faiss_index") | |
return vector_store | |
def set_custom_prompt(self, prompt_template:str): | |
"""Instantiate prompt template for Q&A retreival for each vectore stores | |
Args: | |
prompt_template (str): description of the prompt | |
input_variables (list): variables in the prompt | |
""" | |
prompt = PromptTemplate.from_template( | |
template=prompt_template, | |
) | |
return prompt | |
def load_mistral(self): | |
"""instantiate LLM | |
""" | |
model_kwargs = { | |
"mistral_api_key": self.mistral_api_key, | |
"model": self.llm_model, | |
"max_new_tokens": self.max_new_tokens, | |
"top_p": self.top_p, | |
"temperature": self.temperature, | |
} | |
llm = ChatMistralAI(**model_kwargs) | |
return llm | |
def retrieval_qa_memory_chain(self, db, prompt_template): | |
"""_summary_ | |
""" | |
llm = self.load_mistral() | |
prompt = self.set_custom_prompt(prompt_template) | |
memory = ConversationBufferMemory( | |
memory_key = 'history', | |
input_key = 'question' | |
) | |
chain_type_kwargs= { | |
"prompt" : prompt, | |
"memory" : memory | |
} | |
qa_chain = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = 'stuff', | |
retriever = db.as_retriever(search_kwargs={"k":5}), | |
chain_type_kwargs = chain_type_kwargs, | |
return_source_documents = True, | |
) | |
return qa_chain | |
def retrieval_qa_chain(self, db, prompt_template): | |
"""_summary_ | |
""" | |
llm = self.load_llm() | |
prompt = self.set_custom_prompt(prompt_template) | |
chain_type_kwargs= { | |
"prompt" : prompt, | |
} | |
qa_chain = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = 'stuff', | |
retriever = db.as_retriever(search_kwargs={"k":3}), | |
chain_type_kwargs = chain_type_kwargs, | |
return_source_documents = True, | |
) | |
return qa_chain | |
def get_sources_document(self, source_documents:list) -> dict: | |
"""generate dictionnary with path (as a key) and list of pages associated to one path | |
Args: | |
source_document (list): list of documents containing source_document of rag response | |
Returns: | |
dict: { | |
path/to/file1 : [0, 1, 3], | |
path/to/file2 : [5, 2] | |
} | |
""" | |
sources = defaultdict(list) | |
for doc in source_documents: | |
sources[doc.metadata["source"]].append(doc.metadata["page"]) | |
return sources | |
def shape_answer_with_source(self, answer: str, sources: dict): | |
"""_summary_ | |
Args: | |
answer (str): _description_ | |
source (dict): _description_ | |
""" | |
pattern = r"^(.+)\/([^\/]+)$" | |
source_msg = "" | |
for path, page in sources.items(): | |
file = re.findall(pattern, path)[0][1] | |
source_msg += f"\nFichier: {file} - Page: {page}" | |
answer += f"\n{source_msg}" | |
return answer |