CLIPnCROP / app.py
vishnun's picture
Update app.py
3f7bb9f
import gradio as gr
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, YolosImageProcessor, YolosForObjectDetection
import torch
i1 = gr.Image(type="pil", label="Input image")
i2 = gr.Textbox(label="Description for section to extracted")
i3 = gr.Number(value=0.96, label="Threshold percentage score")
o1 = gr.Image(type="pil", label="Extracted Crop part")
o2 = gr.Textbox(label="Similarity score")
feature_extractor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
dmodel = YolosForObjectDetection.from_pretrained('hustvl/yolos-tiny')
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
def extract_image(image, text, prob, num=1):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = dmodel(**inputs)
# model predicts bounding boxes and corresponding COCO classes
logits = outputs.logits
bboxes = outputs.pred_boxes
probas = outputs.logits.softmax(-1)[0, :, :-1] #removing no class as detr maps
keep = probas.max(-1).values > prob
outs = feature_extractor.post_process(outputs, torch.tensor(image.size[::-1]).unsqueeze(0))
bboxes_scaled = outs[0]['boxes'][keep].detach().numpy()
labels = outs[0]['labels'][keep].detach().numpy()
scores = outs[0]['scores'][keep].detach().numpy()
images_list = []
for i,j in enumerate(bboxes_scaled):
xmin = int(j[0])
ymin = int(j[1])
xmax = int(j[2])
ymax = int(j[3])
im_arr = np.array(image)
roi = im_arr[ymin:ymax, xmin:xmax]
roi_im = Image.fromarray(roi)
images_list.append(roi_im)
inpu = processor(text = [text], images=images_list , return_tensors="pt", padding=True)
output = model(**inpu)
logits_per_image = output.logits_per_text
probs = logits_per_image.softmax(-1)
l_idx = np.argsort(probs[-1].detach().numpy())[::-1][0:num]
final_ims = []
for i,j in enumerate(images_list):
json_dict = {}
if i in l_idx:
json_dict['image'] = images_list[i]
json_dict['score'] = probs[-1].detach().numpy()[i]
final_ims.append(json_dict)
fi = sorted(final_ims, key=lambda item: item.get("score"), reverse=True)
return fi[0]['image'], fi[0]['score']
title = "ClipnCrop"
description = "<p style= 'color:white'>Extract sections of images from your image by using OpenAI's CLIP and Facebooks Detr implemented on HuggingFace Transformers, if the similarity score is not so much, then please consider the prediction to be void.</p>"
examples=[['ex3.jpg', 'black bag', 0.96],['ex2.jpg', 'man in red dress', 0.85]]
article = "<p style= 'color:white; text-align:center;'><a href='https://github.com/Vishnunkumar/clipcrop' target='_blank'>clipcrop</a></p>"
gr_app = gr.Interface(fn=extract_image, inputs=[i1, i2, i3], outputs=[o1, o2], title=title, description=description, article=article, examples=examples)
gr_app.launch()