user-agent commited on
Commit
39ca3cd
1 Parent(s): fffe6fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -1
app.py CHANGED
@@ -13,6 +13,11 @@ from openai import OpenAI
13
  from collections import Counter
14
  from transformers import pipeline
15
 
 
 
 
 
 
16
 
17
 
18
  client = OpenAI()
@@ -42,6 +47,8 @@ def shot(input, category):
42
  common_result = get_predicted_attributes(ast.literal_eval(str(input)),category)
43
  openai_parsed_response = get_openAI_tags(ast.literal_eval(str(input)))
44
  face_embeddings = get_face_embeddings(ast.literal_eval(str(input)))
 
 
45
  return {
46
  "colors":{
47
  "main":mainColour,
@@ -191,7 +198,233 @@ def get_face_embeddings(image_urls):
191
 
192
  return results
193
 
194
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
 
197
 
 
13
  from collections import Counter
14
  from transformers import pipeline
15
 
16
+ import urllib.request
17
+ from transformers import YolosImageProcessor, YolosForObjectDetection
18
+ import torch
19
+ import matplotlib.pyplot as plt
20
+ from torchvision.transforms import ToTensor, ToPILImage
21
 
22
 
23
  client = OpenAI()
 
47
  common_result = get_predicted_attributes(ast.literal_eval(str(input)),category)
48
  openai_parsed_response = get_openAI_tags(ast.literal_eval(str(input)))
49
  face_embeddings = get_face_embeddings(ast.literal_eval(str(input)))
50
+ cropped_images = get_cropped_images(ast.literal_eval(str(input)),category)
51
+ print(cropped_images)
52
  return {
53
  "colors":{
54
  "main":mainColour,
 
198
 
199
  return results
200
 
201
+ # new
202
+ ACCURACY_THRESHOLD = 0.86
203
+
204
+ def open_image_from_url(url):
205
+ # Fetch the image from the URL
206
+ response = requests.get(url, stream=True)
207
+ response.raise_for_status() # Check if the request was successful
208
+
209
+ # Open the image using PIL
210
+ image = Image.open(BytesIO(response.content))
211
+
212
+ return image
213
+
214
+ # Add the main data to the session state
215
+ main = [['Product Id', 'Sku', 'Color', 'Images', 'Status', 'Category', 'Text']]
216
+
217
+ # This is the order of the categories list. NO NOT CHANGE. Just for visualization purposes
218
+ cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
219
+
220
+ filter = ['dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
221
+
222
+ # 0 for full body, 1 for upper body, 2 for lower body, 3 for over body (jacket, coat, etc), 4 for accessories
223
+ yolo_mapping = {
224
+ 'shirt, blouse': 3,
225
+ 'top, t-shirt, sweatshirt' : 1,
226
+ 'sweater': 1,
227
+ 'cardigan': 1,
228
+ 'jacket': 3,
229
+ 'vest': 1,
230
+ 'pants': 2,
231
+ 'shorts': 2,
232
+ 'skirt': 2,
233
+ 'coat': 3,
234
+ 'dress': 0,
235
+ 'jumpsuit': 0,
236
+ 'bag, wallet': 4
237
+ }
238
+
239
+ # First line full body, second line upper body, third line lower body, fourth line over body, fifth line accessories
240
+ label_mapping = [
241
+ ['women-dress-mini', 'women-dress-dress', 'women-dress-maxi', 'women-dress-midi', 'women-playsuitsjumpsuits-playsuit', 'women-playsuitsjumpsuits-jumpsuit', 'women-coords-coords', 'women-swimwear-onepieces', 'women-swimwear-bikinisets'],
242
+ ['women-sweatersknits-cardigan', 'women-top-waistcoat', 'women-top-blouse', 'women-sweatersknits-blouse', 'women-sweatersknits-sweater', 'women-top-top', 'women-loungewear-hoodie', 'women-top-camistanks', 'women-top-tshirt', 'women-top-croptop', 'women-loungewear-sweatshirt', 'women-top-body'],
243
+ ['women-loungewear-joggers', 'women-bottom-trousers', 'women-bottom-leggings', 'women-bottom-jeans', 'women-bottom-shorts', 'women-bottom-skirt', 'women-loungewear-activewear', 'women-bottom-joggers'],
244
+ ['women-top-shirt', 'women-outwear-coatjacket', 'women-outwear-blazer', 'women-outwear-coatjacket', 'women-outwear-kimonos'],
245
+ ['women-accessories-bags']
246
+ ]
247
+
248
+ MODEL_NAME = "valentinafeve/yolos-fashionpedia"
249
+
250
+ feature_extractor = YolosImageProcessor.from_pretrained('hustvl/yolos-small')
251
+ model = YolosForObjectDetection.from_pretrained(MODEL_NAME).to(device)
252
+
253
+ def get_category_index(category):
254
+ # Find index of label mapping
255
+ for i, labels in enumerate(label_mapping):
256
+ if category in labels:
257
+ break
258
+ return i
259
+
260
+ def get_yolo_index(category):
261
+ # Find index of yolo mapping
262
+ return yolo_mapping[category]
263
+
264
+ def fix_channels(t):
265
+ """
266
+ Some images may have 4 channels (transparent images) or just 1 channel (black and white images), in order to let the images have only 3 channels. I am going to remove the fourth channel in transparent images and stack the single channel in back and white images.
267
+ :param t: Tensor-like image
268
+ :return: Tensor-like image with three channels
269
+ """
270
+ if len(t.shape) == 2:
271
+ return ToPILImage()(torch.stack([t for i in (0, 0, 0)]))
272
+ if t.shape[0] == 4:
273
+ return ToPILImage()(t[:3])
274
+ if t.shape[0] == 1:
275
+ return ToPILImage()(torch.stack([t[0] for i in (0, 0, 0)]))
276
+ return ToPILImage()(t)
277
+
278
+ def idx_to_text(i):
279
+ return cats[i]
280
+
281
+ # Random colors used for visualization
282
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
283
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
284
+
285
+ # for output bounding box post-processing
286
+ def box_cxcywh_to_xyxy(x):
287
+ x_c, y_c, w, h = x.unbind(1)
288
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
289
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
290
+ return torch.stack(b, dim=1)
291
+
292
+ def rescale_bboxes(out_bbox, size):
293
+ img_w, img_h = size
294
+ b = box_cxcywh_to_xyxy(out_bbox)
295
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
296
+
297
+ return b
298
+
299
+ def plot_results(pil_img, prob, boxes):
300
+ plt.figure(figsize=(16,10))
301
+ plt.imshow(pil_img)
302
+ ax = plt.gca()
303
+ colors = COLORS * 100
304
+ i = 0
305
+
306
+ crops = []
307
+ crop_classes = []
308
+ for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
309
+ cl = p.argmax()
310
+
311
+ # Save each box as an image
312
+ box_img = pil_img.crop((xmin, ymin, xmax, ymax))
313
+ crops.append(box_img)
314
+ crop_classes.append(idx_to_text(cl))
315
+
316
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
317
+ fill=False, color=c, linewidth=3))
318
+
319
+ ax.text(xmin, ymin, idx_to_text(cl), fontsize=10,
320
+ bbox=dict(facecolor=c, alpha=0.8))
321
+
322
+ i += 1
323
+
324
+ # Remove white padding all around the image
325
+ plt.axis('off')
326
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
327
+ output_img = plt.gcf()
328
+ plt.close()
329
+
330
+ return output_img, crops, crop_classes
331
+
332
+
333
+ def visualize_predictions(image, outputs, threshold=0.8):
334
+ # Keep only predictions with confidence >= threshold
335
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
336
+ keep = probas.max(-1).values > threshold
337
+
338
+ # Convert predicted boxes from [0; 1] to image scales
339
+ bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size)
340
+
341
+ # Get filtered probabilities and boxes based on the filter list
342
+ filter_set = set(filter)
343
+ filtered_probas_boxes = [
344
+ (proba, box) for proba, box in zip(probas[keep], bboxes_scaled)
345
+ if idx_to_text(proba.argmax()) not in filter_set
346
+ ]
347
+
348
+ # If there is a jumpsuit or dress detected, remove them if there are other clothes detected
349
+ contains_jumpsuit_or_dress = any(idx_to_text(proba.argmax()) in ["jumpsuit", "dress"] for proba, _ in filtered_probas_boxes)
350
+ if contains_jumpsuit_or_dress and len(filtered_probas_boxes) > 1:
351
+ filtered_probas_boxes = [
352
+ (proba, box) for proba, box in filtered_probas_boxes
353
+ if idx_to_text(proba.argmax()) not in ["jumpsuit", "dress"]
354
+ ]
355
+
356
+ # Remove duplicates: Only keep one box per class
357
+ unique_classes = set()
358
+ unique_filtered_probas_boxes = []
359
+ for proba, box in filtered_probas_boxes:
360
+ class_text = idx_to_text(proba.argmax())
361
+ if class_text not in unique_classes:
362
+ unique_classes.add(class_text)
363
+ unique_filtered_probas_boxes.append((proba, box))
364
+
365
+ # If there are remaining filtered probabilities, plot results
366
+ output_img = None
367
+ crops = None
368
+ crop_classes = None
369
+ if unique_filtered_probas_boxes:
370
+ final_probas, final_boxes = zip(*unique_filtered_probas_boxes)
371
+ output_img, crops, crop_classes = plot_results(image, list(final_probas), torch.stack(final_boxes))
372
+
373
+ # Return the classes of the detected objects
374
+ return [proba.argmax().item() for proba, _ in unique_filtered_probas_boxes], output_img, crops, crop_classes
375
+
376
+ @spaces.GPU
377
+ def get_objects(image, threshold=0.8):
378
+ class_counts = {}
379
+ image = fix_channels(ToTensor()(image))
380
+ image = image.resize((600, 800))
381
+
382
+ inputs = feature_extractor(images=image, return_tensors="pt")
383
+ outputs = model(**inputs.to(device))
384
+
385
+ detected_classes, output_img, crops, crop_classes = visualize_predictions(image, outputs, threshold=threshold)
386
+ for cl in detected_classes:
387
+ class_name = idx_to_text(cl)
388
+ if class_name not in class_counts:
389
+ class_counts[class_name] = 0
390
+ class_counts[class_name] += 1
391
+
392
+ if crop_classes is not None:
393
+ crop_classes = [get_yolo_index(c) for c in crop_classes]
394
+
395
+ return class_counts, output_img, crops, crop_classes
396
+
397
+
398
+
399
+ def get_cropped_images(images,category):
400
+ cropped_list = []
401
+ resultsPerCategory = {}
402
+ for num, image in enumerate(images):
403
+ image = open_image_from_url(image)
404
+ class_counts, output_img, cropped_images, cropped_classes = get_objects(image, 0.37)
405
+ print(cropped_images)
406
+ if not class_counts:
407
+ continue
408
+
409
+ # Get the inverse category as any other mapping label except the current one corresponding category
410
+ inverse_category = [label for i, labels in enumerate(label_mapping) for label in labels if i != get_category_index(category) and i != 0]
411
+
412
+ # If category is a cardigan, we don't recommend category indices 1 and 3
413
+ if category == 'women-sweatersknits-cardigan':
414
+ inverse_category = [label for i, labels in enumerate(label_mapping) for label in labels if i != get_category_index(category) and i != 1 and i != 3]
415
+
416
+ for i, image in enumerate(cropped_images):
417
+ cropped_category = cropped_classes[i]
418
+ print(cropped_category, cropped_classes[i], get_category_index(category))
419
+
420
+ specific_category = label_mapping[cropped_category]
421
+
422
+ if cropped_category == get_category_index(category):
423
+ continue
424
+
425
+ cropped_list.append(image)
426
+
427
+ return cropped_list
428
 
429
 
430