File size: 4,145 Bytes
2ada650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder

from minigpt4.common.registry import registry
from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset
from minigpt4.datasets.datasets.aok_vqa_reasoning_datasets import AOKVQAReasoningDataset
#, AOKVQGDataset, AOKVQAEvalDataset
from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQGDataset, COCOVQAEvalDataset
# from minigpt4.datasets.datasets.vg_vqa_datasets import VGVQADataset
from minigpt4.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset
from minigpt4.datasets.datasets.doc_dataset import SingleSlideVQADataset, OCRVQADataset



@registry.register_builder("coco_vqa")
class COCOVQABuilder(BaseDatasetBuilder):
    train_dataset_cls = COCOVQADataset
    eval_dataset_cls = COCOVQAEvalDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/coco/defaults_vqa.yaml",
        "eval": "configs/datasets/coco/eval_vqa.yaml",
    }
    

# @registry.register_builder("vg_vqa")
# class VGVQABuilder(BaseDatasetBuilder):
#     train_dataset_cls = VGVQADataset
#     DATASET_CONFIG_DICT = {"default": "configs/datasets/vg/defaults_vqa.yaml"}


@registry.register_builder("ok_vqa")
class OKVQABuilder(COCOVQABuilder):
    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/okvqa/defaults.yaml",
    }


@registry.register_builder("aok_vqa")
class AOKVQABuilder(BaseDatasetBuilder):
    train_dataset_cls = AOKVQADataset
    # eval_dataset_cls = AOKVQAEvalDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"}

@registry.register_builder("aok_vqa_reasoning")
class AOKVQABuilder(BaseDatasetBuilder):
    train_dataset_cls = AOKVQAReasoningDataset
    # eval_dataset_cls = AOKVQAEvalDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa_reasoning/defaults.yaml"}


@registry.register_builder("gqa")
class GQABuilder(BaseDatasetBuilder):
    train_dataset_cls = GQADataset
    # eval_dataset_cls = GQAEvalDataset

    DATASET_CONFIG_DICT = {
        # "default": "configs/datasets/gqa/defaults.yaml",
        # "balanced_val": "configs/datasets/gqa/balanced_val.yaml",
        "default": "configs/datasets/gqa/balanced_val.yaml",
        # "balanced_testdev": "configs/datasets/gqa/balanced_testdev.yaml",
    }



@registry.register_builder("coco_vqg")
class COCOVQGBuilder(BaseDatasetBuilder):
    train_dataset_cls = COCOVQGDataset

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/coco/defaults_vqg.yaml",
    }


@registry.register_builder("ok_vqg")
class OKVQGBuilder(COCOVQGBuilder):
    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/okvqa/defaults_vqg.yaml",
    }


# @registry.register_builder("aok_vqg")
# class AOKVQGBuilder(BaseDatasetBuilder):
#     train_dataset_cls = AOKVQGDataset

#     DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults_vqg.yaml"}


class DocumentVQABuilder(BaseDatasetBuilder):
    def _download_ann(self):
        pass

    def _download_vis(self):
        pass

    def build(self):
        self.build_processors()
        build_info = self.config.build_info

        datasets = dict()
        split = "train"

        dataset_cls = self.train_dataset_cls
        datasets[split] = dataset_cls(
            vis_processor=self.vis_processors[split],
            text_processor=self.text_processors[split],
            vis_root=build_info.image_path,
            ann_path=build_info.ann_path
        )

        return datasets


@registry.register_builder("sslidevqa")
class SingleSlideVQABuilder(DocumentVQABuilder):
    train_dataset_cls = SingleSlideVQADataset
    DATASET_CONFIG_DICT = {"default": "configs/datasets/doc/sslidevqa.yaml"}


@registry.register_builder("ocrvqa")
class OCRVQABuilder(DocumentVQABuilder):
    train_dataset_cls = OCRVQADataset
    DATASET_CONFIG_DICT = {"default": "configs/datasets/doc/ocrvqa.yaml"}