jiuface commited on
Commit
917a5a6
·
1 Parent(s): 4c32826
Files changed (1) hide show
  1. app.py +28 -19
app.py CHANGED
@@ -9,6 +9,7 @@ from io import BytesIO
9
  import PIL.Image
10
  import requests
11
  import cv2
 
12
 
13
  from utils.florence import load_florence_model, run_florence_inference, \
14
  FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
@@ -44,7 +45,7 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
44
  response.raise_for_status()
45
  image_input = PIL.Image.open(BytesIO(response.content))
46
  print("fetch image success")
47
-
48
  _, result = run_florence_inference(
49
  model=FLORENCE_MODEL,
50
  processor=FLORENCE_PROCESSOR,
@@ -53,67 +54,75 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
53
  task=task_prompt,
54
  text=text_prompt
55
  )
 
56
  detections = sv.Detections.from_lmm(
57
  lmm=sv.LMM.FLORENCE_2,
58
  result=result,
59
  resolution_wh=image_input.size
60
  )
 
61
  images = []
62
  if return_rectangles:
63
- # 创建黑色背景的图片
64
- mask_image = np.zeros((image_input.size.height, image_input.size.width), dtype=np.uint8)
65
- bboxes = detections.get('bboxes', [])
 
66
  for bbox in bboxes:
67
  x1, y1, x2, y2 = map(int, bbox)
68
- # mask_image 上绘制白色的矩形
69
- cv2.rectangle(mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
70
- images = [mask_image]
 
 
 
71
  else:
72
- # sam
73
  detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
74
  if len(detections) == 0:
75
  gr.Info("No objects detected.")
76
  return None
77
- kernel_size = dilate
78
  print("mask generated:", len(detections.mask))
 
79
  kernel = np.ones((kernel_size, kernel_size), np.uint8)
 
80
  for i in range(len(detections.mask)):
81
  mask = detections.mask[i].astype(np.uint8) * 255
82
  if dilate > 0:
83
  mask = cv2.dilate(mask, kernel, iterations=1)
84
  images.append(mask)
 
85
  if merge_masks:
86
-
87
  merged_mask = np.zeros_like(images[0], dtype=np.uint8)
88
  for mask in images:
89
  merged_mask = cv2.bitwise_or(merged_mask, mask)
90
- images = [merged_mask] + images
91
 
92
- return images
93
 
94
 
95
  with gr.Blocks() as demo:
96
  with gr.Row():
97
  with gr.Column():
98
  image = gr.Image(type='pil', label='Upload image')
99
- image_url = gr.Textbox( label='Image url', placeholder='Enter text prompts (Optional)')
100
  task_prompt = gr.Dropdown(
101
  ['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts"
102
  )
103
  dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
104
  merge_masks = gr.Checkbox(label="Merge masks", value=False)
105
- return_rectangles = gr.Checkbox(label="Return rectangle masks", value=False)
106
  text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
107
  submit_button = gr.Button(value='Submit', variant='primary')
108
  with gr.Column():
109
  image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
 
 
110
  print(image, image_url, task_prompt, text_prompt, image_gallery)
111
  submit_button.click(
112
- fn = process_image,
113
- inputs = [image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
114
- outputs = [image_gallery,],
115
  show_api=False
116
  )
117
 
118
-
119
- demo.launch(debug=True, show_error=True)
 
9
  import PIL.Image
10
  import requests
11
  import cv2
12
+ import json
13
 
14
  from utils.florence import load_florence_model, run_florence_inference, \
15
  FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
 
45
  response.raise_for_status()
46
  image_input = PIL.Image.open(BytesIO(response.content))
47
  print("fetch image success")
48
+ # start to parse prompt
49
  _, result = run_florence_inference(
50
  model=FLORENCE_MODEL,
51
  processor=FLORENCE_PROCESSOR,
 
54
  task=task_prompt,
55
  text=text_prompt
56
  )
57
+ # start to dectect
58
  detections = sv.Detections.from_lmm(
59
  lmm=sv.LMM.FLORENCE_2,
60
  result=result,
61
  resolution_wh=image_input.size
62
  )
63
+ json_result = json.dumps({"bbox": detections.xyxy, "data": detections.data})
64
  images = []
65
  if return_rectangles:
66
+ # create mask in rectangle
67
+ (image_width, image_height) = image_input.size
68
+ bboxes = detections.xyxy
69
+ merge_mask_image = np.zeros((image_height, image_width), dtype=np.uint8)
70
  for bbox in bboxes:
71
  x1, y1, x2, y2 = map(int, bbox)
72
+ cv2.rectangle(merge_mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
73
+ clip_mask = np.zeros((image_height, image_width), dtype=np.uint8)
74
+ cv2.rectangle(clip_mask, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
75
+ images.append(clip_mask)
76
+ if merge_masks:
77
+ images = [merge_mask_image] + images
78
  else:
79
+ # using sam generate segments images
80
  detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
81
  if len(detections) == 0:
82
  gr.Info("No objects detected.")
83
  return None
 
84
  print("mask generated:", len(detections.mask))
85
+ kernel_size = dilate
86
  kernel = np.ones((kernel_size, kernel_size), np.uint8)
87
+
88
  for i in range(len(detections.mask)):
89
  mask = detections.mask[i].astype(np.uint8) * 255
90
  if dilate > 0:
91
  mask = cv2.dilate(mask, kernel, iterations=1)
92
  images.append(mask)
93
+
94
  if merge_masks:
 
95
  merged_mask = np.zeros_like(images[0], dtype=np.uint8)
96
  for mask in images:
97
  merged_mask = cv2.bitwise_or(merged_mask, mask)
98
+ images = [merged_mask]
99
 
100
+ return [images, json_result]
101
 
102
 
103
  with gr.Blocks() as demo:
104
  with gr.Row():
105
  with gr.Column():
106
  image = gr.Image(type='pil', label='Upload image')
107
+ image_url = gr.Textbox(label='Image url', placeholder='Enter text prompts (Optional)')
108
  task_prompt = gr.Dropdown(
109
  ['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts"
110
  )
111
  dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
112
  merge_masks = gr.Checkbox(label="Merge masks", value=False)
113
+ return_rectangles = gr.Checkbox(label="Return Rectangles", value=False)
114
  text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
115
  submit_button = gr.Button(value='Submit', variant='primary')
116
  with gr.Column():
117
  image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
118
+ json_result = gr.Code(label="JSON Result", language="json")
119
+
120
  print(image, image_url, task_prompt, text_prompt, image_gallery)
121
  submit_button.click(
122
+ fn=process_image,
123
+ inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
124
+ outputs=[image_gallery, json_result],
125
  show_api=False
126
  )
127
 
128
+ demo.launch(debug=True, show_error=True)