Spaces:
Sleeping
Sleeping
File size: 7,760 Bytes
b34d1d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from abc import ABC
import logging
from typing import Sequence, Union, Optional, Tuple
from mmengine.dataset import ConcatDataset, RepeatDataset, ClassBalancedDataset
from mmengine.logging import print_log
from mmengine.registry import DATASETS
from mmengine.dataset.base_dataset import BaseDataset
from mmdet.structures import TrackDataSample
from seg.models.utils import NO_OBJ
@DATASETS.register_module()
class ConcatOVDataset(ConcatDataset, ABC):
_fully_initialized: bool = False
def __init__(self,
datasets: Sequence[Union[BaseDataset, dict]],
lazy_init: bool = False,
data_tag: Optional[Tuple[str]] = None,
):
for i, dataset in enumerate(datasets):
if isinstance(dataset, dict):
dataset.update(lazy_init=lazy_init)
if 'times' in dataset:
dataset['dataset'].update(lazy_init=lazy_init)
super().__init__(datasets, lazy_init=lazy_init,
ignore_keys=['classes', 'thing_classes', 'stuff_classes', 'palette'])
self.data_tag = data_tag
if self.data_tag is not None:
assert len(self.data_tag) == len(datasets)
cls_names = []
for dataset in self.datasets:
if isinstance(dataset, RepeatDataset) or isinstance(dataset, ClassBalancedDataset):
if hasattr(dataset.dataset, 'dataset_name'):
name = dataset.dataset.dataset_name
else:
name = dataset.dataset.__class__.__name__
else:
if hasattr(dataset, 'dataset_name'):
name = dataset.dataset_name
else:
name = dataset.__class__.__name__
cls_names.append(name)
thing_classes = []
thing_mapper = []
stuff_classes = []
stuff_mapper = []
for idx, dataset in enumerate(self.datasets):
if 'classes' not in dataset.metainfo or (self.data_tag is not None and self.data_tag[idx] in ['sam']):
# class agnostic dataset
_thing_mapper = {}
_stuff_mapper = {}
thing_mapper.append(_thing_mapper)
stuff_mapper.append(_stuff_mapper)
continue
_thing_classes = dataset.metainfo['thing_classes'] \
if 'thing_classes' in dataset.metainfo else dataset.metainfo['classes']
_stuff_classes = dataset.metainfo['stuff_classes'] if 'stuff_classes' in dataset.metainfo else []
_thing_mapper = {}
_stuff_mapper = {}
for idy, cls in enumerate(_thing_classes):
flag = False
cls = cls.replace('_or_', ',')
cls = cls.replace('/', ',')
cls = cls.replace('_', ' ')
cls = cls.lower()
for all_idx, all_cls in enumerate(thing_classes):
if set(cls.split(',')).intersection(set(all_cls.split(','))):
_thing_mapper[idy] = all_idx
flag = True
break
if not flag:
thing_classes.append(cls)
_thing_mapper[idy] = len(thing_classes) - 1
thing_mapper.append(_thing_mapper)
for idy, cls in enumerate(_stuff_classes):
flag = False
cls = cls.replace('_or_', ',')
cls = cls.replace('/', ',')
cls = cls.replace('_', ' ')
cls = cls.lower()
for all_idx, all_cls in enumerate(stuff_classes):
if set(cls.split(',')).intersection(set(all_cls.split(','))):
_stuff_mapper[idy] = all_idx
flag = True
break
if not flag:
stuff_classes.append(cls)
_stuff_mapper[idy] = len(stuff_classes) - 1
stuff_mapper.append(_stuff_mapper)
cls_name = ""
cnt = 0
dataset_idx = 0
classes = [*thing_classes, *stuff_classes]
mapper = []
meta_cls_names = []
for _thing_mapper, _stuff_mapper in zip(thing_mapper, stuff_mapper):
if not _thing_mapper and not _stuff_mapper:
# class agnostic dataset
_mapper = dict()
for idx in range(1000):
_mapper[idx] = -1
else:
_mapper = {**_thing_mapper}
_num_thing = len(_thing_mapper)
for key, value in _stuff_mapper.items():
assert value < len(stuff_classes)
_mapper[key + _num_thing] = _stuff_mapper[key] + len(thing_classes)
assert len(_mapper) == len(_thing_mapper) + len(_stuff_mapper)
cnt += 1
cls_name = cls_name + cls_names[dataset_idx] + "_"
meta_cls_names.append(cls_names[dataset_idx])
_mapper[NO_OBJ] = NO_OBJ
mapper.append(_mapper)
dataset_idx += 1
if cnt > 1:
cls_name = "Concat_" + cls_name
cls_name = cls_name[:-1]
self.dataset_name = cls_name
self._metainfo.update({
'classes': classes,
'thing_classes': thing_classes,
'stuff_classes': stuff_classes,
'mapper': mapper,
'dataset_names': meta_cls_names
})
print_log(
f"------------{self.dataset_name}------------",
logger='current',
level=logging.INFO
)
for idx, dataset in enumerate(self.datasets):
dataset_type = cls_names[idx]
if isinstance(dataset, RepeatDataset):
times = dataset.times
else:
times = 1
print_log(
f"|---dataset#{idx + 1} --> name: {dataset_type}; length: {len(dataset)}; repeat times: {times}",
logger='current',
level=logging.INFO
)
print_log(
f"------num_things : {len(thing_classes)}; num_stuff : {len(stuff_classes)}------",
logger='current',
level=logging.INFO
)
def get_dataset_source(self, idx: int) -> int:
dataset_idx, _ = self._get_ori_dataset_idx(idx)
return dataset_idx
def __getitem__(self, idx):
if not self._fully_initialized:
print_log(
'Please call `full_init` method manually to '
'accelerate the speed.',
logger='current',
level=logging.WARNING)
self.full_init()
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
results = self.datasets[dataset_idx][sample_idx]
_mapper = self.metainfo['mapper'][dataset_idx]
data_samples = results['data_samples']
if isinstance(data_samples, TrackDataSample):
for det_sample in data_samples:
if 'gt_sem_seg' in det_sample:
det_sample.gt_sem_seg.sem_seg.apply_(lambda x: _mapper.__getitem__(x))
if 'gt_instances' in det_sample:
det_sample.gt_instances.labels.apply_(lambda x: _mapper.__getitem__(x))
else:
if 'gt_sem_seg' in data_samples:
data_samples.gt_sem_seg.sem_seg.apply_(lambda x: _mapper.__getitem__(x))
if 'gt_instances' in data_samples:
data_samples.gt_instances.labels.apply_(lambda x: _mapper.__getitem__(x))
if self.data_tag is not None:
data_samples.data_tag = self.data_tag[dataset_idx]
return results
|