leesuo215 commited on
Commit
8e8fbf9
1 Parent(s): 9525e96

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +171 -4
main.py CHANGED
@@ -1,7 +1,174 @@
1
- from flask import Flask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- @app.route("/hello")
6
- def hello():
7
- return {"xao chin": "xin chao"}
 
1
+ from flask import Flask, request, jsonify
2
+ from dotenv import load_dotenv
3
+ import os
4
+ import pymongo
5
+ import google.generativeai as genai
6
+ from flask_cors import CORS
7
+ from tqdm import tqdm
8
+
9
+ # Load environment variables from .env file
10
+ load_dotenv()
11
+
12
+ # Access the key
13
+ MONGODB_URI = os.getenv('MONGODB_URI')
14
+ EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL') or 'keepitreal/vietnamese-sbert'
15
+ DB_NAME = os.getenv('DB_NAME')
16
+ DB_COLLECTION = os.getenv('DB_COLLECTION')
17
+ GEMINI_KEY = os.getenv('GEMINI_KEY')
18
+ genai.configure(api_key=GEMINI_KEY)
19
+ model = genai.GenerativeModel('gemini-1.5-pro')
20
+
21
+ client = pymongo.MongoClient(MONGODB_URI)
22
+ db = client[DB_NAME]
23
+ collection = db[DB_COLLECTION]
24
 
25
  app = Flask(__name__)
26
+ CORS(app)
27
+
28
+ from sentence_transformers import SentenceTransformer
29
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL)
30
+
31
+ def vector_search(user_query, collection, limit=4):
32
+ """
33
+ Perform a vector search in the MongoDB collection based on the user query.
34
+
35
+ Args:
36
+ user_query (str): The user's query string.
37
+ collection (MongoCollection): The MongoDB collection to search.
38
+
39
+ Returns:
40
+ list: A list of matching documents.
41
+ """
42
+
43
+ # Generate embedding for the user query
44
+ query_embedding = get_embedding(user_query)
45
+
46
+ if query_embedding is None:
47
+ return "Invalid query or embedding generation failed."
48
+
49
+ # Define the vector search pipeline
50
+ vector_search_stage = {
51
+ "$vectorSearch": {
52
+ "index": "vector_index",
53
+ "queryVector": query_embedding,
54
+ "path": "embedding",
55
+ "numCandidates": 150,
56
+ "limit": limit,
57
+ }
58
+ }
59
+
60
+ unset_stage = {
61
+ "$unset": "embedding"
62
+ }
63
+
64
+ project_stage = {
65
+ "$project": {
66
+ "_id": 0,
67
+ "title": 1,
68
+ "details": 1,
69
+ "price": 1,
70
+ "promotion_price": 1,
71
+ "size_options": 1,
72
+ "gender_options": 1,
73
+ "quantity": 1,
74
+ "stock": 1,
75
+ "is_shoes": 1,
76
+ "is_sandals": 1,
77
+ }
78
+ }
79
+
80
+ pipeline = [vector_search_stage, unset_stage, project_stage]
81
+
82
+ # Execute the search
83
+ results = collection.aggregate(pipeline)
84
+
85
+ return list(results)
86
+
87
+ def get_search_result(query, collection):
88
+ get_knowledge = vector_search(query, collection, 10)
89
+ search_result = ""
90
+ i = 0
91
+ for result in get_knowledge:
92
+ # print(result)
93
+ i += 1
94
+ if result.get('price'):
95
+ search_result += f"\n\nSản phẩm {i+1}: {result.get('title')}, Giá: {result.get('price')}"
96
+
97
+ if result.get('promotion_price'):
98
+ search_result += f", Giá ưu đãi: {result.get('promotion_price')}"
99
+
100
+ if result.get('stock'):
101
+ search_result += f", Trạng thái: {result.get('stock')}"
102
+
103
+ if result.get('is_shoes') == True:
104
+ search_result += f", Loại: Giày"
105
+
106
+ if result.get('is_sandals') == True:
107
+ search_result += f", Loại: Dép"
108
+
109
+ if result.get('size_options'):
110
+ search_result += f", Size: {result.get('size_options')}"
111
+
112
+ if result.get('gender_options'):
113
+ search_result += f", Dành cho: {result.get('gender_options')}"
114
+
115
+ if result.get('details'):
116
+ search_result += f", Chi tiết sản phẩm: {result.get('details')}"
117
+
118
+ return search_result
119
+
120
+ def get_embedding(text):
121
+ if not text.strip():
122
+ print("Attempted to get embedding for empty text.")
123
+ return []
124
+
125
+ embedding = embedding_model.encode(text)
126
+ return embedding.tolist()
127
+
128
+
129
+ def process_query(query):
130
+ return query.lower()
131
+
132
+ @app.route('/api/search', methods=['POST'])
133
+ def handle_query():
134
+ data = request.get_json()
135
+ query = process_query(data.get('question'))
136
+
137
+ if not query:
138
+ return jsonify({'error': 'No query provided'}), 400
139
+
140
+ # Retrieve data from vector database
141
+
142
+ source_information = get_search_result(query, collection).replace('<br>', '\n')
143
+ combined_information = f"Hãy trở thành chuyên gia tư vấn bán hàng cho một website bán giày dép ThuThaoShoes. Câu hỏi của khách hàng: {query}\nTrả lời câu hỏi dựa vào các thông tin sản phẩm dưới đây: {source_information}."
144
+
145
+ response = model.generate_content(combined_information)
146
+
147
+ return jsonify({
148
+ 'content': response.text
149
+ })
150
+
151
+
152
+ @app.route('/api/embedding', methods=['GET'])
153
+ def get_embedding_api():
154
+
155
+ # Lấy tất cả các tài liệu từ collection
156
+ documents = list(collection.find({}))
157
+
158
+ for doc in tqdm(documents, desc="Processing documents"):
159
+ product_specs = doc.get('title', '')
160
+ product_cat = doc.get('category', '')
161
+ print(product_specs + ' ' + product_cat)
162
+ embedding = get_embedding(product_specs + ' Danh mục: ' + product_cat)
163
+
164
+ if embedding is not None:
165
+ # Cập nhật tài liệu với embedding mới
166
+ collection.update_one(
167
+ {'_id': doc['_id']},
168
+ {'$set': {'embedding': embedding}}
169
+ )
170
+
171
+ return jsonify({'message': 'Embedding cập nhật thành công cho tất cả các tài liệu.'})
172
 
173
+ if __name__ == '__main__':
174
+ app.run(debug=True)