File size: 3,216 Bytes
2ada650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
from minigpt4.datasets.builders.image_text_pair_builder import (
    LaionBuilder,
    RefVisualGenomeBuilder,
    OpenImageBuilder,
    LocNaCOCOBuilder,
    LlavaDetailBuilder,
    LlavaReasonBuilder,
    NavR2RBuilder,
    PaintPTCOCOBuilder,
    PaintRLCOCOBuilder,
    PaintRLSCOCOBuilder,
    PaintPixelCOCO32Builder,
    PaintPixelCOCO64Builder,
    PaintLanRLOpaqueCOCOBuilder,
    SegRefCOCO32Builder,
    SegRefCOCOG32Builder,
    SegRefCOCOP32Builder,
    SegRefCOCO64Builder,
    SegRefCOCOG64Builder,
    SegRefCOCOP64Builder,
    CMDVideoBuilder,
    WebVidBuilder,
    VideoChatGPTBuilder,
)
from minigpt4.datasets.builders.vqa_builder import (
    COCOVQABuilder,
    OKVQABuilder,
#     AOKVQABuilder,
    COCOVQGBuilder,
#     OKVQGBuilder,
#     AOKVQGBuilder,
    SingleSlideVQABuilder,
    OCRVQABuilder
)
from minigpt4.common.registry import registry

__all__ = [
    "LaionBuilder",
    "RefVisualGenomeBuilder",
    "OpenImageBuilder",
    "SingleSlideVQABuilder",
    "COCOVQABuilder",
    "COCOVQGBuilder",
    "SingleSlideVQABuilder",
    "OCRVQABuilder",
    "LocNaCOCOBuilder",
    "LlavaDetailBuilder",
    "NavR2RBuilder",
    "PaintPTCOCOBuilder",
    "PaintRLCOCOBuilder",
    "PaintRLSCOCOBuilder",
    "PaintLanRLOpaqueCOCOBuilder",
    "PaintPixelCOCO32Builder",
    "PaintPixelCOCO64Builder",
    "SegRefCOCO32Builder",
    "SegRefCOCOG32Builder",
    "SegRefCOCOP32Builder",
    "SegRefCOCO64Builder",
    "SegRefCOCOG64Builder",
    "SegRefCOCOP64Builder",
    "CMDVideoBuilder",
    "WebVidBuilder",
    "VideoChatGPTBuilder",
]


def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
    """
    Example

    >>> dataset = load_dataset("coco_caption", cfg=None)
    >>> splits = dataset.keys()
    >>> print([len(dataset[split]) for split in splits])

    """
    if cfg_path is None:
        cfg = None
    else:
        cfg = load_dataset_config(cfg_path)

    try:
        builder = registry.get_builder_class(name)(cfg)
    except TypeError:
        print(
            f"Dataset {name} not found. Available datasets:\n"
            + ", ".join([str(k) for k in dataset_zoo.get_names()])
        )
        exit(1)

    if vis_path is not None:
        if data_type is None:
            # use default data type in the config
            data_type = builder.config.data_type

        assert (
            data_type in builder.config.build_info
        ), f"Invalid data_type {data_type} for {name}."

        builder.config.build_info.get(data_type).storage = vis_path

    dataset = builder.build_datasets()
    return dataset


class DatasetZoo:
    def __init__(self) -> None:
        self.dataset_zoo = {
            k: list(v.DATASET_CONFIG_DICT.keys())
            for k, v in sorted(registry.mapping["builder_name_mapping"].items())
        }

    def get_names(self):
        return list(self.dataset_zoo.keys())


dataset_zoo = DatasetZoo()