import torch import numpy as np import torch.utils import torch.utils.data from torch.utils.data.sampler import WeightedRandomSampler import torch.distributed as dist from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from .datasets import RealFakeDataset def get_bal_sampler(dataset): targets = [] for d in dataset.datasets: targets.extend(d.targets) ratio = np.bincount(targets) w = 1. / torch.tensor(ratio, dtype=torch.float) sample_weights = w[targets] sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights)) return sampler def create_train_val_dataloader(opt, clip_model, transform, k_split: float): shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False dataset = RealFakeDataset(opt, clip_model, transform) # 划分训练集和验证集 dataset_size = len(dataset) train_size = int(dataset_size * k_split) val_size = dataset_size - train_size train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=16 ) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=16 ) return train_loader, val_loader def create_test_dataloader(opt, clip_model, transform): shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False dataset = RealFakeDataset(opt, clip_model, transform) sampler = get_bal_sampler(dataset) if opt.class_bal else None data_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=shuffle, sampler=sampler, num_workers=16 ) return data_loader