isLinXu commited on
Commit
0141f34
·
1 Parent(s): 4baf486
Files changed (1) hide show
  1. app.py +60 -20
app.py CHANGED
@@ -238,30 +238,62 @@ class AsyncPredictor:
238
 
239
 
240
  detectron2_model_list = {
 
 
 
 
 
 
241
  "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x":{
242
  "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
243
  "ckpts": "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"
244
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  }
246
 
247
 
248
- # def dtectron2_instance_inference(image, config_file, ckpts, device):
249
- # cfg = get_cfg()
250
- # cfg.merge_from_file(config_file)
251
- # cfg.MODEL.WEIGHTS = ckpts
252
- # cfg.MODEL.DEVICE = "cpu"
253
- # cfg.output = "output_img.jpg"
254
- # visualization_demo = VisualizationDemo(cfg, device=device)
255
- # if image:
256
- # intput_path = "intput_img.jpg"
257
- # image.save(intput_path)
258
- # image = read_image(intput_path, format="BGR")
259
- # predictions, vis_output = visualization_demo.run_on_image(image)
260
- # output_image = PIL.Image.fromarray(vis_output.get_image())
261
- # # print("predictions: ", predictions)
262
- # return output_image
263
-
264
- def dtectron2_instance_inference(image, input_model_name, device):
265
  cfg = get_cfg()
266
  config_file = detectron2_model_list[input_model_name]["config_file"]
267
  ckpts = detectron2_model_list[input_model_name]["ckpts"]
@@ -269,6 +301,9 @@ def dtectron2_instance_inference(image, input_model_name, device):
269
  cfg.MODEL.WEIGHTS = ckpts
270
  cfg.MODEL.DEVICE = "cpu"
271
  cfg.output = "output_img.jpg"
 
 
 
272
  visualization_demo = VisualizationDemo(cfg, device=device)
273
  if image:
274
  intput_path = "intput_img.jpg"
@@ -280,15 +315,20 @@ def dtectron2_instance_inference(image, input_model_name, device):
280
  return output_image
281
 
282
  def download_test_img():
 
 
 
 
283
  # Images
284
  torch.hub.download_url_to_file(
285
  'https://user-images.githubusercontent.com/59380685/268517006-d8d4d3b3-964a-4f4d-8458-18c7eb75a4f2.jpg',
286
  '000000502136.jpg')
287
-
288
 
289
  if __name__ == '__main__':
290
  input_image = gr.inputs.Image(type='pil', label='Input Image')
291
  input_model_name = gr.inputs.Dropdown(list(detectron2_model_list.keys()), label="Model Name", default="COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x")
 
292
  input_device = gr.inputs.Dropdown(["cpu", "cuda"], label="Devices", default="cpu")
293
  output_image = gr.outputs.Image(type='pil', label='Output Image')
294
  output_predictions = gr.outputs.Textbox(type='text', label='Output Predictions')
@@ -301,8 +341,8 @@ if __name__ == '__main__':
301
  "<p style='text-align: center'><a href='https://github.com/facebookresearch/detectron2'>gradio build by gatilin</a></a></p>"
302
  download_test_img()
303
 
304
- examples = [["000000502136.jpg", "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x", "cpu"]]
305
  gr.Interface(fn=dtectron2_instance_inference,
306
- inputs=[input_image, input_model_name, input_device],
307
  outputs=output_image,examples=examples,
308
  title=title, description=description, article=article).launch()
 
238
 
239
 
240
  detectron2_model_list = {
241
+ # Cityscapes
242
+ "Cityscapes/mask_rcnn_R_50_FPN":{
243
+ "config_file": "configs/Cityscapes/mask_rcnn_R_50_FPN.yaml",
244
+ "ckpts": "detectron2://Cityscapes/"
245
+ },
246
+ # COCO-Detection
247
  "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x":{
248
  "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
249
  "ckpts": "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"
250
  },
251
+ "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x":{
252
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml",
253
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
254
+ },
255
+ "COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x":{
256
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x.yaml",
257
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
258
+ },
259
+ "COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x": {
260
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x.yaml",
261
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
262
+ },
263
+ "COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x": {
264
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x.yaml",
265
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
266
+ },
267
+ "COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x": {
268
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml",
269
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
270
+ },
271
+ "COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x": {
272
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x.yaml",
273
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
274
+ },
275
+ "COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x": {
276
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x.yaml",
277
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
278
+ },
279
+ "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x": {
280
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
281
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
282
+ },
283
+ "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x": {
284
+ "config_file": "configs/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
285
+ "ckpts": "detectron2://COCO-InstanceSegmentation/"
286
+ },
287
+ # COCO-Detection
288
+ "COCO-Detection/mask_rcnn_X_101_32x8d_FPN_3x": {
289
+ "config_file": "configs/COCO-Detection/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
290
+ "ckpts": "detectron2://COCO-Detection/"
291
+ },
292
  }
293
 
294
 
295
+
296
+ def dtectron2_instance_inference(image, input_model_name, confidence_threshold, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  cfg = get_cfg()
298
  config_file = detectron2_model_list[input_model_name]["config_file"]
299
  ckpts = detectron2_model_list[input_model_name]["ckpts"]
 
301
  cfg.MODEL.WEIGHTS = ckpts
302
  cfg.MODEL.DEVICE = "cpu"
303
  cfg.output = "output_img.jpg"
304
+ cfg.MODEL.RETINANET.SCORE_THRESH_TEST = confidence_threshold
305
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
306
+ cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_threshold
307
  visualization_demo = VisualizationDemo(cfg, device=device)
308
  if image:
309
  intput_path = "intput_img.jpg"
 
315
  return output_image
316
 
317
  def download_test_img():
318
+ import shutil
319
+ torch.hub.download_url_to_file(
320
+ 'https://github.com/isLinXu/issues/files/12643351/configs.zip',
321
+ 'configs.zip')
322
  # Images
323
  torch.hub.download_url_to_file(
324
  'https://user-images.githubusercontent.com/59380685/268517006-d8d4d3b3-964a-4f4d-8458-18c7eb75a4f2.jpg',
325
  '000000502136.jpg')
326
+ shutil.unpack_archive('configs.zip', 'configs', 'zip')
327
 
328
  if __name__ == '__main__':
329
  input_image = gr.inputs.Image(type='pil', label='Input Image')
330
  input_model_name = gr.inputs.Dropdown(list(detectron2_model_list.keys()), label="Model Name", default="COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x")
331
+ input_prediction_threshold = gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.25, label="Confidence Threshold")
332
  input_device = gr.inputs.Dropdown(["cpu", "cuda"], label="Devices", default="cpu")
333
  output_image = gr.outputs.Image(type='pil', label='Output Image')
334
  output_predictions = gr.outputs.Textbox(type='text', label='Output Predictions')
 
341
  "<p style='text-align: center'><a href='https://github.com/facebookresearch/detectron2'>gradio build by gatilin</a></a></p>"
342
  download_test_img()
343
 
344
+ examples = [["000000502136.jpg", "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x", 0.25, "cpu"]]
345
  gr.Interface(fn=dtectron2_instance_inference,
346
+ inputs=[input_image, input_model_name, input_prediction_threshold, input_device],
347
  outputs=output_image,examples=examples,
348
  title=title, description=description, article=article).launch()