|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection |
|
import torch |
|
|
|
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50') |
|
dmodel = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50') |
|
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
i1 = gr.inputs.Image(type="pil") |
|
i2 = gr.inputs.Textbox() |
|
o1 = gr.outputs.Image(type="pil") |
|
o2 = gr.outputs.Textbox() |
|
|
|
def extract_image(image, text, num=1): |
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
outputs = dmodel(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
bboxes = outputs.pred_boxes |
|
probas = outputs.logits.softmax(-1)[0, :, :-1] |
|
|
|
keep = probas.max(-1).values > 0.96 |
|
outs = feature_extractor.post_process(outputs, torch.tensor(image.size[::-1]).unsqueeze(0)) |
|
bboxes_scaled = outs[0]['boxes'][keep].detach().numpy() |
|
labels = outs[0]['labels'][keep].detach().numpy() |
|
scores = outs[0]['scores'][keep].detach().numpy() |
|
|
|
images_list = [] |
|
for i,j in enumerate(bboxes_scaled): |
|
|
|
xmin = int(j[0]) |
|
ymin = int(j[1]) |
|
xmax = int(j[2]) |
|
ymax = int(j[3]) |
|
|
|
im_arr = np.array(image) |
|
roi = im_arr[ymin:ymax, xmin:xmax] |
|
roi_im = Image.fromarray(roi) |
|
|
|
images_list.append(roi_im) |
|
|
|
inpu = processor(text = [text], images=images_list , return_tensors="pt", padding=True) |
|
output = model(**inpu) |
|
logits_per_image = output.logits_per_text |
|
probs = logits_per_image.softmax(-1) |
|
l_idx = np.argsort(probs[-1].detach().numpy())[::-1][0:num] |
|
|
|
final_ims = [] |
|
for i,j in enumerate(images_list): |
|
json_dict = {} |
|
if i in l_idx: |
|
json_dict['image'] = images_list[i] |
|
json_dict['score'] = probs[-1].detach().numpy()[i] |
|
|
|
final_ims.append(json_dict) |
|
|
|
fi = sorted(final_ims, key=lambda item: item.get("score"), reverse=True) |
|
return fi[0]['image'], fi[0]['score'] |
|
|
|
title = "ClipnCrop" |
|
description = "Extract sections of images from your image by using OpenAI's CLIP and Facebooks Detr implemented on HuggingFace Transformers" |
|
examples=[['ex3.jpg', 'black bag'],['ex2.jpg', 'man in red dress']] |
|
article = "<p style='text-align: center'>" |
|
gr.Interface(fn=extract_image, inputs=[i1, i2], outputs=[o1, o2], title=title, description=description, article=article, examples=examples, enable_queue=True).launch() |