Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Query
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from sentence_transformers import SentenceTransformer, util
|
4 |
+
from datasets import load_dataset
|
5 |
+
from typing import List
|
6 |
+
import numpy as np
|
7 |
+
import base64
|
8 |
+
from PIL import Image
|
9 |
+
from io import BytesIO
|
10 |
+
|
11 |
+
app = FastAPI()
|
12 |
+
|
13 |
+
@app.get("/")
|
14 |
+
def root():
|
15 |
+
return {"message": "Welcome to the Product Search API!"}
|
16 |
+
def encode_image_to_base64(image):
|
17 |
+
"""
|
18 |
+
Converts a PIL Image or an image-like object to a Base64-encoded string.
|
19 |
+
"""
|
20 |
+
if isinstance(image, Image.Image):
|
21 |
+
buffer = BytesIO()
|
22 |
+
image.save(buffer, format="PNG")
|
23 |
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
24 |
+
return None
|
25 |
+
# Initialize FastAPI
|
26 |
+
|
27 |
+
# Load Dataset
|
28 |
+
dataset = load_dataset("ashraq/fashion-product-images-small", split="train")
|
29 |
+
|
30 |
+
# Define fields for embedding
|
31 |
+
fields_for_embedding = [
|
32 |
+
"productDisplayName",
|
33 |
+
"usage",
|
34 |
+
"season",
|
35 |
+
"baseColour",
|
36 |
+
"articleType",
|
37 |
+
"subCategory",
|
38 |
+
"masterCategory",
|
39 |
+
"gender",
|
40 |
+
]
|
41 |
+
|
42 |
+
# Prepare Data
|
43 |
+
data = []
|
44 |
+
for item in dataset:
|
45 |
+
data.append({
|
46 |
+
"productDisplayName": item["productDisplayName"],
|
47 |
+
"usage": item["usage"],
|
48 |
+
"season": item["season"],
|
49 |
+
"baseColour": item["baseColour"],
|
50 |
+
"articleType": item["articleType"],
|
51 |
+
"subCategory": item["subCategory"],
|
52 |
+
"masterCategory": item["masterCategory"],
|
53 |
+
"gender": item["gender"],
|
54 |
+
"year": item["year"],
|
55 |
+
"image": item["image"],
|
56 |
+
})
|
57 |
+
|
58 |
+
# Load Sentence Transformer Model
|
59 |
+
model = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
|
60 |
+
|
61 |
+
# Generate Embeddings
|
62 |
+
def create_combined_text(item):
|
63 |
+
return " ".join([str(item[field]) for field in fields_for_embedding if item[field]])
|
64 |
+
|
65 |
+
texts = [create_combined_text(item) for item in data]
|
66 |
+
embeddings = model.encode(texts, convert_to_tensor=True)
|
67 |
+
|
68 |
+
# Response Model
|
69 |
+
class ProductResponse(BaseModel):
|
70 |
+
productDisplayName: str
|
71 |
+
usage: str
|
72 |
+
season: str
|
73 |
+
baseColour: str
|
74 |
+
articleType: str
|
75 |
+
subCategory: str
|
76 |
+
masterCategory: str
|
77 |
+
gender: str
|
78 |
+
year: int
|
79 |
+
image: str # Base64 encoded string
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
@app.get("/products")
|
84 |
+
def search_products(
|
85 |
+
query: str = Query("", title="Search Query", description="Search term for products"),
|
86 |
+
page: int = Query(1, ge=1, title="Page Number"),
|
87 |
+
items_per_page: int = Query(10, ge=1, le=100, title="Items Per Page"),
|
88 |
+
):
|
89 |
+
# Perform Search
|
90 |
+
if query:
|
91 |
+
query_embedding = model.encode(query, convert_to_tensor=True)
|
92 |
+
scores = util.cos_sim(query_embedding, embeddings).squeeze().tolist()
|
93 |
+
ranked_indices = np.argsort(scores)[::-1]
|
94 |
+
else:
|
95 |
+
ranked_indices = np.arange(len(data))
|
96 |
+
|
97 |
+
# Pagination
|
98 |
+
total_items = len(ranked_indices)
|
99 |
+
total_pages = (total_items + items_per_page - 1) // items_per_page
|
100 |
+
start_idx = (page - 1) * items_per_page
|
101 |
+
end_idx = start_idx + items_per_page
|
102 |
+
paginated_indices = ranked_indices[start_idx:end_idx]
|
103 |
+
|
104 |
+
# Prepare Response
|
105 |
+
results = []
|
106 |
+
for idx in paginated_indices:
|
107 |
+
item = data[idx]
|
108 |
+
results.append({
|
109 |
+
"productDisplayName": item["productDisplayName"],
|
110 |
+
"usage": item["usage"],
|
111 |
+
"season": item["season"],
|
112 |
+
"baseColour": item["baseColour"],
|
113 |
+
"articleType": item["articleType"],
|
114 |
+
"subCategory": item["subCategory"],
|
115 |
+
"masterCategory": item["masterCategory"],
|
116 |
+
"gender": item["gender"],
|
117 |
+
"year": item["year"],
|
118 |
+
"image": encode_image_to_base64(item["image"]),
|
119 |
+
})
|
120 |
+
|
121 |
+
# Construct the API response
|
122 |
+
return {
|
123 |
+
"status": 200,
|
124 |
+
"data": results,
|
125 |
+
"totalpages": total_pages,
|
126 |
+
"currentpage": page
|
127 |
+
}
|