Gabolozano commited on
Commit
8a42819
1 Parent(s): baa648a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -26
app.py CHANGED
@@ -4,51 +4,34 @@ from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImage
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
7
- import warnings
8
- import logging
9
- # To suppress all warnings entries
10
- warnings.filterwarnings('ignore')
11
-
12
- # To ignore specific loggings from the Transformers library
13
- logging.getLogger("transformers").setLevel(logging.ERROR)
14
-
15
- def model_is_panoptic(model_name):
16
- return "panoptic" in model_name
17
 
18
  def load_model(model_name, threshold):
19
  config = DetrConfig.from_pretrained(model_name, threshold=threshold)
20
  model = DetrForObjectDetection.from_pretrained(model_name, config=config)
21
  image_processor = DetrImageProcessor.from_pretrained(model_name)
22
  return pipeline(task='object-detection', model=model, image_processor=image_processor)
23
- # Initial model with default threshold
24
- od_pipe = load_model("facebook/detr-resnet-101", 0.25)
25
 
26
  def draw_detections(image, detections, model_name):
27
  np_image = np.array(image)
28
  np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
29
  for detection in detections:
30
- if model_is_panoptic(model_name):
31
- # Handle segmentations for panoptic models
32
  mask = detection['mask']
33
  color = np.random.randint(0, 255, size=3)
34
  mask = np.round(mask * 255).astype(np.uint8)
35
  mask = cv2.resize(mask, (image.width, image.height))
36
  mask_image = np.stack([mask]*3, axis=-1)
37
  np_image[mask == 255] = np_image[mask == 255] * 0.5 + color * 0.5
38
- else:
39
- # Handle bounding boxes for standard models
40
- score = detection['score']
41
- label = detection['label']
42
  box = detection['box']
43
- x_min, y_min = box['xmin'], box['ymin']
44
- x_max, y_max = box['xmax'], box['ymax']
45
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
46
- label_text = f'{label} {score:.2f}'
47
- cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
48
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
49
- final_pil_image = Image.fromarray(final_image)
50
- return final_pil_image
51
-
52
  def get_pipeline_prediction(model_name, threshold, pil_image):
53
  global od_pipe
54
  od_pipe = load_model(model_name, threshold)
@@ -56,6 +39,7 @@ def get_pipeline_prediction(model_name, threshold, pil_image):
56
  if not isinstance(pil_image, Image.Image):
57
  pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
58
  result = od_pipe(pil_image)
 
59
  processed_image = draw_detections(pil_image, result, model_name)
60
  description = f'Model used: {model_name}, Detection Threshold: {threshold}'
61
  return processed_image, result, description
@@ -77,5 +61,4 @@ with gr.Blocks() as demo:
77
  with gr.Tab("Description"):
78
  description_output = gr.Textbox()
79
  run_button.click(get_pipeline_prediction, inputs=[model_dropdown, threshold_slider, inp_image], outputs=[output_image, output_data, description_output])
80
-
81
  demo.launch()
 
4
  import numpy as np
5
  import cv2
6
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
7
 
8
  def load_model(model_name, threshold):
9
  config = DetrConfig.from_pretrained(model_name, threshold=threshold)
10
  model = DetrForObjectDetection.from_pretrained(model_name, config=config)
11
  image_processor = DetrImageProcessor.from_pretrained(model_name)
12
  return pipeline(task='object-detection', model=model, image_processor=image_processor)
13
+ od_pipe = load_model("facebook/detr-resnet-101", 0.25) # Default model
 
14
 
15
  def draw_detections(image, detections, model_name):
16
  np_image = np.array(image)
17
  np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
18
  for detection in detections:
19
+ if 'mask' in detection:
20
+ # Interpret and visualize segmentation mask
21
  mask = detection['mask']
22
  color = np.random.randint(0, 255, size=3)
23
  mask = np.round(mask * 255).astype(np.uint8)
24
  mask = cv2.resize(mask, (image.width, image.height))
25
  mask_image = np.stack([mask]*3, axis=-1)
26
  np_image[mask == 255] = np_image[mask == 255] * 0.5 + color * 0.5
27
+ if 'box' in detection:
28
+ # Visualize bounding box
 
 
29
  box = detection['box']
30
+ x_min, y_min, x_max, y_max = [int(b) for b in [box['xmin'], box['ymin'], box['xmax'], box['ymax']]]
 
31
  cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
 
 
32
  final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
33
+ return Image.fromarray(final_image)
34
+
 
35
  def get_pipeline_prediction(model_name, threshold, pil_image):
36
  global od_pipe
37
  od_pipe = load_model(model_name, threshold)
 
39
  if not isinstance(pil_image, Image.Image):
40
  pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
41
  result = od_pipe(pil_image)
42
+ print("Detection Output:", result) # Debug: Check the output structure
43
  processed_image = draw_detections(pil_image, result, model_name)
44
  description = f'Model used: {model_name}, Detection Threshold: {threshold}'
45
  return processed_image, result, description
 
61
  with gr.Tab("Description"):
62
  description_output = gr.Textbox()
63
  run_button.click(get_pipeline_prediction, inputs=[model_dropdown, threshold_slider, inp_image], outputs=[output_image, output_data, description_output])
 
64
  demo.launch()