pllava-13b-demo / tasks /shared_utils.py
cathyxl
added
f239efc
raw
history blame
849 Bytes
import copy
import logging
import os
import os.path as osp
from os.path import join
import torch
from torch.utils.data import ConcatDataset, DataLoader
from utils.optimizer import create_optimizer
from utils.scheduler import create_scheduler
logger = logging.getLogger(__name__)
def get_media_types(datasources):
"""get the media types for for all the dataloaders.
Args:
datasources (List): List of dataloaders or datasets.
Returns: List. The media_types.
"""
if isinstance(datasources[0], DataLoader):
datasets = [dataloader.dataset for dataloader in datasources]
else:
datasets = datasources
media_types = [
dataset.datasets[0].media_type
if isinstance(dataset, ConcatDataset)
else dataset.media_type
for dataset in datasets
]
return media_types