Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,736 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import collections
import os.path as osp
import random
from typing import Dict, List
import mmengine
from mmengine.dataset import BaseDataset
from mmdet.registry import DATASETS
@DATASETS.register_module()
class RefCocoDataset(BaseDataset):
"""RefCOCO dataset.
The `Refcoco` and `Refcoco+` dataset is based on
`ReferItGame: Referring to Objects in Photographs of Natural Scenes
<http://tamaraberg.com/papers/referit.pdf>`_.
The `Refcocog` dataset is based on
`Generation and Comprehension of Unambiguous Object Descriptions
<https://arxiv.org/abs/1511.02283>`_.
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
text_mode (str): Text mode. Defaults to 'random'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
ann_file: str,
split_file: str,
data_prefix: Dict,
split: str = 'train',
text_mode: str = 'random',
**kwargs):
self.split_file = split_file
self.split = split
assert text_mode in ['original', 'random', 'concat', 'select_first']
self.text_mode = text_mode
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs,
)
def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)
return super()._join_prefix()
def _init_refs(self):
"""Initialize the refs for RefCOCO."""
anns, imgs = {}, {}
for ann in self.instances['annotations']:
anns[ann['id']] = ann
for img in self.instances['images']:
imgs[img['id']] = img
refs, ref_to_ann = {}, {}
for ref in self.splits:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
# add mapping related to ref
refs[ref_id] = ref
ref_to_ann[ref_id] = anns[ann_id]
self.refs = refs
self.ref_to_ann = ref_to_ann
def load_data_list(self) -> List[dict]:
"""Load data list."""
self.splits = mmengine.load(self.split_file, file_format='pkl')
self.instances = mmengine.load(self.ann_file, file_format='json')
self._init_refs()
img_prefix = self.data_prefix['img_path']
ref_ids = [
ref['ref_id'] for ref in self.splits if ref['split'] == self.split
]
full_anno = []
for ref_id in ref_ids:
ref = self.refs[ref_id]
ann = self.ref_to_ann[ref_id]
ann.update(ref)
full_anno.append(ann)
image_id_list = []
final_anno = {}
for anno in full_anno:
image_id_list.append(anno['image_id'])
final_anno[anno['ann_id']] = anno
annotations = [value for key, value in final_anno.items()]
coco_train_id = []
image_annot = {}
for i in range(len(self.instances['images'])):
coco_train_id.append(self.instances['images'][i]['id'])
image_annot[self.instances['images'][i]
['id']] = self.instances['images'][i]
images = []
for image_id in list(set(image_id_list)):
images += [image_annot[image_id]]
data_list = []
grounding_dict = collections.defaultdict(list)
for anno in annotations:
image_id = int(anno['image_id'])
grounding_dict[image_id].append(anno)
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for image in images:
img_id = image['id']
instances = []
sentences = []
for grounding_anno in grounding_dict[img_id]:
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
# random select one text
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = [texts[idx]]
# concat all texts
elif self.text_mode == 'concat':
text = [''.join(texts)]
# select the first text
elif self.text_mode == 'select_first':
text = [texts[0]]
# use all texts
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
ins = [{
'mask': grounding_anno['segmentation'],
'ignore_flag': 0
}] * len(text)
instances.extend(ins)
sentences.extend(text)
data_info = {
'img_path': join_path(img_prefix, image['file_name']),
'img_id': img_id,
'instances': instances,
'text': sentences
}
data_list.append(data_info)
if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')
return data_list
|