|
import json |
|
import spaces |
|
import requests |
|
import numpy as np |
|
import gradio as gr |
|
from PIL import Image |
|
from io import BytesIO |
|
from turtle import title |
|
from transformers import pipeline |
|
import ast |
|
pipe = pipeline("zero-shot-image-classification", model="patrickjohncyh/fashion-clip") |
|
|
|
color_file_path = 'color_config.json' |
|
attributes_file_path = 'attributes_config.json' |
|
|
|
|
|
with open(color_file_path, 'r') as file: |
|
color_data = json.load(file) |
|
|
|
|
|
with open(attributes_file_path, 'r') as file: |
|
attributes_data = json.load(file) |
|
|
|
COLOURS_DICT = color_data['color_mapping'] |
|
ATTRIBUTES_DICT = attributes_data['attribute_mapping'] |
|
|
|
|
|
def shot(input, category): |
|
subColour,mainColour,score = get_colour(ast.literal_eval(str(input)),category) |
|
common_result = get_predicted_attributes(ast.literal_eval(str(input)),category) |
|
return { |
|
"colors":{ |
|
"main":mainColour, |
|
"sub":subColour, |
|
"score":round(score*100,2) |
|
}, |
|
"attributes":common_result |
|
} |
|
|
|
|
|
|
|
@spaces.GPU |
|
def get_colour(image_urls, category): |
|
colourLabels = list(COLOURS_DICT.keys()) |
|
for i in range(len(colourLabels)): |
|
colourLabels[i] = colourLabels[i] + " clothing: " + category |
|
|
|
responses = pipe(image_urls, candidate_labels=colourLabels) |
|
|
|
mainColour = responses[0][0]['label'].split(" clothing:")[0] |
|
|
|
|
|
if mainColour not in COLOURS_DICT: |
|
return None, None, None |
|
|
|
|
|
labels = COLOURS_DICT[mainColour] |
|
for i in range(len(labels)): |
|
labels[i] = labels[i] + " clothing: " + category |
|
|
|
|
|
responses = pipe(image_urls, candidate_labels=labels) |
|
subColour = responses[0][0]['label'].split(" clothing:")[0] |
|
|
|
return subColour, mainColour, responses[0][0]['score'] |
|
|
|
@spaces.GPU |
|
def get_predicted_attributes(image_urls, category): |
|
|
|
|
|
|
|
attributes = list(ATTRIBUTES_DICT.get(category,{}).keys()) |
|
|
|
|
|
common_result = [] |
|
for attribute in attributes: |
|
|
|
values = list(ATTRIBUTES_DICT.get(category,{}).get(attribute,{}).keys()) |
|
|
|
if len(values) == 0: |
|
continue |
|
|
|
|
|
attribute = attribute.replace("colartype", "collar").replace("sleevelength", "sleeve length").replace("fabricstyle", "fabric") |
|
values = [f"{attribute}: {value}, clothing: {category}" for value in values] |
|
|
|
|
|
responses = pipe(image_urls, candidate_labels=values, device=device) |
|
result = [response[0]['label'].split(", clothing:")[0] for response in responses] |
|
|
|
|
|
if attribute == "details": |
|
result += [response[1]['label'].split(", clothing:")[0] for response in responses] |
|
common_result.append(Counter(result).most_common(2)) |
|
else: |
|
common_result.append(Counter(result).most_common(1)) |
|
|
|
|
|
for i, result in enumerate(common_result): |
|
common_result[i] = ", ".join([f"{x[0]}" for x in result]) |
|
|
|
return common_result |
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=shot, |
|
inputs=[ |
|
gr.Textbox(label="Image URLs (starting with http/https) comma seperated "), |
|
gr.Textbox(label="Category") |
|
], |
|
outputs="text" , |
|
description="Add an image URL (starting with http/https) or upload a picture, and provide a list of labels separated by commas.", |
|
title="Full product flow" |
|
) |
|
|
|
|
|
iface.launch() |
|
|