tdecae commited on
Commit
f2dfb83
·
verified ·
1 Parent(s): 88dc77e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py CHANGED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
5
+ from langchain.chains.combine_documents import create_stuff_documents_chain
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain_core.messages import HumanMessage, SystemMessage
8
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
10
+ import pandas as pd
11
+ from huggingface_hub import login
12
+
13
+ # Load environment variables
14
+ openai_api_key = os.getenv("OPENAI_API_KEY")
15
+
16
+ hf_key = os.getenv("huggingface")
17
+ login(hf_key)
18
+
19
+ # Load product data from CSV
20
+ product_data_path = "./db/catalog_chatbot_2024-07-08.csv"
21
+ df = pd.read_csv(product_data_path, encoding='ISO-8859-1', sep=';')
22
+
23
+ # Define the embedding model
24
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
25
+
26
+ # Create a persistent directory for ChromaDB
27
+ persistent_directory = os.path.join("./","db", "chroma_open_ai")
28
+
29
+ # Check if the vector store already exists
30
+ if not os.path.exists(os.path.join(persistent_directory, 'chroma.sqlite3')):
31
+ db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings)
32
+ for index, row in df.iterrows():
33
+ product_info = (
34
+ f"Nom du produit: {row['Nom du produit']} - "
35
+ f"Catégorie: {row['Catégorie par défaut']} - "
36
+ f"Caractéristiques: {row['Caractéristiques']} - "
37
+ f"Prix: {row['Prix de vente TTC']} - "
38
+ f"Description: {row['Description sans HTML']}"
39
+ )
40
+ metadata = {
41
+ "reference": row['Référence interne'],
42
+ "name": row['Nom du produit'],
43
+ "price": row['Prix de vente TTC'],
44
+ "product_url": row['URL Produit']
45
+ }
46
+ db.add_texts(texts=[product_info], metadatas=[metadata])
47
+ db.persist()
48
+ else:
49
+ db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings)
50
+
51
+ # Create a retriever
52
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
53
+
54
+ # Create a ChatOpenAI model
55
+ llm = ChatOpenAI(model="gpt-4o")
56
+
57
+ # Function to format the products
58
+ def format_retrieved_products(retrieved_docs):
59
+ recommendations = []
60
+ seen_products = set()
61
+ for doc in retrieved_docs:
62
+ metadata = doc.metadata
63
+ product_name = metadata.get("name", "Produit inconnu")
64
+ price = metadata.get("price", "Prix non disponible")
65
+ product_url = metadata.get("product_url", "#")
66
+
67
+ if product_name not in seen_products:
68
+ recommendation = f"**{product_name}** - {price} €\n[Voir produit]({product_url})"
69
+ recommendations.append(recommendation)
70
+ seen_products.add(product_name)
71
+
72
+ return "\n".join(recommendations)
73
+
74
+ # Update the system prompt
75
+ qa_system_prompt = (
76
+ "You are a sales assistant helping customers purchase wine. "
77
+ "Use the retrieved context from the Chroma DB to answer the question. "
78
+ "Recommend 3 different items and provide the URLs of the 3 products from Calais Vins."
79
+ )
80
+
81
+ qa_prompt = ChatPromptTemplate.from_messages(
82
+ [
83
+ ("system", qa_system_prompt),
84
+ MessagesPlaceholder("chat_history"),
85
+ ("human", "{input}"),
86
+ ("system", "{context}")
87
+ ]
88
+ )
89
+
90
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
91
+
92
+ # Define a retrieval chain
93
+ def create_custom_retrieval_chain(retriever, llm_chain):
94
+ def invoke(inputs):
95
+ query = inputs["input"]
96
+ retrieved_docs = retriever.get_relevant_documents(query)
97
+ formatted_response = format_retrieved_products(retrieved_docs)
98
+ return {"answer": formatted_response}
99
+
100
+ return invoke
101
+
102
+ rag_chain = create_custom_retrieval_chain(retriever, question_answer_chain)
103
+
104
+ # Streamlit App Interface
105
+ def run_streamlit_chatbot():
106
+ st.title("Wine Sales Assistant")
107
+
108
+ chat_history = []
109
+
110
+ # User input area
111
+ user_query = st.text_input("Posez une question au chatbot (e.g., je recherche un vin blanc fruité):")
112
+
113
+ if user_query:
114
+ result = rag_chain({"input": user_query, "chat_history": chat_history})
115
+
116
+ # Display chatbot response
117
+ st.write("### Chatbot's Recommendations:")
118
+ st.write(result["answer"])
119
+
120
+ # Display recommendations in a pop-up like fashion
121
+ with st.expander("Voir les recommandations"):
122
+ st.write(result["answer"])
123
+
124
+ # Main function to run the Streamlit app
125
+ if __name__ == "__main__":
126
+ run_streamlit_chatbot()