Object_detection / app_v1.py
Gabolozano's picture
Rename app.py to app_v1.py
425f707 verified
raw
history blame
3.12 kB
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
def load_model(model_name, threshold):
config = DetrConfig.from_pretrained(model_name, threshold=threshold)
model = DetrForObjectDetection.from_pretrained(model_name, config=config)
image_processor = DetrImageProcessor.from_pretrained(model_name)
return pipeline(task='object-detection', model=model, image_processor=image_processor)
# Load the initial model with default threshold
od_pipe = load_model("facebook/detr-resnet-101", 0.25) # Setting a default threshold
def draw_detections(image, detections):
np_image = np.array(image)
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
score = detection['score']
label = detection['label']
box = detection['box']
x_min, y_min = box['xmin'], box['ymin']
x_max, y_max = box['xmax'], box['ymax']
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
label_text = f'{label} {score:.2f}'
cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
final_pil_image = Image.fromarray(final_image)
return final_pil_image
def get_pipeline_prediction(model_name, threshold, pil_image):
global od_pipe
od_pipe = load_model(model_name, threshold) # Reload model with the specified model and threshold
try:
if not isinstance(pil_image, Image.Image):
pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
result = od_pipe(pil_image)
processed_image = draw_detections(pil_image, result)
description = f'Model used: {model_name}, Detection Threshold: {threshold}'
return processed_image, result, description
except Exception as e:
return pil_image, {"error": str(e)}, "Failed to process image"
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Object Detection")
inp_image = gr.Image(label="Upload your image here")
model_dropdown = gr.Dropdown(choices=["facebook/detr-resnet-50", "facebook/detr-resnet-50-panoptic", "facebook/detr-resnet-101", "facebook/detr-resnet-101-panoptic"], value="facebook/detr-resnet-101", label="Select Model")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
run_button = gr.Button("Detect Objects")
with gr.Column():
with gr.Tab("Annotated Image"):
output_image = gr.Image()
with gr.Tab("Detection Results"):
output_data = gr.JSON()
with gr.Tab("Description"):
description_output = gr.Textbox()
run_button.click(get_pipeline_prediction, inputs=[model_dropdown, threshold_slider, inp_image], outputs=[output_image, output_data, description_output])
demo.launch()