Spaces:
Running
on
A10G
Running
on
A10G
from abc import ABC | |
import logging | |
from typing import Sequence, Union, Optional, Tuple | |
from mmengine.dataset import ConcatDataset, RepeatDataset, ClassBalancedDataset | |
from mmengine.logging import print_log | |
from mmengine.registry import DATASETS | |
from mmengine.dataset.base_dataset import BaseDataset | |
from mmdet.structures import TrackDataSample | |
from seg.models.utils import NO_OBJ | |
class ConcatOVDataset(ConcatDataset, ABC): | |
_fully_initialized: bool = False | |
def __init__(self, | |
datasets: Sequence[Union[BaseDataset, dict]], | |
lazy_init: bool = False, | |
data_tag: Optional[Tuple[str]] = None, | |
): | |
for i, dataset in enumerate(datasets): | |
if isinstance(dataset, dict): | |
dataset.update(lazy_init=lazy_init) | |
if 'times' in dataset: | |
dataset['dataset'].update(lazy_init=lazy_init) | |
super().__init__(datasets, lazy_init=lazy_init, | |
ignore_keys=['classes', 'thing_classes', 'stuff_classes', 'palette']) | |
self.data_tag = data_tag | |
if self.data_tag is not None: | |
assert len(self.data_tag) == len(datasets) | |
cls_names = [] | |
for dataset in self.datasets: | |
if isinstance(dataset, RepeatDataset) or isinstance(dataset, ClassBalancedDataset): | |
if hasattr(dataset.dataset, 'dataset_name'): | |
name = dataset.dataset.dataset_name | |
else: | |
name = dataset.dataset.__class__.__name__ | |
else: | |
if hasattr(dataset, 'dataset_name'): | |
name = dataset.dataset_name | |
else: | |
name = dataset.__class__.__name__ | |
cls_names.append(name) | |
thing_classes = [] | |
thing_mapper = [] | |
stuff_classes = [] | |
stuff_mapper = [] | |
for idx, dataset in enumerate(self.datasets): | |
if 'classes' not in dataset.metainfo or (self.data_tag is not None and self.data_tag[idx] in ['sam']): | |
# class agnostic dataset | |
_thing_mapper = {} | |
_stuff_mapper = {} | |
thing_mapper.append(_thing_mapper) | |
stuff_mapper.append(_stuff_mapper) | |
continue | |
_thing_classes = dataset.metainfo['thing_classes'] \ | |
if 'thing_classes' in dataset.metainfo else dataset.metainfo['classes'] | |
_stuff_classes = dataset.metainfo['stuff_classes'] if 'stuff_classes' in dataset.metainfo else [] | |
_thing_mapper = {} | |
_stuff_mapper = {} | |
for idy, cls in enumerate(_thing_classes): | |
flag = False | |
cls = cls.replace('_or_', ',') | |
cls = cls.replace('/', ',') | |
cls = cls.replace('_', ' ') | |
cls = cls.lower() | |
for all_idx, all_cls in enumerate(thing_classes): | |
if set(cls.split(',')).intersection(set(all_cls.split(','))): | |
_thing_mapper[idy] = all_idx | |
flag = True | |
break | |
if not flag: | |
thing_classes.append(cls) | |
_thing_mapper[idy] = len(thing_classes) - 1 | |
thing_mapper.append(_thing_mapper) | |
for idy, cls in enumerate(_stuff_classes): | |
flag = False | |
cls = cls.replace('_or_', ',') | |
cls = cls.replace('/', ',') | |
cls = cls.replace('_', ' ') | |
cls = cls.lower() | |
for all_idx, all_cls in enumerate(stuff_classes): | |
if set(cls.split(',')).intersection(set(all_cls.split(','))): | |
_stuff_mapper[idy] = all_idx | |
flag = True | |
break | |
if not flag: | |
stuff_classes.append(cls) | |
_stuff_mapper[idy] = len(stuff_classes) - 1 | |
stuff_mapper.append(_stuff_mapper) | |
cls_name = "" | |
cnt = 0 | |
dataset_idx = 0 | |
classes = [*thing_classes, *stuff_classes] | |
mapper = [] | |
meta_cls_names = [] | |
for _thing_mapper, _stuff_mapper in zip(thing_mapper, stuff_mapper): | |
if not _thing_mapper and not _stuff_mapper: | |
# class agnostic dataset | |
_mapper = dict() | |
for idx in range(1000): | |
_mapper[idx] = -1 | |
else: | |
_mapper = {**_thing_mapper} | |
_num_thing = len(_thing_mapper) | |
for key, value in _stuff_mapper.items(): | |
assert value < len(stuff_classes) | |
_mapper[key + _num_thing] = _stuff_mapper[key] + len(thing_classes) | |
assert len(_mapper) == len(_thing_mapper) + len(_stuff_mapper) | |
cnt += 1 | |
cls_name = cls_name + cls_names[dataset_idx] + "_" | |
meta_cls_names.append(cls_names[dataset_idx]) | |
_mapper[NO_OBJ] = NO_OBJ | |
mapper.append(_mapper) | |
dataset_idx += 1 | |
if cnt > 1: | |
cls_name = "Concat_" + cls_name | |
cls_name = cls_name[:-1] | |
self.dataset_name = cls_name | |
self._metainfo.update({ | |
'classes': classes, | |
'thing_classes': thing_classes, | |
'stuff_classes': stuff_classes, | |
'mapper': mapper, | |
'dataset_names': meta_cls_names | |
}) | |
print_log( | |
f"------------{self.dataset_name}------------", | |
logger='current', | |
level=logging.INFO | |
) | |
for idx, dataset in enumerate(self.datasets): | |
dataset_type = cls_names[idx] | |
if isinstance(dataset, RepeatDataset): | |
times = dataset.times | |
else: | |
times = 1 | |
print_log( | |
f"|---dataset#{idx + 1} --> name: {dataset_type}; length: {len(dataset)}; repeat times: {times}", | |
logger='current', | |
level=logging.INFO | |
) | |
print_log( | |
f"------num_things : {len(thing_classes)}; num_stuff : {len(stuff_classes)}------", | |
logger='current', | |
level=logging.INFO | |
) | |
def get_dataset_source(self, idx: int) -> int: | |
dataset_idx, _ = self._get_ori_dataset_idx(idx) | |
return dataset_idx | |
def __getitem__(self, idx): | |
if not self._fully_initialized: | |
print_log( | |
'Please call `full_init` method manually to ' | |
'accelerate the speed.', | |
logger='current', | |
level=logging.WARNING) | |
self.full_init() | |
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) | |
results = self.datasets[dataset_idx][sample_idx] | |
_mapper = self.metainfo['mapper'][dataset_idx] | |
data_samples = results['data_samples'] | |
if isinstance(data_samples, TrackDataSample): | |
for det_sample in data_samples: | |
if 'gt_sem_seg' in det_sample: | |
det_sample.gt_sem_seg.sem_seg.apply_(lambda x: _mapper.__getitem__(x)) | |
if 'gt_instances' in det_sample: | |
det_sample.gt_instances.labels.apply_(lambda x: _mapper.__getitem__(x)) | |
else: | |
if 'gt_sem_seg' in data_samples: | |
data_samples.gt_sem_seg.sem_seg.apply_(lambda x: _mapper.__getitem__(x)) | |
if 'gt_instances' in data_samples: | |
data_samples.gt_instances.labels.apply_(lambda x: _mapper.__getitem__(x)) | |
if self.data_tag is not None: | |
data_samples.data_tag = self.data_tag[dataset_idx] | |
return results | |