File size: 3,182 Bytes
a25563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from .mvtec import MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT
from .visa import VISA_CLS_NAMES, VisaDataset, VISA_ROOT
from .mpdd import MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT
from .btad import BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT
from .sdd import SDD_CLS_NAMES, SDDDataset, SDD_ROOT
from .dagm import DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT
from .dtd import DTD_CLS_NAMES,DTDDataset,DTD_ROOT
from .isic import ISIC_CLS_NAMES,ISICDataset,ISIC_ROOT
from .colondb import ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT
from .clinicdb import ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT
from .tn3k import TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT
from .headct import HEADCT_CLS_NAMES,HEADCTDataset,HEADCT_ROOT
from .brain_mri import BrainMRI_CLS_NAMES,BrainMRIDataset,BrainMRI_ROOT
from .br35h import Br35h_CLS_NAMES,Br35hDataset,Br35h_ROOT
from torch.utils.data import ConcatDataset

dataset_dict = {
    'br35h': (Br35h_CLS_NAMES, Br35hDataset, Br35h_ROOT),
    'brain_mri': (BrainMRI_CLS_NAMES, BrainMRIDataset, BrainMRI_ROOT),
    'btad': (BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT),
    'clinicdb': (ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT),
    'colondb': (ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT),
    'dagm': (DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT),
    'dtd': (DTD_CLS_NAMES, DTDDataset, DTD_ROOT),
    'headct': (HEADCT_CLS_NAMES, HEADCTDataset, HEADCT_ROOT),
    'isic': (ISIC_CLS_NAMES, ISICDataset, ISIC_ROOT),
    'mpdd': (MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT),
    'mvtec': (MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT),
    'sdd': (SDD_CLS_NAMES, SDDDataset, SDD_ROOT),
    'tn3k': (TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT),
    'visa': (VISA_CLS_NAMES, VisaDataset, VISA_ROOT),
}

def get_data(dataset_type_list, transform, target_transform, training):
    if not isinstance(dataset_type_list, list):
        dataset_type_list = [dataset_type_list]

    dataset_cls_names_list = []
    dataset_instance_list = []
    dataset_root_list = []
    for dataset_type in dataset_type_list:
        if dataset_dict.get(dataset_type, ''):
            dataset_cls_names, dataset_instance, dataset_root = dataset_dict[dataset_type]
            dataset_instance = dataset_instance(
                clsnames=dataset_cls_names,
                transform=transform,
                target_transform=target_transform,
                training=training
            )

            dataset_cls_names_list.append(dataset_cls_names)
            dataset_instance_list.append(dataset_instance)
            dataset_root_list.append(dataset_root)

        else:
            print(f'Only support {list(dataset_dict.keys())}, but entered {dataset_type}...')
            raise NotImplementedError

    if len(dataset_type_list) > 1:
        dataset_instance = ConcatDataset(dataset_instance_list)
        dataset_cls_names = dataset_cls_names_list
        dataset_root = dataset_root_list
    else:
        dataset_instance = dataset_instance_list[0]
        dataset_cls_names = dataset_cls_names_list[0]
        dataset_root = dataset_root_list[0]

    return dataset_cls_names, dataset_instance, dataset_root