import PIL import numpy as np import torch import torch.nn as nn import torchvision from yacs.config import CfgNode as CN from PIL import ImageDraw from segment_anything import build_sam, SamPredictor from segment_anything.utils.amg import remove_small_regions from PIL import ImageDraw, ImageFont import groundingdino.util.transforms as T from constants.constant import DARKER_COLOR_MAP, LIGHTER_COLOR_MAP, COLORS from groundingdino import build_groundingdino from groundingdino.util.predict import predict from groundingdino.util.utils import clean_state_dict def load_groundingdino_model(model_config_path, model_checkpoint_path): import gc args = CN.load_cfg(open(model_config_path, "r")) model = build_groundingdino(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) print('loading GroundingDINO:', load_res) gc.collect() _ = model.eval() return model class GroundingModule(nn.Module): def __init__(self, device='cpu'): super().__init__() self.device = device sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth" groundingdino_checkpoint = "./checkpoints/groundingdino_swint_ogc.pth" groundingdino_config_file = "./eval_configs/GroundingDINO_SwinT_OGC.yaml" self.grounding_model = load_groundingdino_model(groundingdino_config_file, groundingdino_checkpoint).to(device) self.grounding_model.eval() sam = build_sam(checkpoint=sam_checkpoint).to(device) sam.eval() self.sam_predictor = SamPredictor(sam) @torch.no_grad() def prompt2mask(self, original_image, prompt, state, box_threshold=0.35, text_threshold=0.25, num_boxes=10): def image_transform_grounding(init_image): transform = T.Compose([ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image, _ = transform(init_image, None) # 3, h, w return init_image, image image_np = np.array(original_image, dtype=np.uint8) prompt = prompt.lower() prompt = prompt.strip() if not prompt.endswith("."): prompt = prompt + "." _, image_tensor = image_transform_grounding(original_image) print('==> Box grounding with "{}"...'.format(prompt)) with torch.cuda.amp.autocast(enabled=True): boxes, logits, phrases = predict(self.grounding_model, image_tensor, prompt, box_threshold, text_threshold, device=self.device) print(phrases) # from PIL import Image, ImageDraw, ImageFont H, W = original_image.size[1], original_image.size[0] draw_img = original_image.copy() draw = ImageDraw.Draw(draw_img) color_boxes = [] color_masks = [] local_results = [original_image.copy() for _ in range(len(state['entity']))] local2entity = {} for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): # from 0..1 to 0..W, 0..H box = box * torch.Tensor([W, H, W, H]) # from xywh to xyxy box[:2] -= box[2:] / 2 box[2:] += box[:2] # random color for i, s in enumerate(state['entity']): # print(label.lower(), i[0].lower(), label.lower() == i[0].lower()) if label.lower() == s[0].lower(): local2entity[obj_ind] = i break if obj_ind not in local2entity: print('Color not found', label) color = "grey" # In grey mode. # tuple(np.random.randint(0, 255, size=3).tolist()) else: for i, s in enumerate(state['entity']): # print(label.lower(), i[0].lower(), label.lower() == i[0].lower()) if label.lower() == s[0].lower(): local2entity[obj_ind] = i break if obj_ind not in local2entity: print('Color not found', label) color = tuple(np.random.randint(0, 255, size=3).tolist()) else: color = state['entity'][local2entity[obj_ind]][3] color_boxes.append(color) print(color_boxes) # draw x0, y0, x1, y1 = box x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) draw.rectangle([x0, y0, x1, y1], outline=color, width=10) # font = ImageFont.load_default() font = ImageFont.truetype('InputSans-Regular.ttf', int(H / 512.0 * 30)) if hasattr(font, "getbbox"): bbox = draw.textbbox((x0, y0), str(label), font) else: w, h = draw.textsize(str(label), font) bbox = (x0, y0, w + x0, y0 + h) draw.rectangle(bbox, fill=color) draw.text((x0, y0), str(label), fill="white", font=font) if obj_ind in local2entity: local_draw = ImageDraw.Draw(local_results[local2entity[obj_ind]]) local_draw.rectangle([x0, y0, x1, y1], outline=color, width=10) local_draw.rectangle(bbox, fill=color) local_draw.text((x0, y0), str(label), fill="white", font=font) if boxes.size(0) > 0: print('==> Mask grounding...') boxes = boxes * torch.Tensor([W, H, W, H]) boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] self.sam_predictor.set_image(image_np) transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2]) with torch.cuda.amp.autocast(enabled=True): masks, _, _ = self.sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes.to(self.device), multimask_output=False, ) # remove small disconnected regions and holes fine_masks = [] for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) masks = np.stack(fine_masks, axis=0)[:, np.newaxis] masks = torch.from_numpy(masks) num_obj = min(len(logits), num_boxes) mask_map = None full_img = None for obj_ind in range(num_obj): # box = boxes[obj_ind] m = masks[obj_ind][0] if full_img is None: full_img = np.zeros((m.shape[0], m.shape[1], 3)) mask_map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) local_image = np.zeros((m.shape[0], m.shape[1], 3)) mask_map[m != 0] = obj_ind + 1 # color_mask = np.random.random((1, 3)).tolist()[0] color_mask = np.array(color_boxes[obj_ind]) / 255.0 full_img[m != 0] = color_mask local_image[m != 0] = color_mask # if local_results[local2entity[obj_ind]] is not None: # local_image[m == 0] = np.asarray(local_results[local2entity[obj_ind]])[m == 0] local_image = (local_image * 255).astype(np.uint8) local_image = PIL.Image.fromarray(local_image) if local_results[local2entity[obj_ind]] is not None: local_results[local2entity[obj_ind]] = PIL.Image.blend(local_results[local2entity[obj_ind]], local_image, 0.5) full_img = (full_img * 255).astype(np.uint8) full_img = PIL.Image.fromarray(full_img) draw_img = PIL.Image.blend(draw_img, full_img, 0.5) return draw_img, local_results # def draw_text(self, entity_state, entity, text): # local_img = entity_state['grounding']['local'][entity]['image'].copy() # H, W = local_img.width, local_img.height # font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) # # for x0, y0 in entity_state['grounding']['local'][entity]['text_positions']: # color = entity_state['grounding']['local'][entity]['color'] # local_draw = ImageDraw.Draw(local_img) # if hasattr(font, "getbbox"): # bbox = local_draw.textbbox((x0, y0), str(text), font) # else: # w, h = local_draw.textsize(str(text), font) # bbox = (x0, y0, w + x0, y0 + h) # # local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) # local_draw.text((x0, y0), str(text), fill="white", font=font) # return local_img def draw(self, original_image, entity_state, item=None): original_image = original_image.copy() W, H = original_image.width, original_image.height font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) local_image = np.zeros((H, W, 3)) local_mask = np.zeros((H, W), dtype=bool) def draw_item(img, item): nonlocal local_image, local_mask entity = entity_state['match_state'][item] ei = entity_state['grounding']['local'][entity] color = ei['color'] local_draw = ImageDraw.Draw(img) for x0, y0, x1, y1 in ei['entity_positions']: local_draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], width=int(min(H, W) / 512.0 * 10)) for x0, y0 in ei['text_positions']: if hasattr(font, "getbbox"): bbox = local_draw.textbbox((x0, y0), str(item), font) else: w, h = local_draw.textsize(str(item), font) bbox = (x0, y0, w + x0, y0 + h) local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) local_draw.text((x0, y0), str(item), fill="white", font=font) for m in ei['masks']: local_image[m != 0] = np.array(LIGHTER_COLOR_MAP[color]) / 255.0 local_mask = np.logical_or(local_mask, m) # local_image = (local_image * 255).astype(np.uint8) # local_image = PIL.Image.fromarray(local_image) # img = PIL.Image.blend(img, local_image, 0.5) return img if item is None: for item in entity_state['match_state'].keys(): original_image = draw_item(original_image, item) else: original_image = draw_item(original_image, item) local_image[local_mask == 0] = (np.array(original_image) / 255.0)[local_mask == 0] local_image = (local_image * 255).astype(np.uint8) local_image = PIL.Image.fromarray(local_image) original_image = PIL.Image.blend(original_image, local_image, 0.5) return original_image @torch.no_grad() def prompt2mask2(self, original_image, prompt, state, box_threshold=0.25, text_threshold=0.2, iou_threshold=0.5, num_boxes=10): def image_transform_grounding(init_image): transform = T.Compose([ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image, _ = transform(init_image, None) # 3, h, w return init_image, image image_np = np.array(original_image, dtype=np.uint8) prompt = prompt.lower() prompt = prompt.strip() if not prompt.endswith("."): prompt = prompt + "." _, image_tensor = image_transform_grounding(original_image) print('==> Box grounding with "{}"...'.format(prompt)) with torch.cuda.amp.autocast(enabled=True): boxes, logits, phrases = predict(self.grounding_model, image_tensor, prompt, box_threshold, text_threshold, device=self.device) print('==> Box grounding results {}...'.format(phrases)) # boxes_filt = boxes.cpu() # # use NMS to handle overlapped boxes # print(f"==> Before NMS: {boxes_filt.shape[0]} boxes") # nms_idx = torchvision.ops.nms(boxes_filt, logits, iou_threshold).numpy().tolist() # boxes_filt = boxes_filt[nms_idx] # phrases = [phrases[idx] for idx in nms_idx] # print(f"==> After NMS: {boxes_filt.shape[0]} boxes") # boxes = boxes_filt # from PIL import Image, ImageDraw, ImageFont H, W = original_image.size[1], original_image.size[0] draw_img = original_image.copy() draw = ImageDraw.Draw(draw_img) color_boxes = [] color_masks = [] entity_dict = {} for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): # from 0..1 to 0..W, 0..H box = box * torch.Tensor([W, H, W, H]) # from xywh to xyxy box[:2] -= box[2:] / 2 box[2:] += box[:2] if label not in entity_dict: entity_dict[label] = { 'color': COLORS[len(entity_dict) % (len(COLORS) - 1)], # 'image': original_image.copy(), 'text_positions': [], 'entity_positions': [], 'masks': [] } color = entity_dict[label]['color'] color_boxes.append(DARKER_COLOR_MAP[color]) color_masks.append(LIGHTER_COLOR_MAP[color]) # draw x0, y0, x1, y1 = box x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], width=10) font = ImageFont.truetype('InputSans-Regular.ttf', int(min(H, W) / 512.0 * 30)) if hasattr(font, "getbbox"): bbox = draw.textbbox((x0, y0), str(label), font) else: w, h = draw.textsize(str(label), font) bbox = (x0, y0, w + x0, y0 + h) draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) draw.text((x0, y0), str(label), fill="white", font=font) # local_img = entity_dict[label]['image'] # local_draw = ImageDraw.Draw(local_img) # local_draw.rectangle([x0, y0, x1, y1], outline=DARKER_COLOR_MAP[color], width=10) entity_dict[label]['text_positions'].append((x0, y0)) entity_dict[label]['entity_positions'].append((x0, y0, x1, y1)) # local_draw.rectangle(bbox, fill=DARKER_COLOR_MAP[color]) # local_draw.text((x0, y0), str(label), fill="white", font=font) if boxes.size(0) > 0: print('==> Mask grounding...') boxes = boxes * torch.Tensor([W, H, W, H]) boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] self.sam_predictor.set_image(image_np) transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2]).to(self.device) with torch.cuda.amp.autocast(enabled=True): masks, _, _ = self.sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes.to(self.device), multimask_output=False, ) # remove small disconnected regions and holes fine_masks = [] for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) masks = np.stack(fine_masks, axis=0)[:, np.newaxis] masks = torch.from_numpy(masks) mask_map = None full_img = None for obj_ind, (box, label) in enumerate(zip(boxes, phrases)): m = masks[obj_ind][0] if full_img is None: full_img = np.zeros((m.shape[0], m.shape[1], 3)) mask_map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) # local_image = np.zeros((m.shape[0], m.shape[1], 3)) mask_map[m != 0] = obj_ind + 1 color_mask = np.array(color_masks[obj_ind]) / 255.0 full_img[m != 0] = color_mask entity_dict[label]['masks'].append(m) # local_image[m != 0] = color_mask # local_image[m == 0] = (np.array(entity_dict[label]['image']) / 255.0)[m == 0] # # local_image = (local_image * 255).astype(np.uint8) # local_image = PIL.Image.fromarray(local_image) # entity_dict[label]['image'] = PIL.Image.blend(entity_dict[label]['image'], local_image, 0.5) full_img = (full_img * 255).astype(np.uint8) full_img = PIL.Image.fromarray(full_img) draw_img = PIL.Image.blend(draw_img, full_img, 0.5) print('==> Entity list: {}'.format(list(entity_dict.keys()))) return draw_img, entity_dict