Spaces:
Running
Running
from .mvtec import MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT | |
from .visa import VISA_CLS_NAMES, VisaDataset, VISA_ROOT | |
from .mpdd import MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT | |
from .btad import BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT | |
from .sdd import SDD_CLS_NAMES, SDDDataset, SDD_ROOT | |
from .dagm import DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT | |
from .dtd import DTD_CLS_NAMES,DTDDataset,DTD_ROOT | |
from .isic import ISIC_CLS_NAMES,ISICDataset,ISIC_ROOT | |
from .colondb import ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT | |
from .clinicdb import ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT | |
from .tn3k import TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT | |
from .headct import HEADCT_CLS_NAMES,HEADCTDataset,HEADCT_ROOT | |
from .brain_mri import BrainMRI_CLS_NAMES,BrainMRIDataset,BrainMRI_ROOT | |
from .br35h import Br35h_CLS_NAMES,Br35hDataset,Br35h_ROOT | |
from torch.utils.data import ConcatDataset | |
dataset_dict = { | |
'br35h': (Br35h_CLS_NAMES, Br35hDataset, Br35h_ROOT), | |
'brain_mri': (BrainMRI_CLS_NAMES, BrainMRIDataset, BrainMRI_ROOT), | |
'btad': (BTAD_CLS_NAMES, BTADDataset, BTAD_ROOT), | |
'clinicdb': (ClinicDB_CLS_NAMES, ClinicDBDataset, ClinicDB_ROOT), | |
'colondb': (ColonDB_CLS_NAMES, ColonDBDataset, ColonDB_ROOT), | |
'dagm': (DAGM_CLS_NAMES, DAGMDataset, DAGM_ROOT), | |
'dtd': (DTD_CLS_NAMES, DTDDataset, DTD_ROOT), | |
'headct': (HEADCT_CLS_NAMES, HEADCTDataset, HEADCT_ROOT), | |
'isic': (ISIC_CLS_NAMES, ISICDataset, ISIC_ROOT), | |
'mpdd': (MPDD_CLS_NAMES, MPDDDataset, MPDD_ROOT), | |
'mvtec': (MVTEC_CLS_NAMES, MVTecDataset, MVTEC_ROOT), | |
'sdd': (SDD_CLS_NAMES, SDDDataset, SDD_ROOT), | |
'tn3k': (TN3K_CLS_NAMES, TN3KDataset, TN3K_ROOT), | |
'visa': (VISA_CLS_NAMES, VisaDataset, VISA_ROOT), | |
} | |
def get_data(dataset_type_list, transform, target_transform, training): | |
if not isinstance(dataset_type_list, list): | |
dataset_type_list = [dataset_type_list] | |
dataset_cls_names_list = [] | |
dataset_instance_list = [] | |
dataset_root_list = [] | |
for dataset_type in dataset_type_list: | |
if dataset_dict.get(dataset_type, ''): | |
dataset_cls_names, dataset_instance, dataset_root = dataset_dict[dataset_type] | |
dataset_instance = dataset_instance( | |
clsnames=dataset_cls_names, | |
transform=transform, | |
target_transform=target_transform, | |
training=training | |
) | |
dataset_cls_names_list.append(dataset_cls_names) | |
dataset_instance_list.append(dataset_instance) | |
dataset_root_list.append(dataset_root) | |
else: | |
print(f'Only support {list(dataset_dict.keys())}, but entered {dataset_type}...') | |
raise NotImplementedError | |
if len(dataset_type_list) > 1: | |
dataset_instance = ConcatDataset(dataset_instance_list) | |
dataset_cls_names = dataset_cls_names_list | |
dataset_root = dataset_root_list | |
else: | |
dataset_instance = dataset_instance_list[0] | |
dataset_cls_names = dataset_cls_names_list[0] | |
dataset_root = dataset_root_list[0] | |
return dataset_cls_names, dataset_instance, dataset_root |