Plachta commited on
Commit
8e9f709
·
verified ·
1 Parent(s): 0cbc48e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -56,7 +56,9 @@ def predict(input, topk):
56
  t_image = img_resize.apply_image_torch(image)
57
  t_orig_size = t_image.shape[-2:]
58
  # pad to 1024x1024
 
59
  t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
 
60
 
61
  # get box prompt
62
  valid_boxes = []
@@ -69,7 +71,7 @@ def predict(input, topk):
69
  t_boxes = np.array(valid_boxes)
70
  t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
71
  box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
72
- batched_inputs = [{"image": t_image[0], "boxes": box_torch}]
73
  with torch.no_grad():
74
  outputs = sam.infer(batched_inputs, multimask_output=False)
75
  # visualize and post on tensorboard
@@ -87,7 +89,7 @@ def predict(input, topk):
87
  pred_logits = outputs.logits[i].detach().cpu().numpy()
88
  top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
89
  pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
90
- coded_grasp = GraspCoder(1024, 1024, None, grasp_annos_reformat=pred_grasp)
91
  _ = coded_grasp.decode()
92
  decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)
93
 
@@ -125,7 +127,4 @@ if __name__ == "__main__":
125
  btn.click(predict,
126
  inputs=[prompter, top_k],
127
  outputs=[image_output])
128
- app.launch()
129
-
130
-
131
-
 
56
  t_image = img_resize.apply_image_torch(image)
57
  t_orig_size = t_image.shape[-2:]
58
  # pad to 1024x1024
59
+ pixel_mask = torch.ones(1, t_orig_size[0], t_orig_size[1], device=device)
60
  t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
61
+ pixel_mask = torch.nn.functional.pad(pixel_mask, (0, 1024 - t_orig_size[1], 0, 1024 - t_orig_size[0]))
62
 
63
  # get box prompt
64
  valid_boxes = []
 
71
  t_boxes = np.array(valid_boxes)
72
  t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
73
  box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
74
+ batched_inputs = [{"image": t_image[0], "boxes": box_torch, "pixel_mask": pixel_mask}]
75
  with torch.no_grad():
76
  outputs = sam.infer(batched_inputs, multimask_output=False)
77
  # visualize and post on tensorboard
 
89
  pred_logits = outputs.logits[i].detach().cpu().numpy()
90
  top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
91
  pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
92
+ coded_grasp = GraspCoder(t_orig_size[0], t_orig_size[1], None, grasp_annos_reformat=pred_grasp)
93
  _ = coded_grasp.decode()
94
  decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)
95
 
 
127
  btn.click(predict,
128
  inputs=[prompter, top_k],
129
  outputs=[image_output])
130
+ app.launch()