# Copyright (c) OpenMMLab. All rights reserved. import copy import logging import os.path as osp import pickle from typing import List, Union import h5py import tqdm from mmdet.datasets.api_wrappers import COCO from mmdet.datasets.base_det_dataset import BaseDetDataset from mmdet.registry import DATASETS from mmengine.fileio import get_local_path from mmengine.logging import print_log @DATASETS.register_module() class MASADataset(BaseDetDataset): """Dataset for COCO.""" METAINFO = { "classes": ("object"), # palette is a list of color tuples, which is used for visualization. "palette": [(220, 20, 60)], } COCOAPI = COCO # ann_id is unique in coco dataset. ANN_ID_UNIQUE = True def __init__(self, anno_hdf5_path=None, img_prefix=None, *args, **kwargs): self.anno_hdf5_path = anno_hdf5_path self.img_prefix = img_prefix super().__init__(*args, **kwargs) def read_dicts_from_hdf5(self, hdf5_file_path, pkl_file_path): with h5py.File(hdf5_file_path, "r") as hf: # Retrieve the dataset corresponding to the specified .pkl file path dataset = hf[pkl_file_path] binary_data = dataset[()] # Deserialize the binary data and load the list of dictionaries list_of_dicts = pickle.loads(binary_data) return list_of_dicts def get_ann_info(self, img_info): """Get COCO annotation by index. Args: idx (int): Index of data. Returns: dict: Annotation info of specified index. """ if self.anno_hdf5_path is not None: try: ann_info = self.read_dicts_from_hdf5( self.anno_hdf5_path, img_info["file_name"].replace(".jpg", ".pkl") ) return ann_info except: print(self.anno_hdf5_path) print(img_info["file_name"].replace(".jpg", ".pkl")) return None else: img_id = img_info["id"] ann_ids = self.coco.get_ann_ids(img_ids=[img_id], cat_ids=self.cat_ids) ann_info = self.coco.load_anns(ann_ids) return ann_info def __getitem__(self, idx: int) -> dict: """Get the idx-th image and data information of dataset after ``self.pipeline``, and ``full_init`` will be called if the dataset has not been fully initialized. During training phase, if ``self.pipeline`` get ``None``, ``self._rand_another`` will be called until a valid image is fetched or the maximum limit of refetech is reached. Args: idx (int): The index of self.data_list. Returns: dict: The idx-th image and data information of dataset after ``self.pipeline``. """ # Performing full initialization by calling `__getitem__` will consume # extra memory. If a dataset is not fully initialized by setting # `lazy_init=True` and then fed into the dataloader. Different workers # will simultaneously read and parse the annotation. It will cost more # time and memory, although this may work. Therefore, it is recommended # to manually call `full_init` before dataset fed into dataloader to # ensure all workers use shared RAM from master process. if not self._fully_initialized: print_log( "Please call `full_init()` method manually to accelerate " "the speed.", logger="current", level=logging.WARNING, ) self.full_init() if self.test_mode: data = self.prepare_data(idx) if data is None: raise Exception( "Test time pipline should not get `None` " "data_sample" ) return data for _ in range(self.max_refetch + 1): try: data = self.prepare_data(idx) except Exception as e: data = None # Broken images or random augmentations may cause the returned data # to be None if data is None: idx = self._rand_another() continue return data raise Exception( f"Cannot find valid image after {self.max_refetch}! " "Please check your image path and pipeline" ) def load_data_list(self) -> List[dict]: """Load annotations from an annotation file named as ``self.ann_file`` Returns: List[dict]: A list of annotation. """ # noqa: E501 with get_local_path( self.ann_file, backend_args=self.backend_args ) as local_path: self.coco = self.COCOAPI(local_path) # The order of returned `cat_ids` will not # change with the order of the `classes` self.cat_ids = self.coco.get_cat_ids(cat_names=self.metainfo["classes"]) self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) img_ids = self.coco.get_img_ids() data_list = [] total_ann_ids = [] print("Loading data list...") for img_id in tqdm.tqdm(img_ids): raw_img_info = self.coco.load_imgs([img_id])[0] raw_img_info["img_id"] = img_id ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) raw_ann_info = self.coco.load_anns(ann_ids) total_ann_ids.extend(ann_ids) parsed_data_info = self.parse_data_info( {"raw_ann_info": raw_ann_info, "raw_img_info": raw_img_info} ) data_list.append(parsed_data_info) if self.ANN_ID_UNIQUE: assert len(set(total_ann_ids)) == len( total_ann_ids ), f"Annotation ids in '{self.ann_file}' are not unique!" del self.coco return data_list def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: """Parse raw annotation to target format. Args: raw_data_info (dict): Raw data information load from ``ann_file`` Returns: Union[dict, List[dict]]: Parsed annotation. """ img_info = raw_data_info["raw_img_info"] ann_info = raw_data_info["raw_ann_info"] data_info = {} # TODO: need to change data_prefix['img'] to data_prefix['img_path'] img_path = osp.join(self.data_prefix["img"], img_info["file_name"]) if self.data_prefix.get("seg", None): seg_map_path = osp.join( self.data_prefix["seg"], img_info["file_name"].rsplit(".", 1)[0] + self.seg_map_suffix, ) else: seg_map_path = None data_info["img_path"] = img_path data_info["img_id"] = img_info["img_id"] data_info["seg_map_path"] = seg_map_path data_info["height"] = img_info["height"] data_info["width"] = img_info["width"] if self.return_classes: data_info["text"] = self.metainfo["classes"] data_info["caption_prompt"] = self.caption_prompt data_info["custom_entities"] = True instances = [] for i, ann in enumerate(ann_info): instance = {} if ann.get("ignore", False): continue x1, y1, w, h = ann["bbox"] inter_w = max(0, min(x1 + w, img_info["width"]) - max(x1, 0)) inter_h = max(0, min(y1 + h, img_info["height"]) - max(y1, 0)) if inter_w * inter_h == 0: continue if ann["area"] <= 0 or w < 1 or h < 1: continue if "category_id" not in ann: ann["category_id"] = 1 if ann["category_id"] not in self.cat_ids: continue bbox = [x1, y1, x1 + w, y1 + h] if ann.get("iscrowd", False): instance["ignore_flag"] = 1 else: instance["ignore_flag"] = 0 instance["bbox"] = bbox instance["bbox_label"] = self.cat2label[ann["category_id"]] if ann.get("segmentation", None): instance["mask"] = ann["segmentation"] if "instance_id" in ann: instance["instance_id"] = ann["instance_id"] else: instance["instance_id"] = ann["id"] instances.append(instance) data_info["instances"] = instances return data_info def filter_data(self) -> List[dict]: """Filter annotations according to filter_cfg. Returns: List[dict]: Filtered results. """ if self.test_mode: return self.data_list if self.filter_cfg is None: return self.data_list filter_empty_gt = self.filter_cfg.get("filter_empty_gt", False) min_size = self.filter_cfg.get("min_size", 0) # obtain images that contain annotation ids_with_ann = set(data_info["img_id"] for data_info in self.data_list) # obtain images that contain annotations of the required categories ids_in_cat = set() for i, class_id in enumerate(self.cat_ids): ids_in_cat |= set(self.cat_img_map[class_id]) # merge the image id sets of the two conditions and use the merged set # to filter out images if self.filter_empty_gt=True ids_in_cat &= ids_with_ann valid_data_infos = [] for i, data_info in enumerate(self.data_list): img_id = data_info["img_id"] width = data_info["width"] height = data_info["height"] if filter_empty_gt and img_id not in ids_in_cat: continue if min(width, height) >= min_size: valid_data_infos.append(data_info) return valid_data_infos