Spaces:
Sleeping
Sleeping
File size: 5,219 Bytes
8e8fbf9 cc39a93 8e8fbf9 cc39a93 8e8fbf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
from flask import Flask, request, jsonify
from dotenv import load_dotenv
import os
import pymongo
import google.generativeai as genai
from flask_cors import CORS
from tqdm import tqdm
# Load environment variables from .env file
load_dotenv()
# Access the key
MONGODB_URI = os.getenv('MONGODB_URI')
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL') or 'keepitreal/vietnamese-sbert'
DB_NAME = os.getenv('DB_NAME')
DB_COLLECTION = os.getenv('DB_COLLECTION')
GEMINI_KEY = os.getenv('GEMINI_KEY')
genai.configure(api_key=GEMINI_KEY)
model = genai.GenerativeModel('gemini-1.5-pro')
client = pymongo.MongoClient(MONGODB_URI)
db = client[DB_NAME]
collection = db[DB_COLLECTION]
app = Flask(__name__)
CORS(app)
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
def vector_search(user_query, collection, limit=4):
"""
Perform a vector search in the MongoDB collection based on the user query.
Args:
user_query (str): The user's query string.
collection (MongoCollection): The MongoDB collection to search.
Returns:
list: A list of matching documents.
"""
# Generate embedding for the user query
query_embedding = get_embedding(user_query)
if query_embedding is None:
return "Invalid query or embedding generation failed."
# Define the vector search pipeline
vector_search_stage = {
"$vectorSearch": {
"index": "vector_index",
"queryVector": query_embedding,
"path": "embedding",
"numCandidates": 150,
"limit": limit,
}
}
unset_stage = {
"$unset": "embedding"
}
project_stage = {
"$project": {
"_id": 0,
"title": 1,
"details": 1,
"price": 1,
"promotion_price": 1,
"size_options": 1,
"gender_options": 1,
"quantity": 1,
"stock": 1,
"is_shoes": 1,
"is_sandals": 1,
}
}
pipeline = [vector_search_stage, unset_stage, project_stage]
# Execute the search
results = collection.aggregate(pipeline)
return list(results)
def get_search_result(query, collection):
get_knowledge = vector_search(query, collection, 10)
search_result = ""
i = 0
for result in get_knowledge:
# print(result)
i += 1
if result.get('price'):
search_result += f"\n\nSản phẩm {i+1}: {result.get('title')}, Giá: {result.get('price')}"
if result.get('promotion_price'):
search_result += f", Giá ưu đãi: {result.get('promotion_price')}"
if result.get('stock'):
search_result += f", Trạng thái: {result.get('stock')}"
if result.get('is_shoes') == True:
search_result += f", Loại: Giày"
if result.get('is_sandals') == True:
search_result += f", Loại: Dép"
if result.get('size_options'):
search_result += f", Size: {result.get('size_options')}"
if result.get('gender_options'):
search_result += f", Dành cho: {result.get('gender_options')}"
if result.get('details'):
search_result += f", Chi tiết sản phẩm: {result.get('details')}"
return search_result
def get_embedding(text):
if not text.strip():
print("Attempted to get embedding for empty text.")
return []
embedding = embedding_model.encode(text)
return embedding.tolist()
def process_query(query):
return query.lower()
@app.route('/api/search', methods=['POST'])
def handle_query():
data = request.get_json()
query = process_query(data.get('question'))
if not query:
return jsonify({'error': 'No query provided'}), 400
# Retrieve data from vector database
source_information = get_search_result(query, collection).replace('<br>', '\n')
combined_information = f"Hãy trở thành chuyên gia tư vấn bán hàng cho một website bán giày dép ThuThaoShoes. Câu hỏi của khách hàng: {query}\nTrả lời câu hỏi dựa vào các thông tin sản phẩm dưới đây: {source_information}."
response = model.generate_content(combined_information)
return jsonify({
'content': response.text
})
@app.route('/api/embedding', methods=['GET'])
def get_embedding_api():
# Lấy tất cả các tài liệu từ collection
documents = list(collection.find({}))
for doc in tqdm(documents, desc="Processing documents"):
product_specs = doc.get('title', '')
product_cat = doc.get('category', '')
print(product_specs + ' ' + product_cat)
embedding = get_embedding(product_specs + ' Danh mục: ' + product_cat)
if embedding is not None:
# Cập nhật tài liệu với embedding mới
collection.update_one(
{'_id': doc['_id']},
{'$set': {'embedding': embedding}}
)
return jsonify({'message': 'Embedding cập nhật thành công cho tất cả các tài liệu.'})
if __name__ == '__main__':
app.run(debug=True)
|