Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os.path as osp | |
import xml.etree.ElementTree as ET | |
from typing import List, Optional, Union | |
import mmcv | |
from mmengine.fileio import get, get_local_path, list_from_file | |
from mmdet.registry import DATASETS | |
from .base_det_dataset import BaseDetDataset | |
class XMLDataset(BaseDetDataset): | |
"""XML dataset for detection. | |
Args: | |
img_subdir (str): Subdir where images are stored. Default: JPEGImages. | |
ann_subdir (str): Subdir where annotations are. Default: Annotations. | |
backend_args (dict, optional): Arguments to instantiate the | |
corresponding backend. Defaults to None. | |
""" | |
def __init__(self, | |
img_subdir: str = 'JPEGImages', | |
ann_subdir: str = 'Annotations', | |
**kwargs) -> None: | |
self.img_subdir = img_subdir | |
self.ann_subdir = ann_subdir | |
super().__init__(**kwargs) | |
def sub_data_root(self) -> str: | |
"""Return the sub data root.""" | |
return self.data_prefix.get('sub_data_root', '') | |
def load_data_list(self) -> List[dict]: | |
"""Load annotation from XML style ann_file. | |
Returns: | |
list[dict]: Annotation info from XML file. | |
""" | |
assert self._metainfo.get('classes', None) is not None, \ | |
'`classes` in `XMLDataset` can not be None.' | |
self.cat2label = { | |
cat: i | |
for i, cat in enumerate(self._metainfo['classes']) | |
} | |
data_list = [] | |
img_ids = list_from_file(self.ann_file, backend_args=self.backend_args) | |
for img_id in img_ids: | |
file_name = osp.join(self.img_subdir, f'{img_id}.jpg') | |
xml_path = osp.join(self.sub_data_root, self.ann_subdir, | |
f'{img_id}.xml') | |
raw_img_info = {} | |
raw_img_info['img_id'] = img_id | |
raw_img_info['file_name'] = file_name | |
raw_img_info['xml_path'] = xml_path | |
parsed_data_info = self.parse_data_info(raw_img_info) | |
data_list.append(parsed_data_info) | |
return data_list | |
def bbox_min_size(self) -> Optional[int]: | |
"""Return the minimum size of bounding boxes in the images.""" | |
if self.filter_cfg is not None: | |
return self.filter_cfg.get('bbox_min_size', None) | |
else: | |
return None | |
def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]: | |
"""Parse raw annotation to target format. | |
Args: | |
img_info (dict): Raw image information, usually it includes | |
`img_id`, `file_name`, and `xml_path`. | |
Returns: | |
Union[dict, List[dict]]: Parsed annotation. | |
""" | |
data_info = {} | |
img_path = osp.join(self.sub_data_root, img_info['file_name']) | |
data_info['img_path'] = img_path | |
data_info['img_id'] = img_info['img_id'] | |
data_info['xml_path'] = img_info['xml_path'] | |
# deal with xml file | |
with get_local_path( | |
img_info['xml_path'], | |
backend_args=self.backend_args) as local_path: | |
raw_ann_info = ET.parse(local_path) | |
root = raw_ann_info.getroot() | |
size = root.find('size') | |
if size is not None: | |
width = int(size.find('width').text) | |
height = int(size.find('height').text) | |
else: | |
img_bytes = get(img_path, backend_args=self.backend_args) | |
img = mmcv.imfrombytes(img_bytes, backend='cv2') | |
height, width = img.shape[:2] | |
del img, img_bytes | |
data_info['height'] = height | |
data_info['width'] = width | |
data_info['instances'] = self._parse_instance_info( | |
raw_ann_info, minus_one=True) | |
return data_info | |
def _parse_instance_info(self, | |
raw_ann_info: ET, | |
minus_one: bool = True) -> List[dict]: | |
"""parse instance information. | |
Args: | |
raw_ann_info (ElementTree): ElementTree object. | |
minus_one (bool): Whether to subtract 1 from the coordinates. | |
Defaults to True. | |
Returns: | |
List[dict]: List of instances. | |
""" | |
instances = [] | |
for obj in raw_ann_info.findall('object'): | |
instance = {} | |
name = obj.find('name').text | |
if name not in self._metainfo['classes']: | |
continue | |
difficult = obj.find('difficult') | |
difficult = 0 if difficult is None else int(difficult.text) | |
bnd_box = obj.find('bndbox') | |
bbox = [ | |
int(float(bnd_box.find('xmin').text)), | |
int(float(bnd_box.find('ymin').text)), | |
int(float(bnd_box.find('xmax').text)), | |
int(float(bnd_box.find('ymax').text)) | |
] | |
# VOC needs to subtract 1 from the coordinates | |
if minus_one: | |
bbox = [x - 1 for x in bbox] | |
ignore = False | |
if self.bbox_min_size is not None: | |
assert not self.test_mode | |
w = bbox[2] - bbox[0] | |
h = bbox[3] - bbox[1] | |
if w < self.bbox_min_size or h < self.bbox_min_size: | |
ignore = True | |
if difficult or ignore: | |
instance['ignore_flag'] = 1 | |
else: | |
instance['ignore_flag'] = 0 | |
instance['bbox'] = bbox | |
instance['bbox_label'] = self.cat2label[name] | |
instances.append(instance) | |
return instances | |
def filter_data(self) -> List[dict]: | |
"""Filter annotations according to filter_cfg. | |
Returns: | |
List[dict]: Filtered results. | |
""" | |
if self.test_mode: | |
return self.data_list | |
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \ | |
if self.filter_cfg is not None else False | |
min_size = self.filter_cfg.get('min_size', 0) \ | |
if self.filter_cfg is not None else 0 | |
valid_data_infos = [] | |
for i, data_info in enumerate(self.data_list): | |
width = data_info['width'] | |
height = data_info['height'] | |
if filter_empty_gt and len(data_info['instances']) == 0: | |
continue | |
if min(width, height) >= min_size: | |
valid_data_infos.append(data_info) | |
return valid_data_infos | |