File size: 3,884 Bytes
a3d26e6
 
 
 
 
 
 
 
 
 
 
 
 
 
59ba192
a3d26e6
 
 
 
 
 
 
 
 
 
 
 
 
ea077e1
a3d26e6
 
 
59ba192
 
a3d26e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59ba192
89033ee
a3d26e6
 
89033ee
 
 
 
a3d26e6
 
89033ee
a3d26e6
3a8ddd8
 
a3d26e6
 
 
 
3a8ddd8
 
 
 
 
59ba192
 
 
 
 
 
 
3a8ddd8
59ba192
3a8ddd8
a3d26e6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os

from dotenv import load_dotenv
from langchain_community.vectorstores import FAISS
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_mistralai.embeddings import MistralAIEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_community.document_loaders.csv_loader import CSVLoader

from util import getYamlConfig


# load .env in local dev
load_dotenv()
env_api_key = os.environ.get("MISTRAL_API_KEY")

class Rag:
    document_vector_store = None
    retriever = None
    chain = None

    def __init__(self, vectore_store=None):
        
        # self.model = ChatMistralAI(model=llm_model)
        self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)

        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100, length_function=len)
        
        base_template = getYamlConfig()['prompt_template']
        self.prompt = PromptTemplate.from_template(base_template)

        self.vector_store = vectore_store

    def setModel(self, model):
        self.model = model

    def ingestToDb(self, file_path: str, filename: str):

        docs = PyPDFLoader(file_path=file_path).load()

        # Extract all text from the document
        text = ""
        for page in docs:
            text += page.page_content

        # Split the text into chunks
        chunks = self.text_splitter.split_text(text)
        
        return self.vector_store.addDoc(filename=filename, text_chunks=chunks, embedding=self.embedding)

    def getDbFiles(self):
        return self.vector_store.getDocs()

    def ingest(self, pdf_file_path: str):
        docs = PyPDFLoader(file_path=pdf_file_path).load()
       
        chunks = self.text_splitter.split_documents(docs)
        chunks = filter_complex_metadata(chunks)

        document_vector_store = FAISS.from_documents(chunks, self.embedding)
        
        self.retriever = document_vector_store.as_retriever(
            search_type="similarity_score_threshold",
            search_kwargs={
                "k": 3,
                "score_threshold": 0.5,
            },
        )

    def ask(self, query: str, messages: list, variables: list = None):
        self.chain = self.prompt | self.model | StrOutputParser()
        
        # Retrieve the context document
        if self.retriever is None:
            documentContext = ''
        else:
            documentContext = self.retriever.invoke(query)

        # Retrieve the VectoreStore
        contextCommon = self.vector_store.retriever(query, self.embedding)

        # Dictionnaire de base avec les variables principales
        chain_input = {
            "query": query,
            "documentContext": documentContext,
            "commonContext": contextCommon,
            "messages": messages
        }

        # Suppression des valeurs nulles (facultatif)
        chain_input = {k: v for k, v in chain_input.items() if v is not None}

        # Si des variables sous forme de liste sont fournies
        if variables:
            # Convertir la liste en dictionnaire avec 'key' comme clé et 'value' comme valeur
            extra_vars = {item['key']: item['value'] for item in variables if 'key' in item and 'value' in item}
            
            # Fusionner avec chain_input
            chain_input.update(extra_vars)
        

        return self.chain.invoke(chain_input)

    def clear(self):
        self.document_vector_store = None
        self.vector_store = None
        self.retriever = None
        self.chain = None