Spaces:
Runtime error
Runtime error
import logging | |
import torch.distributed as dist | |
import bubogpt.common.utils as utils | |
from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process | |
from bubogpt.common.registry import registry | |
from bubogpt.datasets.builders import load_dataset_config | |
from bubogpt.processors.base_processor import BaseProcessor | |
class MultimodalBaseDatasetBuilder(): | |
train_dataset_cls, eval_dataset_cls = None, None | |
def __init__(self, cfg=None): | |
super().__init__() | |
if cfg is None: | |
# help to create datasets from default config. | |
self.config = load_dataset_config(self.default_config_path()) | |
elif isinstance(cfg, str): | |
self.config = load_dataset_config(cfg) | |
else: | |
# when called from task.build_dataset() | |
self.config = cfg | |
self.data_type = self.config.data_type.split("_") | |
# It will be a list like ["audio", "image"], etc. | |
# Add "text" manually here. | |
self.processors = {modal: {"train": BaseProcessor(), "eval": BaseProcessor()} | |
for modal in [*self.data_type, "text"]} | |
def build_datasets(self): | |
# download, split, etc... | |
# only called on 1 GPU/TPU in distributed | |
if is_main_process(): | |
self._download_data() | |
if is_dist_avail_and_initialized(): | |
dist.barrier() | |
# at this point, all the annotations and image/videos should be all downloaded to the specified locations. | |
logging.info("Building datasets...") | |
datasets = self.build() # dataset['train'/'val'/'test'] | |
return datasets | |
def build_processors(self): | |
for modal in [*self.data_type, "text"]: | |
proc_cfg = self.config.get("{}_processor".format(modal)) | |
if proc_cfg is not None: | |
train_cfg = proc_cfg.get("train") | |
eval_cfg = proc_cfg.get("eval") | |
self.processors[modal]["train"] = self._build_proc_from_cfg(train_cfg) | |
self.processors[modal]["eval"] = self._build_proc_from_cfg(eval_cfg) | |
def _build_proc_from_cfg(cfg): | |
return ( | |
registry.get_processor_class(cfg.name).from_config(cfg) | |
if cfg is not None | |
else None | |
) | |
def default_config_path(cls, type="default"): | |
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) | |
def _download_data(self): | |
pass | |