import copy import numpy as np import torch import sys sys.path.append("./") from models import sam_model_registry from models.grasp_mods import modify_forward from models.utils.transforms import ResizeLongestSide from gradio_image_prompter import ImagePrompter from structures.grasp_box import GraspCoder img_resize = ResizeLongestSide(1024) import cv2 import gradio as gr from models.grasp_mods import add_inference_method device = "cuda" if torch.cuda.is_available() else "cpu" model_type = "vit_b" mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis] std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis] sam = sam_model_registry[model_type]() sam.to(device=device) sam.forward = modify_forward(sam) sam.infer = add_inference_method(sam) pretrained_model_path = "./epoch_39_step_415131.pth" if pretrained_model_path != "": sd = torch.load(pretrained_model_path, map_location='cpu') # strip prefix "module." from keys new_sd = {} for k, v in sd.items(): if k.startswith("module."): k = k[7:] new_sd[k] = v sam.load_state_dict(new_sd) sam.eval() def predict(input, topk): np_image = input["image"] points = input["points"] orig_size = np_image.shape[:2] # normalize image np_image = np_image.transpose(2, 0, 1) image = (np_image - mean) / std image = torch.tensor(image).float().to(device) image = image.unsqueeze(0) t_image = img_resize.apply_image_torch(image) t_orig_size = t_image.shape[-2:] # pad to 1024x1024 t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2])) # get box prompt valid_boxes = [] for point in points: x1, y1, type1, x2, y2, type2 = point if type1 == 2 and type2 == 3: valid_boxes.append([x1, y1, x2, y2]) if len(valid_boxes) == 0: return np_image t_boxes = np.array(valid_boxes) t_boxes = img_resize.apply_boxes(t_boxes, orig_size) box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device) batched_inputs = [{"image": t_image[0], "boxes": box_torch}] with torch.no_grad(): outputs = sam.infer(batched_inputs, multimask_output=False) # visualize and post on tensorboard # recover image recovered_img = batched_inputs[0]['image'].cpu().numpy() recovered_img = recovered_img * std + mean recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255) for i in range(len(outputs.pred_masks)): # get predicted mask pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5 pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2) # get predicted grasp pred_logits = outputs.logits[i].detach().cpu().numpy() top_ind = pred_logits[:, 0].argsort()[-topk:][::-1] pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind] coded_grasp = GraspCoder(1024, 1024, None, grasp_annos_reformat=pred_grasp) _ = coded_grasp.decode() decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos) # draw mask mask_color = np.array([0, 255, 0])[None, None, :] recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5 # draw grasp recovered_img = np.ascontiguousarray(recovered_img) for grasp in decoded_grasp: grasp = grasp.astype(int) cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1) cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1) cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2) cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2) recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]] # resize to original size recovered_img = cv2.resize(recovered_img, (orig_size[1], orig_size[0])) return recovered_img if __name__ == "__main__": app = gr.Blocks(title="GraspAnything") with app: gr.Markdown(""" # GraspAnything
Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object. """) with gr.Column(): prompter = ImagePrompter(show_label=False) top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps") with gr.Column(): image_output = gr.Image() btn = gr.Button("Generate!") btn.click(predict, inputs=[prompter, top_k], outputs=[image_output]) app.launch()