from segment_anything.utils.transforms import ResizeLongestSide from groundingdino.util.inference import load_image, load_model, predict from torchvision.ops import box_convert import numpy as np import torch from segment_anything import sam_model_registry from segment_anything.modeling import Sam import os import torchvision.transforms as T2 import groundingdino.datasets.transforms as T from PIL import Image def init_segmentation(device='cpu') -> Sam: # 1) first cd into the segment_anything and pip install -e . # to get the model stary in the root foler folder and run the download_model.sh # 2) chmod +x download_model.sh && ./download_model.sh # the largest model: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth # this is the smallest model if os.path.exists('sam-hq/sam_hq_vit_b.pth'): sam_checkpoint = "sam-hq/sam_hq_vit_b.pth" model_type = "vit_b" else: sam_checkpoint = "sam-hq/sam_hq_vit_tiny.pth" model_type = "vit_tiny" print(f'SAM device: {device}, model_type: {model_type}') sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) return sam def find_ground(image_source, image, model, segmentor, device='cpu', TEXT_PROMPT="ground", BOX_TRESHOLD=0.35, TEXT_TRESHOLD=0.25): boxes, logits, _ = predict( model=model, image=image, caption=TEXT_PROMPT, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD, device=device ) if len(boxes) == 0: return None # only want box corresponding to max logit max_logit_idx = torch.argmax(logits) box = boxes[max_logit_idx].unsqueeze(0) _, h, w = image_source.shape box = box * torch.tensor([w, h, w, h], device=device) xyxy = box_convert(boxes=box, in_fmt="cxcywh", out_fmt="xyxy") image = image.unsqueeze(0) org_shape = image.shape[-2:] resize_transform = ResizeLongestSide(segmentor.image_encoder.img_size) batched_input = [] images = resize_transform.apply_image_torch(image*1.0)# .permute(2, 0, 1).contiguous() for image, boxes in zip(images, xyxy): transformed_boxes = resize_transform.apply_boxes_torch(boxes, org_shape) # Bx4 batched_input.append({'image': image, 'boxes': transformed_boxes, 'original_size':org_shape}) seg_out = segmentor(batched_input, multimask_output=False) mask_per_image = seg_out[0]['masks'] return mask_per_image[0,0,:,:].cpu().numpy() def load_image2(image:np.ndarray, device) -> tuple[torch.Tensor, torch.Tensor]: transform = T.Compose( [ # T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) transform2 = T2.ToTensor() image_source = Image.fromarray(image).convert("RGB") image = transform2(image_source).to(device) image_transformed, _ = transform(image_source, None) return image, image_transformed if __name__ == '__main__': import pandas as pd from matplotlib import pyplot as plt from tqdm import tqdm from cubercnn import data from detectron2.data.catalog import MetadataCatalog from priors import get_config_and_filter_settings import supervision as sv def init_dataset(): ''' dataloader stuff. currently not used anywhere, because I'm not sure what the difference between the omni3d dataset and load omni3D json functions are. this is a 3rd alternative to this. The train script calls something similar to this.''' cfg, filter_settings = get_config_and_filter_settings() dataset_names = ['SUNRGBD_train','SUNRGBD_val','SUNRGBD_test'] dataset_paths_to_json = ['datasets/Omni3D/'+dataset_name+'.json' for dataset_name in dataset_names] # for dataset_name in dataset_names: # simple_register(dataset_name, filter_settings, filter_empty=True) # Get Image and annotations datasets = data.Omni3D(dataset_paths_to_json, filter_settings=filter_settings) data.register_and_store_model_metadata(datasets, cfg.OUTPUT_DIR, filter_settings) thing_classes = MetadataCatalog.get('omni3d_model').thing_classes dataset_id_to_contiguous_id = MetadataCatalog.get('omni3d_model').thing_dataset_id_to_contiguous_id infos = datasets.dataset['info'] dataset_id_to_unknown_cats = {} possible_categories = set(i for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES + 1)) dataset_id_to_src = {} for info in infos: dataset_id = info['id'] known_category_training_ids = set() if not dataset_id in dataset_id_to_src: dataset_id_to_src[dataset_id] = info['source'] for id in info['known_category_ids']: if id in dataset_id_to_contiguous_id: known_category_training_ids.add(dataset_id_to_contiguous_id[id]) # determine and store the unknown categories. unknown_categories = possible_categories - known_category_training_ids dataset_id_to_unknown_cats[dataset_id] = unknown_categories return datasets def load_image(image_path: str, device) -> tuple[torch.Tensor, torch.Tensor]: transform = T.Compose( [ # T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) transform2 = T2.ToTensor() image_source = Image.open(image_path).convert("RGB") image = transform2(image_source).to(device) image_transformed, _ = transform(image_source, None) return image, image_transformed.to(device) def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: list[str]) -> np.ndarray: """ This function annotates an image with bounding boxes and labels. Parameters: image_source (np.ndarray): The source image to be annotated. boxes (torch.Tensor): A tensor containing bounding box coordinates. logits (torch.Tensor): A tensor containing confidence scores for each bounding box. phrases (List[str]): A list of labels for each bounding box. Returns: np.ndarray: The annotated image. """ h, w, _ = image_source.shape boxes = boxes * torch.Tensor([w, h, w, h]) xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() detections = sv.Detections(xyxy=xyxy) labels = [ f"{phrase} {logit:.2f}" for phrase, logit in zip(phrases, logits) ] box_annotator = sv.BoxAnnotator() # annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) annotated_frame = image_source.copy() annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) return annotated_frame # datasets = init_dataset() device = 'cuda' if torch.cuda.is_available() else 'cpu' # model.to(device) segmentor = init_segmentation(device=device) os.makedirs('datasets/ground_maps', exist_ok=True) model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "GroundingDINO/weights/groundingdino_swint_ogc.pth", device=device) TEXT_PROMPT = "ground" BOX_TRESHOLD = 0.35 TEXT_TRESHOLD = 0.25 noground = 0 no_ground_idx = [] # **** to annotate full dataset **** # for img_id, img_info in tqdm(datasets.imgs.items()): # file_path = img_info['file_path'] # width = img_info['width'] # height = img_info['height'] # **** to annotate full dataset **** # **** to annotate demo images **** for img_id in tqdm(os.listdir('datasets/coco_examples')): file_path = 'coco_examples/'+img_id image_source, image = load_image('datasets/'+file_path, device=device) # **** to annotate demo images **** boxes, logits, phrases = predict( model=model, image=image, caption=TEXT_PROMPT, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD, device=device ) if len(boxes) == 0: print(f"No ground found for {img_id}") noground += 1 # save a ground map that is all zeros no_ground_idx.append(img_id) continue # only want box corresponding to max logit max_logit_idx = torch.argmax(logits) logit = logits[max_logit_idx].unsqueeze(0) box = boxes[max_logit_idx].unsqueeze(0) phrase = [phrases[max_logit_idx]] _, h, w = image_source.shape box = box * torch.tensor([w, h, w, h], device=device) xyxy = box_convert(boxes=box, in_fmt="cxcywh", out_fmt="xyxy") image = image.unsqueeze(0) org_shape = image.shape[-2:] resize_transform = ResizeLongestSide(segmentor.image_encoder.img_size) batched_input = [] images = resize_transform.apply_image_torch(image*1.0)# .permute(2, 0, 1).contiguous() for image, boxes in zip(images, xyxy): transformed_boxes = resize_transform.apply_boxes_torch(boxes, org_shape) # Bx4 batched_input.append({'image': image, 'boxes': transformed_boxes, 'original_size':org_shape}) seg_out = segmentor(batched_input, multimask_output=False) mask_per_image = seg_out[0]['masks'] nnz = torch.count_nonzero(mask_per_image, dim=(-2, -1)) indices = torch.nonzero(nnz <= 1000).flatten() if len(indices) > 0: noground += 1 # save a ground map that is all zeros no_ground_idx.append(img_id) np.savez_compressed(f'datasets/ground_maps/{img_id}.npz', mask=mask_per_image.cpu()[0,0,:,:].numpy()) print(f"Could not find ground for {noground} images") df = pd.DataFrame(no_ground_idx, columns=['img_id']) df.to_csv('datasets/no_ground_idx.csv', index=False)