|
import os |
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
|
import pandas as pd |
|
from huggingface_hub import login |
|
|
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
hf_key = os.getenv("huggingface") |
|
login(hf_key) |
|
|
|
|
|
product_data_path = "./db/catalog_chatbot_2024-07-08.csv" |
|
df = pd.read_csv(product_data_path, encoding='ISO-8859-1', sep=';') |
|
|
|
|
|
embeddings = OpenAIEmbeddings(model="text-embedding-3-small") |
|
|
|
|
|
persistent_directory = os.path.join("./","db", "chroma_open_ai") |
|
|
|
|
|
if not os.path.exists(os.path.join(persistent_directory, 'chroma.sqlite3')): |
|
db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings) |
|
for index, row in df.iterrows(): |
|
product_info = ( |
|
f"Nom du produit: {row['Nom du produit']} - " |
|
f"Catégorie: {row['Catégorie par défaut']} - " |
|
f"Caractéristiques: {row['Caractéristiques']} - " |
|
f"Prix: {row['Prix de vente TTC']} - " |
|
f"Description: {row['Description sans HTML']}" |
|
) |
|
metadata = { |
|
"reference": row['Référence interne'], |
|
"name": row['Nom du produit'], |
|
"price": row['Prix de vente TTC'], |
|
"product_url": row['URL Produit'] |
|
} |
|
db.add_texts(texts=[product_info], metadatas=[metadata]) |
|
db.persist() |
|
else: |
|
db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings) |
|
|
|
|
|
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 10}) |
|
|
|
|
|
llm = ChatOpenAI(model="gpt-4o") |
|
|
|
|
|
def format_retrieved_products(retrieved_docs): |
|
recommendations = [] |
|
seen_products = set() |
|
for doc in retrieved_docs: |
|
metadata = doc.metadata |
|
product_name = metadata.get("name", "Produit inconnu") |
|
price = metadata.get("price", "Prix non disponible") |
|
product_url = metadata.get("product_url", "#") |
|
|
|
if product_name not in seen_products: |
|
recommendation = f"**{product_name}** - {price} €\n[Voir produit]({product_url})" |
|
recommendations.append(recommendation) |
|
seen_products.add(product_name) |
|
|
|
return "\n".join(recommendations) |
|
|
|
|
|
qa_system_prompt = ( |
|
"You are a sales assistant helping customers purchase wine. " |
|
"Use the retrieved context from the Chroma DB to answer the question. " |
|
"Recommend 3 different items and provide the URLs of the 3 products from Calais Vins." |
|
) |
|
|
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", qa_system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
("system", "{context}") |
|
] |
|
) |
|
|
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) |
|
|
|
|
|
def create_custom_retrieval_chain(retriever, llm_chain): |
|
def invoke(inputs): |
|
query = inputs["input"] |
|
retrieved_docs = retriever.get_relevant_documents(query) |
|
formatted_response = format_retrieved_products(retrieved_docs) |
|
return {"answer": formatted_response} |
|
|
|
return invoke |
|
|
|
rag_chain = create_custom_retrieval_chain(retriever, question_answer_chain) |
|
|
|
|
|
def run_streamlit_chatbot(): |
|
st.title("Wine Sales Assistant") |
|
|
|
chat_history = [] |
|
|
|
|
|
user_query = st.text_input("Posez une question au chatbot (e.g., je recherche un vin blanc fruité):") |
|
|
|
if user_query: |
|
result = rag_chain({"input": user_query, "chat_history": chat_history}) |
|
|
|
|
|
st.write("### Chatbot's Recommendations:") |
|
st.write(result["answer"]) |
|
|
|
|
|
with st.expander("Voir les recommandations"): |
|
st.write(result["answer"]) |
|
|
|
|
|
if __name__ == "__main__": |
|
run_streamlit_chatbot() |
|
|