IbrahimHasani's picture
Update app.py
cbab6c8 verified
raw
history blame
No virus
8.13 kB
import gradio as gr
import torch
import numpy as np
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from PIL import Image, ImageDraw
import cv2
import torch.nn.functional as F
import tempfile
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from io import BytesIO
from SuperGluePretrainedNetwork.models.matching import Matching
from SuperGluePretrainedNetwork.models.utils import read_image
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
matching = Matching({
'superpoint': {'nms_radius': 4, 'keypoint_threshold': 0.005, 'max_keypoints': 1024},
'superglue': {'weights': 'outdoor', 'sinkhorn_iterations': 20, 'match_threshold': 0.2}
}).eval().to(device)
# Utility functions
def save_array_to_temp_image(arr):
rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
img = Image.fromarray(rgb_arr)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file_name = temp_file.name
temp_file.close()
img.save(temp_file_name)
return temp_file_name
def unified_matching_plot2(image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, text, path=None, show_keypoints=False, fast_viz=False, opencv_display=False, opencv_title='matches', small_text=[]):
height = min(image0.shape[0], image1.shape[0])
image0_resized = cv2.resize(image0, (int(image0.shape[1] * height / image0.shape[0]), height))
image1_resized = cv2.resize(image1, (int(image1.shape[1] * height / image1.shape[0]), height))
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.imshow(image0_resized)
plt.scatter(kpts0[:, 0], kpts0[:, 1], color='r', s=1)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image1_resized)
plt.scatter(kpts1[:, 0], kpts1[:, 1], color='r', s=1)
plt.axis('off')
fig, ax = plt.subplots(figsize=(20, 20))
plt.plot([mkpts0[:, 0], mkpts1[:, 0] + image0_resized.shape[1]], [mkpts0[:, 1], mkpts1[:, 1]], 'r', lw=0.5)
plt.scatter(mkpts0[:, 0], mkpts0[:, 1], s=2, marker='o', color='b')
plt.scatter(mkpts1[:, 0] + image0_resized.shape[1], mkpts1[:, 1], s=2, marker='o', color='g')
plt.imshow(np.hstack([image0_resized, image1_resized]), aspect='auto')
plt.suptitle('\n'.join(text), fontsize=20, fontweight='bold')
plt.tight_layout()
plt.show()
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.close(fig)
return img
def stitch_images(images):
"""Stitches a list of images vertically."""
if not images:
return Image.new('RGB', (100, 100), color='gray')
max_width = max([img.width for img in images])
total_height = sum(img.height for img in images)
composite = Image.new('RGB', (max_width, total_height))
y_offset = 0
for img in images:
composite.paste(img, (0, y_offset))
y_offset += img.height
return composite
# Main functions
def detect_and_crop(target_image, query_image, threshold=0.5, nms_threshold=0.3):
target_sizes = torch.Tensor([target_image.size[::-1]])
inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.image_guided_detection(**inputs)
img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
outputs.logits = outputs.logits.cpu()
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
boxes, scores = results[0]["boxes"], results[0]["scores"]
if len(boxes) == 0:
return [], None
filtered_boxes = []
for box in boxes:
x1, y1, x2, y2 = [int(i) for i in box.tolist()]
cropped_img = img[y1:y2, x1:x2]
if cropped_img.size != 0:
filtered_boxes.append(cropped_img)
draw = ImageDraw.Draw(target_image)
for box in boxes:
draw.rectangle(box.tolist(), outline="red", width=3)
return filtered_boxes, target_image
def image_matching_no_pyramid(query_img, target_img, visualize=True):
temp_query = save_array_to_temp_image(np.array(query_img))
temp_target = save_array_to_temp_image(np.array(target_img))
image1, inp1, scales1 = read_image(temp_target, device, [640*2], 0, True)
image0, inp0, scales0 = read_image(temp_query, device, [640*2], 0, True)
if image0 is None or image1 is None:
return None
pred = matching({'image0': inp0, 'image1': inp1})
pred = {k: v[0] for k, v in pred.items()}
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
matches, conf = pred['matches0'], pred['matching_scores0']
valid = matches > -1
mkpts0 = kpts0[valid]
mkpts1 = kpts1[matches[valid]]
mconf = conf[valid]
color = cm.jet(mconf.detach().cpu().numpy())[:len(mkpts0)]
valid_count = np.sum(valid.tolist())
mkpts0_np = mkpts0.cpu().numpy()
mkpts1_np = mkpts1.cpu().numpy()
try:
H, inliers = cv2.findHomography(mkpts0_np, mkpts1_np, cv2.RANSAC, 5.0)
except:
inliers = 0
num_inliers = np.sum(inliers)
if visualize:
visualized_img = unified_matching_plot2(
image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, ['Matches'], True, False, True, 'Matches', [])
else:
visualized_img = None
return {
'valid': [valid_count],
'inliers': [num_inliers],
'visualized_image': [visualized_img]
}
def check_object_in_image(query_image, target_image, threshold=50, scale_factor=[0.33, 0.66, 1]):
images_to_return = []
cropped_images, bbox_image = detect_and_crop(target_image, query_image)
temp_files = [save_array_to_temp_image(i) for i in cropped_images]
crop_results = [image_matching_no_pyramid(query_image, Image.open(i), visualize=True) for i in temp_files]
cropped_visuals = []
cropped_inliers = []
for result in crop_results:
if result:
for img in result['visualized_image']:
cropped_visuals.append(Image.fromarray(img))
for inliers_ in result['inliers']:
cropped_inliers.append(inliers_)
images_to_return.append(stitch_images(cropped_visuals))
is_present = any(value >= threshold for value in cropped_inliers)
return {
'is_present': is_present,
'image_with_boxes': bbox_image,
'object_detection_inliers': [int(i) for i in cropped_inliers],
}
def interface(poster_source, media_source, threshold, scale_factor):
result1 = check_object_in_image(poster_source, media_source, threshold, scale_factor)
if result1['is_present']:
return result1['is_present'], result1['image_with_boxes']
result2 = check_object_in_image(poster_source, media_source, threshold, scale_factor)
return result2['is_present'], result2['image_with_boxes']
iface = gr.Interface(
fn=interface,
inputs=[
gr.Image(type="pil", label="Upload a Query Image (Poster)"),
gr.Image(type="pil", label="Upload a Target Image (Media)"),
gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Threshold"),
gr.CheckboxGroup(choices=["0.33", "0.66", "1.0"], value=["0.33", "0.66", "1.0"], label="Scale Factors"),
],
outputs=[
gr.Label(label="Object Presence"),
gr.Image(type="pil", label="Detected Bounding Boxes"),
],
title="Object Detection in Images",
description="""
This application allows you to check if an object in a query image (poster) is present in a target image (media).
Steps:
1. Upload a Query Image (Poster)
2. Upload a Target Image (Media)
3. Set Threshold
4. Set Scale Factors
5. View Results
"""
)
if __name__ == "__main__":
iface.launch()