Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
import gradio as gr | |
from langchain_chroma import Chroma | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.chains import create_retrieval_chain, create_history_aware_retriever | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.prompts import MessagesPlaceholder | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_core.documents import Document | |
from langchain_core.retrievers import BaseRetriever | |
from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from langchain_openai import ChatOpenAI | |
from langchain.callbacks.tracers import ConsoleCallbackHandler | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from datasets import load_dataset | |
import chromadb | |
from typing import List | |
from mixedbread_ai.client import MixedbreadAI | |
from tqdm import tqdm | |
# Global params | |
CHROMA_PATH = "chromadb_mem10_mxbai_800_complete" | |
MODEL_EMB = "mxbai-embed-large" | |
MODEL_RRK = "mixedbread-ai/mxbai-rerank-large-v1" | |
LLM_NAME = "gpt-4o-mini" | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
MXBAI_API_KEY = os.environ.get("MXBAI_API_KEY") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
HF_API_KEY = os.environ.get("HF_API_KEY") | |
# MixedbreadAI Client | |
# device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
mxbai_client = MixedbreadAI(api_key=MXBAI_API_KEY) | |
model_emb = "mixedbread-ai/mxbai-embed-large-v1" | |
# Set up ChromaDB | |
memoires_ds = load_dataset("eliot-hub/memoires_vec_800", split="data", token=HF_TOKEN, streaming=True) | |
batched_ds = memoires_ds.batch(batch_size=41000) | |
client = chromadb.Client() | |
collection = client.get_or_create_collection(name="embeddings_mxbai") | |
for batch in tqdm(batched_ds, desc="Processing dataset batches"): | |
collection.upsert( | |
ids=batch["id"], | |
metadatas=batch["metadata"], | |
documents=batch["document"], | |
embeddings=batch["embedding"], | |
) | |
print(f"Collection complete: {collection.count()}") | |
db = Chroma( | |
client=client, | |
collection_name=f"embeddings_mxbai", | |
embedding_function = HuggingFaceEmbeddings(model_name=model_emb) | |
) | |
# Reranker class | |
class Reranker(BaseRetriever): | |
retriever: VectorStoreRetriever | |
# model: CrossEncoder | |
k: int | |
def _get_relevant_documents( | |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
docs = self.retriever.invoke(query) | |
results = mxbai_client.reranking(model=MODEL_RRK, query=query, input=[doc.page_content for doc in docs], return_input=True, top_k=self.k) | |
return [Document(page_content=res.input) for res in results.data] | |
# Set up reranker + LLM | |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 25}) | |
reranker = Reranker(retriever=retriever, k=4) #Reranker(retriever=retriever, model=model, k=4) | |
llm = ChatOpenAI(model=LLM_NAME, verbose=True) #, api_key=OPENAI_API_KEY, ) | |
# Set up the contextualize question prompt | |
contextualize_q_system_prompt = ( | |
"Compte tenu de l'historique des discussions et de la dernière question de l'utilisateur " | |
"qui peut faire référence à un contexte dans l'historique du chat, " | |
"formuler une question autonome qui peut être comprise " | |
"sans l'historique du chat. Ne répondez PAS à la question, " | |
"juste la reformuler si nécessaire et sinon la renvoyer telle quelle." | |
) | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
# Create the history-aware retriever | |
history_aware_retriever = create_history_aware_retriever( | |
llm, reranker, contextualize_q_prompt | |
) | |
# Set up the QA prompt | |
system_prompt = ( | |
"Réponds à la question en te basant uniquement sur le contexte suivant: \n\n {context}" | |
"Si tu ne connais pas la réponse, dis que tu ne sais pas." | |
) | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
# Create the question-answer chain | |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
# Set up the conversation history | |
store = {} | |
def get_session_history(session_id: str) -> ChatMessageHistory: | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
conversational_rag_chain = RunnableWithMessageHistory( | |
rag_chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer", | |
) | |
# Gradio interface | |
def chatbot(message, history): | |
session_id = "gradio_session" | |
response = conversational_rag_chain.invoke( | |
{"input": message}, | |
config={ | |
"configurable": {"session_id": session_id}, | |
"callbacks": [ConsoleCallbackHandler()] | |
}, | |
)["answer"] | |
return response | |
iface = gr.ChatInterface( | |
chatbot, | |
title="Dataltist Chatbot", | |
description="Posez vos questions sur l'assurance", | |
textbox=gr.Textbox(placeholder="Qu'est-ce que l'assurance multirisque habitation ?", container=False, scale=9), | |
theme=gr.themes.Soft(primary_hue="red", secondary_hue="pink"), | |
# examples=[ | |
# "Qu'est-ce que l'assurance multirisque habitation ?", | |
# "Qu'est-ce que la garantie DTA ?", | |
# ], | |
retry_btn=None, | |
undo_btn=None, | |
submit_btn=gr.Button(value="Envoyer", icon="./send_icon.png", variant="primary"), | |
clear_btn="Effacer la conversation", | |
) | |
if __name__ == "__main__": | |
iface.launch() # share=True |