MohamedAshraf701 commited on
Commit
01319d7
·
verified ·
1 Parent(s): ea2eb76

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query
2
+ from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer, util
4
+ from datasets import load_dataset
5
+ from typing import List
6
+ import numpy as np
7
+ import base64
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ app = FastAPI()
12
+
13
+ @app.get("/")
14
+ def root():
15
+ return {"message": "Welcome to the Product Search API!"}
16
+ def encode_image_to_base64(image):
17
+ """
18
+ Converts a PIL Image or an image-like object to a Base64-encoded string.
19
+ """
20
+ if isinstance(image, Image.Image):
21
+ buffer = BytesIO()
22
+ image.save(buffer, format="PNG")
23
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
24
+ return None
25
+ # Initialize FastAPI
26
+
27
+ # Load Dataset
28
+ dataset = load_dataset("ashraq/fashion-product-images-small", split="train")
29
+
30
+ # Define fields for embedding
31
+ fields_for_embedding = [
32
+ "productDisplayName",
33
+ "usage",
34
+ "season",
35
+ "baseColour",
36
+ "articleType",
37
+ "subCategory",
38
+ "masterCategory",
39
+ "gender",
40
+ ]
41
+
42
+ # Prepare Data
43
+ data = []
44
+ for item in dataset:
45
+ data.append({
46
+ "productDisplayName": item["productDisplayName"],
47
+ "usage": item["usage"],
48
+ "season": item["season"],
49
+ "baseColour": item["baseColour"],
50
+ "articleType": item["articleType"],
51
+ "subCategory": item["subCategory"],
52
+ "masterCategory": item["masterCategory"],
53
+ "gender": item["gender"],
54
+ "year": item["year"],
55
+ "image": item["image"],
56
+ })
57
+
58
+ # Load Sentence Transformer Model
59
+ model = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
60
+
61
+ # Generate Embeddings
62
+ def create_combined_text(item):
63
+ return " ".join([str(item[field]) for field in fields_for_embedding if item[field]])
64
+
65
+ texts = [create_combined_text(item) for item in data]
66
+ embeddings = model.encode(texts, convert_to_tensor=True)
67
+
68
+ # Response Model
69
+ class ProductResponse(BaseModel):
70
+ productDisplayName: str
71
+ usage: str
72
+ season: str
73
+ baseColour: str
74
+ articleType: str
75
+ subCategory: str
76
+ masterCategory: str
77
+ gender: str
78
+ year: int
79
+ image: str # Base64 encoded string
80
+
81
+
82
+
83
+ @app.get("/products")
84
+ def search_products(
85
+ query: str = Query("", title="Search Query", description="Search term for products"),
86
+ page: int = Query(1, ge=1, title="Page Number"),
87
+ items_per_page: int = Query(10, ge=1, le=100, title="Items Per Page"),
88
+ ):
89
+ # Perform Search
90
+ if query:
91
+ query_embedding = model.encode(query, convert_to_tensor=True)
92
+ scores = util.cos_sim(query_embedding, embeddings).squeeze().tolist()
93
+ ranked_indices = np.argsort(scores)[::-1]
94
+ else:
95
+ ranked_indices = np.arange(len(data))
96
+
97
+ # Pagination
98
+ total_items = len(ranked_indices)
99
+ total_pages = (total_items + items_per_page - 1) // items_per_page
100
+ start_idx = (page - 1) * items_per_page
101
+ end_idx = start_idx + items_per_page
102
+ paginated_indices = ranked_indices[start_idx:end_idx]
103
+
104
+ # Prepare Response
105
+ results = []
106
+ for idx in paginated_indices:
107
+ item = data[idx]
108
+ results.append({
109
+ "productDisplayName": item["productDisplayName"],
110
+ "usage": item["usage"],
111
+ "season": item["season"],
112
+ "baseColour": item["baseColour"],
113
+ "articleType": item["articleType"],
114
+ "subCategory": item["subCategory"],
115
+ "masterCategory": item["masterCategory"],
116
+ "gender": item["gender"],
117
+ "year": item["year"],
118
+ "image": encode_image_to_base64(item["image"]),
119
+ })
120
+
121
+ # Construct the API response
122
+ return {
123
+ "status": 200,
124
+ "data": results,
125
+ "totalpages": total_pages,
126
+ "currentpage": page
127
+ }