Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) Tencent Inc. All rights reserved. | |
import copy | |
import json | |
import logging | |
from typing import Callable, List, Union | |
from mmengine.logging import print_log | |
from mmengine.dataset.base_dataset import ( | |
BaseDataset, Compose, force_full_init) | |
from mmyolo.registry import DATASETS | |
class MultiModalDataset: | |
"""Multi-modal dataset.""" | |
def __init__(self, | |
dataset: Union[BaseDataset, dict], | |
class_text_path: str = None, | |
test_mode: bool = True, | |
pipeline: List[Union[dict, Callable]] = [], | |
lazy_init: bool = False) -> None: | |
self.dataset: BaseDataset | |
if isinstance(dataset, dict): | |
self.dataset = DATASETS.build(dataset) | |
elif isinstance(dataset, BaseDataset): | |
self.dataset = dataset | |
else: | |
raise TypeError( | |
'dataset must be a dict or a BaseDataset, ' | |
f'but got {dataset}') | |
if class_text_path is not None: | |
self.class_texts = json.load(open(class_text_path, 'r')) | |
# ori_classes = self.dataset.metainfo['classes'] | |
# assert len(ori_classes) == len(self.class_texts), \ | |
# ('The number of classes in the dataset and the class text' | |
# 'file must be the same.') | |
else: | |
self.class_texts = None | |
self.test_mode = test_mode | |
self._metainfo = self.dataset.metainfo | |
self.pipeline = Compose(pipeline) | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def metainfo(self) -> dict: | |
return copy.deepcopy(self._metainfo) | |
def full_init(self) -> None: | |
"""``full_init`` dataset.""" | |
if self._fully_initialized: | |
return | |
self.dataset.full_init() | |
self._ori_len = len(self.dataset) | |
self._fully_initialized = True | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index.""" | |
data_info = self.dataset.get_data_info(idx) | |
if self.class_texts is not None: | |
data_info.update({'texts': self.class_texts}) | |
return data_info | |
def __getitem__(self, idx): | |
if not self._fully_initialized: | |
print_log( | |
'Please call `full_init` method manually to ' | |
'accelerate the speed.', | |
logger='current', | |
level=logging.WARNING) | |
self.full_init() | |
data_info = self.get_data_info(idx) | |
if hasattr(self.dataset, 'test_mode') and not self.dataset.test_mode: | |
data_info['dataset'] = self | |
elif not self.test_mode: | |
data_info['dataset'] = self | |
return self.pipeline(data_info) | |
def __len__(self) -> int: | |
return self._ori_len | |
class MultiModalMixedDataset(MultiModalDataset): | |
"""Multi-modal Mixed dataset. | |
mix "detection dataset" and "caption dataset" | |
Args: | |
dataset_type (str): dataset type, 'detection' or 'caption' | |
""" | |
def __init__(self, | |
dataset: Union[BaseDataset, dict], | |
class_text_path: str = None, | |
dataset_type: str = 'detection', | |
test_mode: bool = True, | |
pipeline: List[Union[dict, Callable]] = [], | |
lazy_init: bool = False) -> None: | |
self.dataset_type = dataset_type | |
super().__init__(dataset, | |
class_text_path, | |
test_mode, | |
pipeline, | |
lazy_init) | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index.""" | |
data_info = self.dataset.get_data_info(idx) | |
if self.class_texts is not None: | |
data_info.update({'texts': self.class_texts}) | |
data_info['is_detection'] = 1 \ | |
if self.dataset_type == 'detection' else 0 | |
return data_info | |