Gabolozano commited on
Commit
6ed6f8e
1 Parent(s): e3205c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -45
app.py CHANGED
@@ -5,82 +5,63 @@ import numpy as np
5
  import cv2
6
  from PIL import Image
7
 
8
- # Initialize the model
9
- config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
10
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
11
- image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
 
13
  def load_model(threshold):
14
- # Reinitialize the model with the desired detection threshold
15
  config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
 
16
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
17
- image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
18
- return pipeline(task='object-detection', model=model, image_processor=image_processor)
19
 
20
- od_pipe = load_model(0.5) # Default threshold
 
21
 
22
  def draw_detections(image, detections):
23
- # Convert PIL image to a numpy array
24
  np_image = np.array(image)
25
-
26
- # Convert RGB to BGR for OpenCV
27
  np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
28
 
29
- # Draw detections
30
  for detection in detections:
31
  score = detection['score']
32
  label = detection['label']
33
  box = detection['box']
34
- x_min = box['xmin']
35
- y_min = box['ymin']
36
- x_max = box['xmax']
37
- y_max = box['ymax']
38
-
39
- # Increase font size for better visibility
40
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
41
- label_text = f'{label} {score:.2f}'
42
- cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
43
 
44
- # Convert BGR to RGB for displaying
45
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
46
- final_pil_image = Image.fromarray(final_image)
47
- return final_pil_image
48
 
49
  def get_pipeline_prediction(threshold, pil_image):
50
  global od_pipe
 
51
  try:
52
- # Check if the model threshold needs adjusting
53
- if od_pipe.config.threshold != threshold:
54
- od_pipe = load_model(threshold)
55
- print("Model reloaded with new threshold:", threshold)
56
-
57
- # Ensure input is a PIL image
58
  if not isinstance(pil_image, Image.Image):
59
  pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
60
-
61
- # Run detection and return annotated image and results
62
- pipeline_output = od_pipe(pil_image)
63
- processed_image = draw_detections(pil_image, pipeline_output)
64
- return processed_image, pipeline_output
65
  except Exception as e:
66
- error_message = f"An error occurred: {str(e)}"
67
- print(error_message)
68
- return pil_image, {"error": error_message}
69
 
70
- # Gradio interface
71
  with gr.Blocks() as demo:
72
  with gr.Row():
73
  with gr.Column():
74
- inp_image = gr.Image(label="Input image")
75
- slider = gr.Slider(minimum=0, maximum=1, step=0.05, label="Detection Sensitivity", value=0.5)
76
- gr.Markdown("Adjust the slider to change detection sensitivity.")
77
- btn_run = gr.Button('Run Detection')
78
  with gr.Column():
79
  with gr.Tab("Annotated Image"):
80
- out_image = gr.Image()
81
  with gr.Tab("Detection Results"):
82
- out_json = gr.JSON()
83
 
84
- btn_run.click(get_pipeline_prediction, inputs=[slider, inp_image], outputs=[out_image, out_json])
85
 
86
  demo.launch()
 
5
  import cv2
6
  from PIL import Image
7
 
8
+ # Pre-load the base configuration and models (without setting a threshold yet)
9
+ base_config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
10
+ base_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=base_config)
11
+ base_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
12
 
13
  def load_model(threshold):
14
+ # Adjust the configuration for the current threshold
15
  config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
16
+ # Create a new model instance with the updated configuration
17
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
18
+ # Image processor does not need to be re-loaded
19
+ return pipeline(task='object-detection', model=model, image_processor=base_processor)
20
 
21
+ # Initialize the pipeline with a default threshold
22
+ od_pipe = load_model(0.25) # Set a default threshold here
23
 
24
  def draw_detections(image, detections):
 
25
  np_image = np.array(image)
 
 
26
  np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
27
 
 
28
  for detection in detections:
29
  score = detection['score']
30
  label = detection['label']
31
  box = detection['box']
32
+ x_min, y_min = box['xmin'], box['ymin']
33
+ x_max, y_max = box['xmax'], box['ymax']
 
 
 
 
34
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
35
+ cv2.putText(np_image, f"{label} {score:.2f}", (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
 
36
 
 
37
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
38
+ return Image.fromarray(final_image)
 
39
 
40
  def get_pipeline_prediction(threshold, pil_image):
41
  global od_pipe
42
+ od_pipe = load_model(threshold) # reload model with the specified threshold
43
  try:
 
 
 
 
 
 
44
  if not isinstance(pil_image, Image.Image):
45
  pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
46
+ result = od_pipe(pil_image)
47
+ processed_image = draw_detections(pil_image, result)
48
+ return processed_image, result
 
 
49
  except Exception as e:
50
+ return pil_image, {"error": str(e)}
 
 
51
 
 
52
  with gr.Blocks() as demo:
53
  with gr.Row():
54
  with gr.Column():
55
+ gr.Markdown("## Object Detection")
56
+ inp_image = gr.Image(label="Upload your image here")
57
+ threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
58
+ run_button = gr.Button("Detect Objects")
59
  with gr.Column():
60
  with gr.Tab("Annotated Image"):
61
+ output_image = gr.Image()
62
  with gr.Tab("Detection Results"):
63
+ output_data = gr.JSON()
64
 
65
+ run_button.click(get_pipeline_prediction, inputs=[threshold_slider, inp_image], outputs=[output_image, output_data])
66
 
67
  demo.launch()