|
import gradio as gr
|
|
from langchain_mistralai.chat_models import ChatMistralAI
|
|
from langchain.prompts import ChatPromptTemplate
|
|
import os
|
|
from pathlib import Path
|
|
from typing import List, Dict, Optional
|
|
import json
|
|
import faiss
|
|
import numpy as np
|
|
from langchain.schema import Document
|
|
from sentence_transformers import SentenceTransformer
|
|
import pickle
|
|
import re
|
|
|
|
class RAGLoader:
|
|
def __init__(self,
|
|
docs_folder: str = "./docs",
|
|
splits_folder: str = "./splits",
|
|
index_folder: str = "./index",
|
|
model_name: str = "intfloat/multilingual-e5-large"):
|
|
"""
|
|
Initialise le RAG Loader
|
|
|
|
Args:
|
|
docs_folder: Dossier contenant les documents sources
|
|
splits_folder: Dossier où seront stockés les morceaux de texte
|
|
index_folder: Dossier où sera stocké l'index FAISS
|
|
model_name: Nom du modèle SentenceTransformer à utiliser
|
|
"""
|
|
self.docs_folder = Path(docs_folder)
|
|
self.splits_folder = Path(splits_folder)
|
|
self.index_folder = Path(index_folder)
|
|
self.model_name = model_name
|
|
|
|
|
|
self.splits_folder.mkdir(parents=True, exist_ok=True)
|
|
self.index_folder.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
self.splits_path = self.splits_folder / "splits.json"
|
|
self.index_path = self.index_folder / "faiss.index"
|
|
self.documents_path = self.index_folder / "documents.pkl"
|
|
|
|
|
|
self.model = None
|
|
self.index = None
|
|
self.indexed_documents = None
|
|
|
|
def load_and_split_texts(self) -> List[Document]:
|
|
"""
|
|
Charge les textes du dossier docs, les découpe en morceaux et les sauvegarde
|
|
dans un fichier JSON unique.
|
|
|
|
Returns:
|
|
Liste de Documents contenant les morceaux de texte et leurs métadonnées
|
|
"""
|
|
documents = []
|
|
|
|
|
|
if self._splits_exist():
|
|
print("Chargement des splits existants...")
|
|
return self._load_existing_splits()
|
|
|
|
print("Création de nouveaux splits...")
|
|
|
|
for file_path in self.docs_folder.glob("*.txt"):
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
text = file.read()
|
|
|
|
|
|
|
|
chunks = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
|
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
doc = Document(
|
|
page_content=chunk,
|
|
metadata={
|
|
'source': file_path.name,
|
|
'chunk_id': i,
|
|
'total_chunks': len(chunks)
|
|
}
|
|
)
|
|
documents.append(doc)
|
|
|
|
|
|
self._save_splits(documents)
|
|
|
|
print(f"Nombre total de morceaux créés: {len(documents)}")
|
|
return documents
|
|
|
|
def _splits_exist(self) -> bool:
|
|
"""Vérifie si le fichier de splits existe"""
|
|
return self.splits_path.exists()
|
|
|
|
def _save_splits(self, documents: List[Document]):
|
|
"""Sauvegarde tous les documents découpés dans un seul fichier JSON"""
|
|
splits_data = {
|
|
'splits': [
|
|
{
|
|
'text': doc.page_content,
|
|
'metadata': doc.metadata
|
|
}
|
|
for doc in documents
|
|
]
|
|
}
|
|
|
|
with open(self.splits_path, 'w', encoding='utf-8') as f:
|
|
json.dump(splits_data, f, ensure_ascii=False, indent=2)
|
|
|
|
def _load_existing_splits(self) -> List[Document]:
|
|
"""Charge les splits depuis le fichier JSON unique"""
|
|
with open(self.splits_path, 'r', encoding='utf-8') as f:
|
|
splits_data = json.load(f)
|
|
|
|
documents = [
|
|
Document(
|
|
page_content=split['text'],
|
|
metadata=split['metadata']
|
|
)
|
|
for split in splits_data['splits']
|
|
]
|
|
|
|
print(f"Nombre de splits chargés: {len(documents)}")
|
|
return documents
|
|
|
|
def load_index(self) -> bool:
|
|
"""
|
|
Charge l'index FAISS et les documents associés s'ils existent
|
|
|
|
Returns:
|
|
bool: True si l'index a été chargé, False sinon
|
|
"""
|
|
if not self._index_exists():
|
|
print("Aucun index trouvé.")
|
|
return False
|
|
|
|
print("Chargement de l'index existant...")
|
|
try:
|
|
|
|
self.index = faiss.read_index(str(self.index_path))
|
|
|
|
|
|
with open(self.documents_path, 'rb') as f:
|
|
self.indexed_documents = pickle.load(f)
|
|
|
|
print(f"Index chargé avec {self.index.ntotal} vecteurs")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Erreur lors du chargement de l'index: {e}")
|
|
return False
|
|
|
|
def create_index(self, documents: Optional[List[Document]] = None) -> bool:
|
|
"""
|
|
Crée un nouvel index FAISS à partir des documents.
|
|
Si aucun document n'est fourni, charge les documents depuis le fichier JSON.
|
|
|
|
Args:
|
|
documents: Liste optionnelle de Documents à indexer
|
|
|
|
Returns:
|
|
bool: True si l'index a été créé avec succès, False sinon
|
|
"""
|
|
try:
|
|
|
|
if self.model is None:
|
|
print("Chargement du modèle...")
|
|
self.model = SentenceTransformer(self.model_name)
|
|
|
|
|
|
if documents is None:
|
|
documents = self.load_and_split_texts()
|
|
|
|
if not documents:
|
|
print("Aucun document à indexer.")
|
|
return False
|
|
|
|
print("Création des embeddings...")
|
|
texts = [doc.page_content for doc in documents]
|
|
embeddings = self.model.encode(texts, show_progress_bar=True)
|
|
|
|
|
|
dimension = embeddings.shape[1]
|
|
self.index = faiss.IndexFlatL2(dimension)
|
|
|
|
|
|
self.index.add(np.array(embeddings).astype('float32'))
|
|
|
|
|
|
print("Sauvegarde de l'index...")
|
|
faiss.write_index(self.index, str(self.index_path))
|
|
|
|
|
|
self.indexed_documents = documents
|
|
with open(self.documents_path, 'wb') as f:
|
|
pickle.dump(documents, f)
|
|
|
|
print(f"Index créé avec succès : {self.index.ntotal} vecteurs")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Erreur lors de la création de l'index: {e}")
|
|
return False
|
|
|
|
def _index_exists(self) -> bool:
|
|
"""Vérifie si l'index et les documents associés existent"""
|
|
return self.index_path.exists() and self.documents_path.exists()
|
|
|
|
def get_retriever(self, k: int = 5):
|
|
"""
|
|
Crée un retriever pour l'utilisation avec LangChain
|
|
|
|
Args:
|
|
k: Nombre de documents similaires à retourner
|
|
|
|
Returns:
|
|
Callable: Fonction de recherche compatible avec LangChain
|
|
"""
|
|
if self.index is None:
|
|
if not self.load_index():
|
|
if not self.create_index():
|
|
raise ValueError("Impossible de charger ou créer l'index")
|
|
|
|
if self.model is None:
|
|
self.model = SentenceTransformer(self.model_name)
|
|
|
|
def retriever_function(query: str) -> List[Document]:
|
|
|
|
query_embedding = self.model.encode([query])[0]
|
|
|
|
|
|
distances, indices = self.index.search(
|
|
np.array([query_embedding]).astype('float32'),
|
|
k
|
|
)
|
|
|
|
|
|
results = []
|
|
for idx in indices[0]:
|
|
if idx != -1:
|
|
results.append(self.indexed_documents[idx])
|
|
|
|
return results
|
|
|
|
return retriever_function
|
|
|
|
|
|
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key="QK0ZZpSxQbCEVgOLtI6FARQVmBYc6WGP")
|
|
rag_loader = RAGLoader()
|
|
retriever = rag_loader.get_retriever(k=5)
|
|
|
|
prompt_template = ChatPromptTemplate.from_messages([
|
|
("system", """أنت مساعد مفيد يجيب على الأسئلة باللغة العربية باستخدام المعلومات المقدمة.
|
|
استخدم المعلومات التالية للإجابة على السؤال:
|
|
|
|
{context}
|
|
|
|
إذا لم تكن المعلومات كافية للإجابة على السؤال بشكل كامل، قم بتوضيح ذلك.
|
|
أجب بشكل موجز ودقيق."""),
|
|
("human", "{question}")
|
|
])
|
|
|
|
def process_question(question: str) -> tuple[str, str]:
|
|
"""
|
|
Process a question and return both the answer and the relevant context
|
|
"""
|
|
relevant_docs = retriever(question)
|
|
context = "\n".join([doc.page_content for doc in relevant_docs])
|
|
|
|
prompt = prompt_template.format_messages(
|
|
context=context,
|
|
question=question
|
|
)
|
|
|
|
response = llm(prompt)
|
|
return response.content, context
|
|
|
|
def gradio_interface(question: str) -> tuple[str, str]:
|
|
"""
|
|
Gradio interface function that returns both answer and context as a tuple.
|
|
"""
|
|
|
|
return process_question(question)
|
|
|
|
|
|
custom_css = """
|
|
#question-box textarea, #answer-box textarea, #context-box textarea {
|
|
text-align: right !important;
|
|
direction: rtl !important;
|
|
}
|
|
"""
|
|
|
|
|
|
question = "هل يجوز لرجل السلطة اقتناء عقار داخل مجال عمله"
|
|
answer, context = process_question(question)
|
|
|
|
|
|
print("الإجابة:", answer)
|
|
print("\nالسياق المستخدم:", context)
|
|
|
|
|
|
with gr.Blocks(css=custom_css) as iface:
|
|
with gr.Column():
|
|
input_text = gr.Textbox(
|
|
label="السؤال",
|
|
placeholder="اكتب سؤالك هنا...",
|
|
lines=2,
|
|
elem_id="question-box"
|
|
)
|
|
|
|
answer_box = gr.Textbox(
|
|
label="الإجابة",
|
|
lines=4,
|
|
elem_id="answer-box"
|
|
)
|
|
|
|
context_box = gr.Textbox(
|
|
label="السياق المستخدم",
|
|
lines=8,
|
|
elem_id="context-box"
|
|
)
|
|
|
|
submit_btn = gr.Button("إرسال")
|
|
|
|
|
|
submit_btn.click(
|
|
fn=gradio_interface,
|
|
inputs=input_text,
|
|
outputs=[answer_box, context_box]
|
|
)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
iface.launch(share=True) |