CLIPnCROP / app.py
vishnun's picture
Update app.py
5e78ed3
raw
history blame
2.67 kB
import gradio as gr
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection
import torch
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50')
dmodel = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
i1 = gr.inputs.Image(type="pil")
i2 = gr.inputs.Textbox()
o1 = gr.outputs.Image(type="pil")
o2 = gr.outputs.Textbox()
def extract_image(image, text, num=1):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = dmodel(**inputs)
# model predicts bounding boxes and corresponding COCO classes
logits = outputs.logits
bboxes = outputs.pred_boxes
probas = outputs.logits.softmax(-1)[0, :, :-1] #removing no class as detr maps
keep = probas.max(-1).values > 0.96
outs = feature_extractor.post_process(outputs, torch.tensor(image.size[::-1]).unsqueeze(0))
bboxes_scaled = outs[0]['boxes'][keep].detach().numpy()
labels = outs[0]['labels'][keep].detach().numpy()
scores = outs[0]['scores'][keep].detach().numpy()
images_list = []
for i,j in enumerate(bboxes_scaled):
xmin = int(j[0])
ymin = int(j[1])
xmax = int(j[2])
ymax = int(j[3])
im_arr = np.array(image)
roi = im_arr[ymin:ymax, xmin:xmax]
roi_im = Image.fromarray(roi)
images_list.append(roi_im)
inpu = processor(text = [text], images=images_list , return_tensors="pt", padding=True)
output = model(**inpu)
logits_per_image = output.logits_per_text
probs = logits_per_image.softmax(-1)
l_idx = np.argsort(probs[-1].detach().numpy())[::-1][0:num]
final_ims = []
for i,j in enumerate(images_list):
json_dict = {}
if i in l_idx:
json_dict['image'] = images_list[i]
json_dict['score'] = probs[-1].detach().numpy()[i]
final_ims.append(json_dict)
fi = sorted(final_ims, key=lambda item: item.get("score"), reverse=True)
return fi[0]['image'], fi[0]['score']
title = "ClipnCrop"
description = "Extract sections of images from your image by using OpenAI's CLIP and Facebooks Detr implemented on HuggingFace Transformers"
examples=[['ex3.jpg', 'black bag'],['ex2.jpg', 'man in red dress']]
article = "<p style='text-align: center'>"
gr.Interface(fn=extract_image, inputs=[i1, i2], outputs=[o1, o2], title=title, description=description, article=article, examples=examples, enable_queue=True).launch()