import os |
import streamlit as st |
from dotenv import load_dotenv |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
from langchain.chains.combine_documents import create_stuff_documents_chain |
from langchain_community.vectorstores import Chroma |
from langchain_core.messages import HumanMessage, SystemMessage |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
import pandas as pd |
from huggingface_hub import login |
openai_api_key = os.getenv("OPENAI_API_KEY") |
hf_key = os.getenv("huggingface") |
login(hf_key) |
product_data_path = "./db/catalog_chatbot_2024-07-08.csv" |
df = pd.read_csv(product_data_path, encoding='ISO-8859-1', sep=';') |
embeddings = OpenAIEmbeddings(model="text-embedding-3-small") |
persistent_directory = os.path.join("./","db", "chroma_open_ai") |
if not os.path.exists(os.path.join(persistent_directory, 'chroma.sqlite3')): |
db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings) |
for index, row in df.iterrows(): |
product_info = ( |
f"Nom du produit: {row['Nom du produit']} - " |
f"Catégorie: {row['Catégorie par défaut']} - " |
f"Caractéristiques: {row['Caractéristiques']} - " |
f"Prix: {row['Prix de vente TTC']} - " |
f"Description: {row['Description sans HTML']}" |
) |
metadata = { |
"reference": row['Référence interne'], |
"name": row['Nom du produit'], |
"price": row['Prix de vente TTC'], |
"product_url": row['URL Produit'] |
} |
db.add_texts(texts=[product_info], metadatas=[metadata]) |
db.persist() |
else: |
db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings) |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 10}) |
llm = ChatOpenAI(model="gpt-4o") |
def format_retrieved_products(retrieved_docs): |
recommendations = [] |
seen_products = set() |
for doc in retrieved_docs: |
metadata = doc.metadata |
product_name = metadata.get("name", "Produit inconnu") |
price = metadata.get("price", "Prix non disponible") |
product_url = metadata.get("product_url", "#") |
if product_name not in seen_products: |
recommendation = f"**{product_name}** - {price} €\n[Voir produit]({product_url})" |
recommendations.append(recommendation) |
seen_products.add(product_name) |
return "\n".join(recommendations) |
qa_system_prompt = ( |
"You are a sales assistant helping customers purchase wine. " |
"Use the retrieved context from the Chroma DB to answer the question. " |
"Recommend 3 different items and provide the URLs of the 3 products from Calais Vins." |
) |
qa_prompt = ChatPromptTemplate.from_messages( |
[ |
("system", qa_system_prompt), |
MessagesPlaceholder("chat_history"), |
("human", "{input}"), |
("system", "{context}") |
] |
) |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) |
def create_custom_retrieval_chain(retriever, llm_chain): |
def invoke(inputs): |
query = inputs["input"] |
retrieved_docs = retriever.get_relevant_documents(query) |
formatted_response = format_retrieved_products(retrieved_docs) |
return {"answer": formatted_response} |
return invoke |
rag_chain = create_custom_retrieval_chain(retriever, question_answer_chain) |
def run_streamlit_chatbot(): |
st.title("Wine Sales Assistant") |
chat_history = [] |
user_query = st.text_input("Posez une question au chatbot (e.g., je recherche un vin blanc fruité):") |
if user_query: |
result = rag_chain({"input": user_query, "chat_history": chat_history}) |
st.write("### Chatbot's Recommendations:") |
st.write(result["answer"]) |
with st.expander("Voir les recommandations"): |
st.write(result["answer"]) |
if __name__ == "__main__": |
run_streamlit_chatbot() |