|
import torch |
|
import numpy as np |
|
import torch.utils |
|
import torch.utils.data |
|
from torch.utils.data.sampler import WeightedRandomSampler |
|
import torch.distributed as dist |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
|
|
from .datasets import RealFakeDataset |
|
|
|
|
|
|
|
def get_bal_sampler(dataset): |
|
targets = [] |
|
for d in dataset.datasets: |
|
targets.extend(d.targets) |
|
|
|
ratio = np.bincount(targets) |
|
w = 1. / torch.tensor(ratio, dtype=torch.float) |
|
sample_weights = w[targets] |
|
sampler = WeightedRandomSampler(weights=sample_weights, |
|
num_samples=len(sample_weights)) |
|
return sampler |
|
|
|
|
|
def create_train_val_dataloader(opt, clip_model, transform, k_split: float): |
|
shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False |
|
|
|
dataset = RealFakeDataset(opt, clip_model, transform) |
|
|
|
|
|
dataset_size = len(dataset) |
|
train_size = int(dataset_size * k_split) |
|
val_size = dataset_size - train_size |
|
|
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=False, |
|
num_workers=16 |
|
) |
|
val_loader = torch.utils.data.DataLoader(val_dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=False, |
|
num_workers=16 |
|
) |
|
|
|
return train_loader, val_loader |
|
|
|
|
|
def create_test_dataloader(opt, clip_model, transform): |
|
shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False |
|
|
|
dataset = RealFakeDataset(opt, clip_model, transform) |
|
|
|
sampler = get_bal_sampler(dataset) if opt.class_bal else None |
|
|
|
|
|
data_loader = torch.utils.data.DataLoader(dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=shuffle, |
|
sampler=sampler, |
|
num_workers=16 |
|
) |
|
return data_loader |
|
|