liuyizhang commited on
Commit
c419c35
1 Parent(s): e5f7fa3

update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -20
app.py CHANGED
@@ -116,18 +116,16 @@ def load_image(image_path):
116
  image, _ = transform(image_pil, None) # 3, h, w
117
  return image_pil, image
118
 
119
-
120
  def load_model(model_config_path, model_checkpoint_path, device):
121
  args = SLConfig.fromfile(model_config_path)
122
  args.device = device
123
  model = build_model(args)
124
- checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
125
  load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
126
  print(load_res)
127
  _ = model.eval()
128
  return model
129
 
130
-
131
  def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
132
  caption = caption.lower()
133
  caption = caption.strip()
@@ -172,14 +170,12 @@ def show_mask(mask, ax, random_color=False):
172
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
173
  ax.imshow(mask_image)
174
 
175
-
176
  def show_box(box, ax, label):
177
  x0, y0 = box[0], box[1]
178
  w, h = box[2] - box[0], box[3] - box[1]
179
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
180
  ax.text(x0, y0, label)
181
 
182
-
183
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
184
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
185
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
@@ -189,6 +185,19 @@ device = "cuda"
189
 
190
  device = get_device()
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
193
  assert text_prompt, 'text_prompt is not found!'
194
 
@@ -196,24 +205,20 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
196
  os.makedirs(output_dir, exist_ok=True)
197
  # load image
198
  image_pil, image = load_image(image_path.convert("RGB"))
199
- # load model
200
- model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
201
 
202
  # visualize raw image
203
  image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
204
 
205
  # run grounding dino model
206
  boxes_filt, pred_phrases = get_grounding_output(
207
- model, image, text_prompt, box_threshold, text_threshold, device=device
208
  )
209
 
210
  size = image_pil.size
211
 
212
  if task_type == 'segment' or task_type == 'inpainting':
213
- # initialize SAM
214
- predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
215
  image = np.array(image_path)
216
- predictor.set_image(image)
217
 
218
  H, W = size[1], size[0]
219
  for i in range(boxes_filt.size(0)):
@@ -222,9 +227,9 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
222
  boxes_filt[i][2:] += boxes_filt[i][:2]
223
 
224
  boxes_filt = boxes_filt.cpu()
225
- transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
226
 
227
- masks, _, _ = predictor.predict_torch(
228
  point_coords = None,
229
  point_labels = None,
230
  boxes = transformed_boxes,
@@ -266,14 +271,8 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
266
  mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
267
  mask_pil = Image.fromarray(mask)
268
  image_pil = Image.fromarray(image)
269
-
270
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
271
- "runwayml/stable-diffusion-inpainting",
272
- # torch_dtype=torch.float16
273
- )
274
- pipe = pipe.to(device)
275
 
276
- image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
277
  image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
278
  image.save(image_path)
279
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
 
116
  image, _ = transform(image_pil, None) # 3, h, w
117
  return image_pil, image
118
 
 
119
  def load_model(model_config_path, model_checkpoint_path, device):
120
  args = SLConfig.fromfile(model_config_path)
121
  args.device = device
122
  model = build_model(args)
123
+ checkpoint = torch.load(model_checkpoint_path, map_location=device) #"cpu")
124
  load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
125
  print(load_res)
126
  _ = model.eval()
127
  return model
128
 
 
129
  def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
130
  caption = caption.lower()
131
  caption = caption.strip()
 
170
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
171
  ax.imshow(mask_image)
172
 
 
173
  def show_box(box, ax, label):
174
  x0, y0 = box[0], box[1]
175
  w, h = box[2] - box[0], box[3] - box[1]
176
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
177
  ax.text(x0, y0, label)
178
 
 
179
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
180
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
181
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
 
185
 
186
  device = get_device()
187
 
188
+ # initialize groundingdino model
189
+ groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
190
+
191
+ # initialize SAM
192
+ sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
193
+
194
+ # initialize stable-diffusion-inpainting
195
+ sd_pipe = StableDiffusionInpaintPipeline.from_pretrained(
196
+ "runwayml/stable-diffusion-inpainting",
197
+ # torch_dtype=torch.float16
198
+ )
199
+ sd_pipe = sd_pipe.to(device)
200
+
201
  def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
202
  assert text_prompt, 'text_prompt is not found!'
203
 
 
205
  os.makedirs(output_dir, exist_ok=True)
206
  # load image
207
  image_pil, image = load_image(image_path.convert("RGB"))
 
 
208
 
209
  # visualize raw image
210
  image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
211
 
212
  # run grounding dino model
213
  boxes_filt, pred_phrases = get_grounding_output(
214
+ groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=device
215
  )
216
 
217
  size = image_pil.size
218
 
219
  if task_type == 'segment' or task_type == 'inpainting':
 
 
220
  image = np.array(image_path)
221
+ sam_predictor.set_image(image)
222
 
223
  H, W = size[1], size[0]
224
  for i in range(boxes_filt.size(0)):
 
227
  boxes_filt[i][2:] += boxes_filt[i][:2]
228
 
229
  boxes_filt = boxes_filt.cpu()
230
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
231
 
232
+ masks, _, _ = sam_predictor.predict_torch(
233
  point_coords = None,
234
  point_labels = None,
235
  boxes = transformed_boxes,
 
271
  mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
272
  mask_pil = Image.fromarray(mask)
273
  image_pil = Image.fromarray(image)
 
 
 
 
 
 
274
 
275
+ image = sd_pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
276
  image_path = os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")
277
  image.save(image_path)
278
  image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)