import copy import logging import os import os.path as osp from os.path import join import torch from torch.utils.data import ConcatDataset, DataLoader from utils.optimizer import create_optimizer from utils.scheduler import create_scheduler logger = logging.getLogger(__name__) def get_media_types(datasources): """get the media types for for all the dataloaders. Args: datasources (List): List of dataloaders or datasets. Returns: List. The media_types. """ if isinstance(datasources[0], DataLoader): datasets = [dataloader.dataset for dataloader in datasources] else: datasets = datasources media_types = [ dataset.datasets[0].media_type if isinstance(dataset, ConcatDataset) else dataset.media_type for dataset in datasets ] return media_types