Gabolozano's picture
Update app.py
b6fa5d6 verified
raw
history blame
2.44 kB
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
# Initialize the model
config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
# Initialize the pipeline
od_pipe = pipeline(task='object-detection', model=model, image_processor=image_processor)
def draw_detections(image, detections):
# Convert PIL image to a numpy array
np_image = np.array(image)
# Convert RGB to BGR for OpenCV
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
score = detection['score']
label = detection['label']
box = detection['box']
x_min = box['xmin']
y_min = box['ymin']
x_max = box['xmax']
y_max = box['ymax']
# Draw rectangles and label with a larger font size
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
label_text = f'{label} {score:.2f}'
label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0]
label_x = x_min
label_y = y_min - label_size[1] if y_min - label_size[1] > 10 else y_min + label_size[1]
cv2.putText(np_image, label_text, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
# Convert BGR to RGB for displaying
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
final_pil_image = Image.fromarray(final_image)
return final_pil_image
def get_pipeline_prediction(pil_image):
try:
pipeline_output = od_pipe(pil_image)
processed_image = draw_detections(pil_image, pipeline_output)
return processed_image, pipeline_output
except Exception as e:
print(f"An error occurred: {str(e)}")
return pil_image, {"error": str(e)}
# Setting up Gradio interface with tabs for the outputs
demo = gr.Interface(
fn=get_pipeline_prediction,
inputs=gr.inputs.Image(label="Input image", type="pil"),
outputs=[
gr.outputs.Image(type="pil", label="Annotated Image"),
gr.outputs.JSON(label="Detected Objects")
],
outputs_per_tab =[['image'], ['json']],
tabs=["Annotated Image", "Detection Results"]
)
demo.launch()