Gabolozano's picture
Update app.py
8a42819 verified
raw
history blame
3.35 kB
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
def load_model(model_name, threshold):
config = DetrConfig.from_pretrained(model_name, threshold=threshold)
model = DetrForObjectDetection.from_pretrained(model_name, config=config)
image_processor = DetrImageProcessor.from_pretrained(model_name)
return pipeline(task='object-detection', model=model, image_processor=image_processor)
od_pipe = load_model("facebook/detr-resnet-101", 0.25) # Default model
def draw_detections(image, detections, model_name):
np_image = np.array(image)
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
if 'mask' in detection:
# Interpret and visualize segmentation mask
mask = detection['mask']
color = np.random.randint(0, 255, size=3)
mask = np.round(mask * 255).astype(np.uint8)
mask = cv2.resize(mask, (image.width, image.height))
mask_image = np.stack([mask]*3, axis=-1)
np_image[mask == 255] = np_image[mask == 255] * 0.5 + color * 0.5
if 'box' in detection:
# Visualize bounding box
box = detection['box']
x_min, y_min, x_max, y_max = [int(b) for b in [box['xmin'], box['ymin'], box['xmax'], box['ymax']]]
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
return Image.fromarray(final_image)
def get_pipeline_prediction(model_name, threshold, pil_image):
global od_pipe
od_pipe = load_model(model_name, threshold)
try:
if not isinstance(pil_image, Image.Image):
pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
result = od_pipe(pil_image)
print("Detection Output:", result) # Debug: Check the output structure
processed_image = draw_detections(pil_image, result, model_name)
description = f'Model used: {model_name}, Detection Threshold: {threshold}'
return processed_image, result, description
except Exception as e:
return pil_image, {"error": str(e)}, "Failed to process image"
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Object Detection")
inp_image = gr.Image(label="Upload your image here")
model_dropdown = gr.Dropdown(choices=["facebook/detr-resnet-50", "facebook/detr-resnet-50-panoptic", "facebook/detr-resnet-101", "facebook/detr-resnet-101-panoptic"], value="facebook/detr-resnet-101", label="Select Model")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
run_button = gr.Button("Detect Objects")
with gr.Column():
with gr.Tab("Annotated Image"):
output_image = gr.Image()
with gr.Tab("Detection Results"):
output_data = gr.JSON()
with gr.Tab("Description"):
description_output = gr.Textbox()
run_button.click(get_pipeline_prediction, inputs=[model_dropdown, threshold_slider, inp_image], outputs=[output_image, output_data, description_output])
demo.launch()