user-agent
commited on
Commit
•
39ca3cd
1
Parent(s):
fffe6fa
Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,11 @@ from openai import OpenAI
|
|
13 |
from collections import Counter
|
14 |
from transformers import pipeline
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
client = OpenAI()
|
@@ -42,6 +47,8 @@ def shot(input, category):
|
|
42 |
common_result = get_predicted_attributes(ast.literal_eval(str(input)),category)
|
43 |
openai_parsed_response = get_openAI_tags(ast.literal_eval(str(input)))
|
44 |
face_embeddings = get_face_embeddings(ast.literal_eval(str(input)))
|
|
|
|
|
45 |
return {
|
46 |
"colors":{
|
47 |
"main":mainColour,
|
@@ -191,7 +198,233 @@ def get_face_embeddings(image_urls):
|
|
191 |
|
192 |
return results
|
193 |
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
|
197 |
|
|
|
13 |
from collections import Counter
|
14 |
from transformers import pipeline
|
15 |
|
16 |
+
import urllib.request
|
17 |
+
from transformers import YolosImageProcessor, YolosForObjectDetection
|
18 |
+
import torch
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
21 |
|
22 |
|
23 |
client = OpenAI()
|
|
|
47 |
common_result = get_predicted_attributes(ast.literal_eval(str(input)),category)
|
48 |
openai_parsed_response = get_openAI_tags(ast.literal_eval(str(input)))
|
49 |
face_embeddings = get_face_embeddings(ast.literal_eval(str(input)))
|
50 |
+
cropped_images = get_cropped_images(ast.literal_eval(str(input)),category)
|
51 |
+
print(cropped_images)
|
52 |
return {
|
53 |
"colors":{
|
54 |
"main":mainColour,
|
|
|
198 |
|
199 |
return results
|
200 |
|
201 |
+
# new
|
202 |
+
ACCURACY_THRESHOLD = 0.86
|
203 |
+
|
204 |
+
def open_image_from_url(url):
|
205 |
+
# Fetch the image from the URL
|
206 |
+
response = requests.get(url, stream=True)
|
207 |
+
response.raise_for_status() # Check if the request was successful
|
208 |
+
|
209 |
+
# Open the image using PIL
|
210 |
+
image = Image.open(BytesIO(response.content))
|
211 |
+
|
212 |
+
return image
|
213 |
+
|
214 |
+
# Add the main data to the session state
|
215 |
+
main = [['Product Id', 'Sku', 'Color', 'Images', 'Status', 'Category', 'Text']]
|
216 |
+
|
217 |
+
# This is the order of the categories list. NO NOT CHANGE. Just for visualization purposes
|
218 |
+
cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
|
219 |
+
|
220 |
+
filter = ['dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
|
221 |
+
|
222 |
+
# 0 for full body, 1 for upper body, 2 for lower body, 3 for over body (jacket, coat, etc), 4 for accessories
|
223 |
+
yolo_mapping = {
|
224 |
+
'shirt, blouse': 3,
|
225 |
+
'top, t-shirt, sweatshirt' : 1,
|
226 |
+
'sweater': 1,
|
227 |
+
'cardigan': 1,
|
228 |
+
'jacket': 3,
|
229 |
+
'vest': 1,
|
230 |
+
'pants': 2,
|
231 |
+
'shorts': 2,
|
232 |
+
'skirt': 2,
|
233 |
+
'coat': 3,
|
234 |
+
'dress': 0,
|
235 |
+
'jumpsuit': 0,
|
236 |
+
'bag, wallet': 4
|
237 |
+
}
|
238 |
+
|
239 |
+
# First line full body, second line upper body, third line lower body, fourth line over body, fifth line accessories
|
240 |
+
label_mapping = [
|
241 |
+
['women-dress-mini', 'women-dress-dress', 'women-dress-maxi', 'women-dress-midi', 'women-playsuitsjumpsuits-playsuit', 'women-playsuitsjumpsuits-jumpsuit', 'women-coords-coords', 'women-swimwear-onepieces', 'women-swimwear-bikinisets'],
|
242 |
+
['women-sweatersknits-cardigan', 'women-top-waistcoat', 'women-top-blouse', 'women-sweatersknits-blouse', 'women-sweatersknits-sweater', 'women-top-top', 'women-loungewear-hoodie', 'women-top-camistanks', 'women-top-tshirt', 'women-top-croptop', 'women-loungewear-sweatshirt', 'women-top-body'],
|
243 |
+
['women-loungewear-joggers', 'women-bottom-trousers', 'women-bottom-leggings', 'women-bottom-jeans', 'women-bottom-shorts', 'women-bottom-skirt', 'women-loungewear-activewear', 'women-bottom-joggers'],
|
244 |
+
['women-top-shirt', 'women-outwear-coatjacket', 'women-outwear-blazer', 'women-outwear-coatjacket', 'women-outwear-kimonos'],
|
245 |
+
['women-accessories-bags']
|
246 |
+
]
|
247 |
+
|
248 |
+
MODEL_NAME = "valentinafeve/yolos-fashionpedia"
|
249 |
+
|
250 |
+
feature_extractor = YolosImageProcessor.from_pretrained('hustvl/yolos-small')
|
251 |
+
model = YolosForObjectDetection.from_pretrained(MODEL_NAME).to(device)
|
252 |
+
|
253 |
+
def get_category_index(category):
|
254 |
+
# Find index of label mapping
|
255 |
+
for i, labels in enumerate(label_mapping):
|
256 |
+
if category in labels:
|
257 |
+
break
|
258 |
+
return i
|
259 |
+
|
260 |
+
def get_yolo_index(category):
|
261 |
+
# Find index of yolo mapping
|
262 |
+
return yolo_mapping[category]
|
263 |
+
|
264 |
+
def fix_channels(t):
|
265 |
+
"""
|
266 |
+
Some images may have 4 channels (transparent images) or just 1 channel (black and white images), in order to let the images have only 3 channels. I am going to remove the fourth channel in transparent images and stack the single channel in back and white images.
|
267 |
+
:param t: Tensor-like image
|
268 |
+
:return: Tensor-like image with three channels
|
269 |
+
"""
|
270 |
+
if len(t.shape) == 2:
|
271 |
+
return ToPILImage()(torch.stack([t for i in (0, 0, 0)]))
|
272 |
+
if t.shape[0] == 4:
|
273 |
+
return ToPILImage()(t[:3])
|
274 |
+
if t.shape[0] == 1:
|
275 |
+
return ToPILImage()(torch.stack([t[0] for i in (0, 0, 0)]))
|
276 |
+
return ToPILImage()(t)
|
277 |
+
|
278 |
+
def idx_to_text(i):
|
279 |
+
return cats[i]
|
280 |
+
|
281 |
+
# Random colors used for visualization
|
282 |
+
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
283 |
+
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
284 |
+
|
285 |
+
# for output bounding box post-processing
|
286 |
+
def box_cxcywh_to_xyxy(x):
|
287 |
+
x_c, y_c, w, h = x.unbind(1)
|
288 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
289 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
290 |
+
return torch.stack(b, dim=1)
|
291 |
+
|
292 |
+
def rescale_bboxes(out_bbox, size):
|
293 |
+
img_w, img_h = size
|
294 |
+
b = box_cxcywh_to_xyxy(out_bbox)
|
295 |
+
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
|
296 |
+
|
297 |
+
return b
|
298 |
+
|
299 |
+
def plot_results(pil_img, prob, boxes):
|
300 |
+
plt.figure(figsize=(16,10))
|
301 |
+
plt.imshow(pil_img)
|
302 |
+
ax = plt.gca()
|
303 |
+
colors = COLORS * 100
|
304 |
+
i = 0
|
305 |
+
|
306 |
+
crops = []
|
307 |
+
crop_classes = []
|
308 |
+
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
|
309 |
+
cl = p.argmax()
|
310 |
+
|
311 |
+
# Save each box as an image
|
312 |
+
box_img = pil_img.crop((xmin, ymin, xmax, ymax))
|
313 |
+
crops.append(box_img)
|
314 |
+
crop_classes.append(idx_to_text(cl))
|
315 |
+
|
316 |
+
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
|
317 |
+
fill=False, color=c, linewidth=3))
|
318 |
+
|
319 |
+
ax.text(xmin, ymin, idx_to_text(cl), fontsize=10,
|
320 |
+
bbox=dict(facecolor=c, alpha=0.8))
|
321 |
+
|
322 |
+
i += 1
|
323 |
+
|
324 |
+
# Remove white padding all around the image
|
325 |
+
plt.axis('off')
|
326 |
+
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
327 |
+
output_img = plt.gcf()
|
328 |
+
plt.close()
|
329 |
+
|
330 |
+
return output_img, crops, crop_classes
|
331 |
+
|
332 |
+
|
333 |
+
def visualize_predictions(image, outputs, threshold=0.8):
|
334 |
+
# Keep only predictions with confidence >= threshold
|
335 |
+
probas = outputs.logits.softmax(-1)[0, :, :-1]
|
336 |
+
keep = probas.max(-1).values > threshold
|
337 |
+
|
338 |
+
# Convert predicted boxes from [0; 1] to image scales
|
339 |
+
bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size)
|
340 |
+
|
341 |
+
# Get filtered probabilities and boxes based on the filter list
|
342 |
+
filter_set = set(filter)
|
343 |
+
filtered_probas_boxes = [
|
344 |
+
(proba, box) for proba, box in zip(probas[keep], bboxes_scaled)
|
345 |
+
if idx_to_text(proba.argmax()) not in filter_set
|
346 |
+
]
|
347 |
+
|
348 |
+
# If there is a jumpsuit or dress detected, remove them if there are other clothes detected
|
349 |
+
contains_jumpsuit_or_dress = any(idx_to_text(proba.argmax()) in ["jumpsuit", "dress"] for proba, _ in filtered_probas_boxes)
|
350 |
+
if contains_jumpsuit_or_dress and len(filtered_probas_boxes) > 1:
|
351 |
+
filtered_probas_boxes = [
|
352 |
+
(proba, box) for proba, box in filtered_probas_boxes
|
353 |
+
if idx_to_text(proba.argmax()) not in ["jumpsuit", "dress"]
|
354 |
+
]
|
355 |
+
|
356 |
+
# Remove duplicates: Only keep one box per class
|
357 |
+
unique_classes = set()
|
358 |
+
unique_filtered_probas_boxes = []
|
359 |
+
for proba, box in filtered_probas_boxes:
|
360 |
+
class_text = idx_to_text(proba.argmax())
|
361 |
+
if class_text not in unique_classes:
|
362 |
+
unique_classes.add(class_text)
|
363 |
+
unique_filtered_probas_boxes.append((proba, box))
|
364 |
+
|
365 |
+
# If there are remaining filtered probabilities, plot results
|
366 |
+
output_img = None
|
367 |
+
crops = None
|
368 |
+
crop_classes = None
|
369 |
+
if unique_filtered_probas_boxes:
|
370 |
+
final_probas, final_boxes = zip(*unique_filtered_probas_boxes)
|
371 |
+
output_img, crops, crop_classes = plot_results(image, list(final_probas), torch.stack(final_boxes))
|
372 |
+
|
373 |
+
# Return the classes of the detected objects
|
374 |
+
return [proba.argmax().item() for proba, _ in unique_filtered_probas_boxes], output_img, crops, crop_classes
|
375 |
+
|
376 |
+
@spaces.GPU
|
377 |
+
def get_objects(image, threshold=0.8):
|
378 |
+
class_counts = {}
|
379 |
+
image = fix_channels(ToTensor()(image))
|
380 |
+
image = image.resize((600, 800))
|
381 |
+
|
382 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
383 |
+
outputs = model(**inputs.to(device))
|
384 |
+
|
385 |
+
detected_classes, output_img, crops, crop_classes = visualize_predictions(image, outputs, threshold=threshold)
|
386 |
+
for cl in detected_classes:
|
387 |
+
class_name = idx_to_text(cl)
|
388 |
+
if class_name not in class_counts:
|
389 |
+
class_counts[class_name] = 0
|
390 |
+
class_counts[class_name] += 1
|
391 |
+
|
392 |
+
if crop_classes is not None:
|
393 |
+
crop_classes = [get_yolo_index(c) for c in crop_classes]
|
394 |
+
|
395 |
+
return class_counts, output_img, crops, crop_classes
|
396 |
+
|
397 |
+
|
398 |
+
|
399 |
+
def get_cropped_images(images,category):
|
400 |
+
cropped_list = []
|
401 |
+
resultsPerCategory = {}
|
402 |
+
for num, image in enumerate(images):
|
403 |
+
image = open_image_from_url(image)
|
404 |
+
class_counts, output_img, cropped_images, cropped_classes = get_objects(image, 0.37)
|
405 |
+
print(cropped_images)
|
406 |
+
if not class_counts:
|
407 |
+
continue
|
408 |
+
|
409 |
+
# Get the inverse category as any other mapping label except the current one corresponding category
|
410 |
+
inverse_category = [label for i, labels in enumerate(label_mapping) for label in labels if i != get_category_index(category) and i != 0]
|
411 |
+
|
412 |
+
# If category is a cardigan, we don't recommend category indices 1 and 3
|
413 |
+
if category == 'women-sweatersknits-cardigan':
|
414 |
+
inverse_category = [label for i, labels in enumerate(label_mapping) for label in labels if i != get_category_index(category) and i != 1 and i != 3]
|
415 |
+
|
416 |
+
for i, image in enumerate(cropped_images):
|
417 |
+
cropped_category = cropped_classes[i]
|
418 |
+
print(cropped_category, cropped_classes[i], get_category_index(category))
|
419 |
+
|
420 |
+
specific_category = label_mapping[cropped_category]
|
421 |
+
|
422 |
+
if cropped_category == get_category_index(category):
|
423 |
+
continue
|
424 |
+
|
425 |
+
cropped_list.append(image)
|
426 |
+
|
427 |
+
return cropped_list
|
428 |
|
429 |
|
430 |
|