|
import gradio as gr |
|
import numpy as np |
|
from transformers import AutoFeatureExtractor, AutoModel |
|
from datasets import load_dataset |
|
from PIL import Image, ImageDraw |
|
import os |
|
|
|
|
|
|
|
print('Load model for computing embeddings of the candidate images') |
|
model_ckpt = "google/vit-base-patch16-224" |
|
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt) |
|
model = AutoModel.from_pretrained(model_ckpt) |
|
hidden_dim = model.config.hidden_size |
|
|
|
|
|
dataset_with_embeddings = load_dataset("tonyassi/vogue-runway-top15-512px-nobg-embeddings2", split="train") |
|
dataset_with_embeddings.add_faiss_index(column='embeddings') |
|
|
|
|
|
def get_neighbors(query_image, top_k=10): |
|
qi_embedding = model(**extractor(query_image, return_tensors="pt")) |
|
qi_embedding = qi_embedding.last_hidden_state[:, 0].detach().numpy().squeeze() |
|
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples('embeddings', qi_embedding, k=top_k) |
|
return scores, retrieved_examples |
|
|
|
|
|
|
|
def search(image_dict): |
|
|
|
|
|
query_image = Image.open(image_dict['composite']).convert(mode='RGB') |
|
|
|
|
|
scores, retrieved_examples = get_neighbors(query_image) |
|
|
|
|
|
|
|
|
|
result = [] |
|
for i in range(len(retrieved_examples["image"])): |
|
id = retrieved_examples["label"][i] |
|
print('id', id) |
|
label = dataset_with_embeddings.features["label"].names[id] |
|
print('label', label) |
|
result.append((retrieved_examples["image"][i], label)) |
|
|
|
return result, query_image |
|
|
|
iface = gr.Interface(fn=search, |
|
title='Sketch to Fashion Collection', |
|
description=""" |
|
Tony Assi |
|
""", |
|
inputs=gr.ImageEditor(label='Sketchpad' ,type='filepath', value={'background':'./model2.png', 'layers':None, 'composite':None}, sources=['upload'], transforms=[]), |
|
outputs=[gr.Gallery(label='Similar', object_fit='contain', height=900), gr.Image()], |
|
|
|
theme = gr.themes.Base(primary_hue="teal",secondary_hue="teal",neutral_hue="slate"),) |
|
iface.launch() |