BuboGPT / bubogpt /datasets /builders /multimodal_base_dataset_builder.py
ikuinen99's picture
update
e4bd7f9
raw
history blame
2.48 kB
import logging
import torch.distributed as dist
import bubogpt.common.utils as utils
from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process
from bubogpt.common.registry import registry
from bubogpt.datasets.builders import load_dataset_config
from bubogpt.processors.base_processor import BaseProcessor
class MultimodalBaseDatasetBuilder():
train_dataset_cls, eval_dataset_cls = None, None
def __init__(self, cfg=None):
super().__init__()
if cfg is None:
# help to create datasets from default config.
self.config = load_dataset_config(self.default_config_path())
elif isinstance(cfg, str):
self.config = load_dataset_config(cfg)
else:
# when called from task.build_dataset()
self.config = cfg
self.data_type = self.config.data_type.split("_")
# It will be a list like ["audio", "image"], etc.
# Add "text" manually here.
self.processors = {modal: {"train": BaseProcessor(), "eval": BaseProcessor()}
for modal in [*self.data_type, "text"]}
def build_datasets(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
if is_main_process():
self._download_data()
if is_dist_avail_and_initialized():
dist.barrier()
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
datasets = self.build() # dataset['train'/'val'/'test']
return datasets
def build_processors(self):
for modal in [*self.data_type, "text"]:
proc_cfg = self.config.get("{}_processor".format(modal))
if proc_cfg is not None:
train_cfg = proc_cfg.get("train")
eval_cfg = proc_cfg.get("eval")
self.processors[modal]["train"] = self._build_proc_from_cfg(train_cfg)
self.processors[modal]["eval"] = self._build_proc_from_cfg(eval_cfg)
@staticmethod
def _build_proc_from_cfg(cfg):
return (
registry.get_processor_class(cfg.name).from_config(cfg)
if cfg is not None
else None
)
@classmethod
def default_config_path(cls, type="default"):
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
def _download_data(self):
pass