Spaces:
Running
Running
from flask import Flask, request, jsonify | |
from dotenv import load_dotenv | |
import os | |
import pymongo | |
import google.generativeai as genai | |
from flask_cors import CORS | |
from tqdm import tqdm | |
# Load environment variables from .env file | |
load_dotenv() | |
# Access the key | |
MONGODB_URI = os.getenv('MONGODB_URI') | |
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL') or 'keepitreal/vietnamese-sbert' | |
DB_NAME = os.getenv('DB_NAME') | |
DB_COLLECTION = os.getenv('DB_COLLECTION') | |
GEMINI_KEY = os.getenv('GEMINI_KEY') | |
genai.configure(api_key=GEMINI_KEY) | |
model = genai.GenerativeModel('gemini-1.5-pro') | |
client = pymongo.MongoClient(MONGODB_URI) | |
db = client[DB_NAME] | |
collection = db[DB_COLLECTION] | |
app = Flask(__name__) | |
CORS(app) | |
from sentence_transformers import SentenceTransformer | |
embedding_model = SentenceTransformer(EMBEDDING_MODEL) | |
def vector_search(user_query, collection, limit=4): | |
""" | |
Perform a vector search in the MongoDB collection based on the user query. | |
Args: | |
user_query (str): The user's query string. | |
collection (MongoCollection): The MongoDB collection to search. | |
Returns: | |
list: A list of matching documents. | |
""" | |
# Generate embedding for the user query | |
query_embedding = get_embedding(user_query) | |
if query_embedding is None: | |
return "Invalid query or embedding generation failed." | |
# Define the vector search pipeline | |
vector_search_stage = { | |
"$vectorSearch": { | |
"index": "vector_index", | |
"queryVector": query_embedding, | |
"path": "embedding", | |
"numCandidates": 150, | |
"limit": limit, | |
} | |
} | |
unset_stage = { | |
"$unset": "embedding" | |
} | |
project_stage = { | |
"$project": { | |
"_id": 0, | |
"title": 1, | |
"details": 1, | |
"price": 1, | |
"promotion_price": 1, | |
"size_options": 1, | |
"gender_options": 1, | |
"quantity": 1, | |
"stock": 1, | |
"is_shoes": 1, | |
"is_sandals": 1, | |
} | |
} | |
pipeline = [vector_search_stage, unset_stage, project_stage] | |
# Execute the search | |
results = collection.aggregate(pipeline) | |
return list(results) | |
def get_search_result(query, collection): | |
get_knowledge = vector_search(query, collection, 10) | |
search_result = "" | |
i = 0 | |
for result in get_knowledge: | |
# print(result) | |
i += 1 | |
if result.get('price'): | |
search_result += f"\n\nSản phẩm {i+1}: {result.get('title')}, Giá: {result.get('price')}" | |
if result.get('promotion_price'): | |
search_result += f", Giá ưu đãi: {result.get('promotion_price')}" | |
if result.get('stock'): | |
search_result += f", Trạng thái: {result.get('stock')}" | |
if result.get('is_shoes') == True: | |
search_result += f", Loại: Giày" | |
if result.get('is_sandals') == True: | |
search_result += f", Loại: Dép" | |
if result.get('size_options'): | |
search_result += f", Size: {result.get('size_options')}" | |
if result.get('gender_options'): | |
search_result += f", Dành cho: {result.get('gender_options')}" | |
if result.get('details'): | |
search_result += f", Chi tiết sản phẩm: {result.get('details')}" | |
return search_result | |
def get_embedding(text): | |
if not text.strip(): | |
print("Attempted to get embedding for empty text.") | |
return [] | |
embedding = embedding_model.encode(text) | |
return embedding.tolist() | |
def process_query(query): | |
return query.lower() | |
def handle_query(): | |
data = request.get_json() | |
query = process_query(data.get('question')) | |
if not query: | |
return jsonify({'error': 'No query provided'}), 400 | |
# Retrieve data from vector database | |
source_information = get_search_result(query, collection).replace('<br>', '\n') | |
combined_information = f"Hãy trở thành chuyên gia tư vấn bán hàng cho một website bán giày dép ThuThaoShoes. Câu hỏi của khách hàng: {query}\nTrả lời câu hỏi dựa vào các thông tin sản phẩm dưới đây: {source_information}." | |
response = model.generate_content(combined_information) | |
return jsonify({ | |
'content': response.text | |
}) | |
def get_embedding_api(): | |
# Lấy tất cả các tài liệu từ collection | |
documents = list(collection.find({})) | |
for doc in tqdm(documents, desc="Processing documents"): | |
product_specs = doc.get('title', '') | |
product_cat = doc.get('category', '') | |
print(product_specs + ' ' + product_cat) | |
embedding = get_embedding(product_specs + ' Danh mục: ' + product_cat) | |
if embedding is not None: | |
# Cập nhật tài liệu với embedding mới | |
collection.update_one( | |
{'_id': doc['_id']}, | |
{'$set': {'embedding': embedding}} | |
) | |
return jsonify({'message': 'Embedding cập nhật thành công cho tất cả các tài liệu.'}) | |
if __name__ == '__main__': | |
app.run(debug=True) | |