""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config from minigpt4.datasets.builders.image_text_pair_builder import ( LaionBuilder, RefVisualGenomeBuilder, OpenImageBuilder, LocNaCOCOBuilder, LlavaDetailBuilder, LlavaReasonBuilder, NavR2RBuilder, PaintPTCOCOBuilder, PaintRLCOCOBuilder, PaintRLSCOCOBuilder, PaintPixelCOCO32Builder, PaintPixelCOCO64Builder, PaintLanRLOpaqueCOCOBuilder, SegRefCOCO32Builder, SegRefCOCOG32Builder, SegRefCOCOP32Builder, SegRefCOCO64Builder, SegRefCOCOG64Builder, SegRefCOCOP64Builder, CMDVideoBuilder, WebVidBuilder, VideoChatGPTBuilder, ) from minigpt4.datasets.builders.vqa_builder import ( COCOVQABuilder, OKVQABuilder, # AOKVQABuilder, COCOVQGBuilder, # OKVQGBuilder, # AOKVQGBuilder, SingleSlideVQABuilder, OCRVQABuilder ) from minigpt4.common.registry import registry __all__ = [ "LaionBuilder", "RefVisualGenomeBuilder", "OpenImageBuilder", "SingleSlideVQABuilder", "COCOVQABuilder", "COCOVQGBuilder", "SingleSlideVQABuilder", "OCRVQABuilder", "LocNaCOCOBuilder", "LlavaDetailBuilder", "NavR2RBuilder", "PaintPTCOCOBuilder", "PaintRLCOCOBuilder", "PaintRLSCOCOBuilder", "PaintLanRLOpaqueCOCOBuilder", "PaintPixelCOCO32Builder", "PaintPixelCOCO64Builder", "SegRefCOCO32Builder", "SegRefCOCOG32Builder", "SegRefCOCOP32Builder", "SegRefCOCO64Builder", "SegRefCOCOG64Builder", "SegRefCOCOP64Builder", "CMDVideoBuilder", "WebVidBuilder", "VideoChatGPTBuilder", ] def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): """ Example >>> dataset = load_dataset("coco_caption", cfg=None) >>> splits = dataset.keys() >>> print([len(dataset[split]) for split in splits]) """ if cfg_path is None: cfg = None else: cfg = load_dataset_config(cfg_path) try: builder = registry.get_builder_class(name)(cfg) except TypeError: print( f"Dataset {name} not found. Available datasets:\n" + ", ".join([str(k) for k in dataset_zoo.get_names()]) ) exit(1) if vis_path is not None: if data_type is None: # use default data type in the config data_type = builder.config.data_type assert ( data_type in builder.config.build_info ), f"Invalid data_type {data_type} for {name}." builder.config.build_info.get(data_type).storage = vis_path dataset = builder.build_datasets() return dataset class DatasetZoo: def __init__(self) -> None: self.dataset_zoo = { k: list(v.DATASET_CONFIG_DICT.keys()) for k, v in sorted(registry.mapping["builder_name_mapping"].items()) } def get_names(self): return list(self.dataset_zoo.keys()) dataset_zoo = DatasetZoo()