File size: 3,115 Bytes
d4c3acc
 
246f207
ceb95cf
 
 
c21a752
b518471
 
 
 
 
648d371
b518471
 
648d371
ceb95cf
 
 
 
 
 
d5aac08
aca9f11
6ed6f8e
 
ceb95cf
793cc29
 
ceb95cf
 
793cc29
 
6564a3a
b518471
b842d10
b518471
cfd90f9
e3205c6
 
6ed6f8e
 
b518471
 
cfd90f9
b518471
cfd90f9
b3747be
 
5769b69
6ed6f8e
 
b518471
6ed6f8e
 
5769b69
 
6ed6f8e
5769b69
6ed6f8e
b518471
 
e3205c6
b518471
e14364d
aca9f11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image

def load_model(model_name, threshold):
    config = DetrConfig.from_pretrained(model_name, threshold=threshold)
    model = DetrForObjectDetection.from_pretrained(model_name, config=config)
    image_processor = DetrImageProcessor.from_pretrained(model_name)
    return pipeline(task='object-detection', model=model, image_processor=image_processor)

# Load the initial model with default threshold
od_pipe = load_model("facebook/detr-resnet-101", 0.25)  # Setting a default threshold

def draw_detections(image, detections):
    np_image = np.array(image)
    np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
    
    for detection in detections:
        score = detection['score']
        label = detection['label']
        box = detection['box']
        x_min, y_min = box['xmin'], box['ymin']
        x_max, y_max = box['xmax'], box['ymax']
        cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
        label_text = f'{label} {score:.2f}'
        cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
    
    final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
    final_pil_image = Image.fromarray(final_image)
    return final_pil_image

def get_pipeline_prediction(model_name, threshold, pil_image):
    global od_pipe
    od_pipe = load_model(model_name, threshold)  # Reload model with the specified model and threshold
    try:
        if not isinstance(pil_image, Image.Image):
            pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
        result = od_pipe(pil_image)
        processed_image = draw_detections(pil_image, result)
        description = f'Model used: {model_name}, Detection Threshold: {threshold}'
        return processed_image, result, description
    except Exception as e:
        return pil_image, {"error": str(e)}, "Failed to process image"

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown("## Object Detection")
            inp_image = gr.Image(label="Upload your image here")
            model_dropdown = gr.Dropdown(choices=["facebook/detr-resnet-50", "facebook/detr-resnet-50-panoptic", "facebook/detr-resnet-101", "facebook/detr-resnet-101-panoptic"], value="facebook/detr-resnet-101", label="Select Model")
            threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
            run_button = gr.Button("Detect Objects")
        with gr.Column():
            with gr.Tab("Annotated Image"):
                output_image = gr.Image()
            with gr.Tab("Detection Results"):
                output_data = gr.JSON()
            with gr.Tab("Description"):
                description_output = gr.Textbox()

    run_button.click(get_pipeline_prediction, inputs=[model_dropdown, threshold_slider, inp_image], outputs=[output_image, output_data, description_output])

demo.launch()