Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import torch.utils | |
import torch.utils.data | |
from torch.utils.data.sampler import WeightedRandomSampler | |
import torch.distributed as dist | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
from .datasets import RealFakeDataset | |
def get_bal_sampler(dataset): | |
targets = [] | |
for d in dataset.datasets: | |
targets.extend(d.targets) | |
ratio = np.bincount(targets) | |
w = 1. / torch.tensor(ratio, dtype=torch.float) | |
sample_weights = w[targets] | |
sampler = WeightedRandomSampler(weights=sample_weights, | |
num_samples=len(sample_weights)) | |
return sampler | |
def create_train_val_dataloader(opt, clip_model, transform, k_split: float): | |
shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False | |
dataset = RealFakeDataset(opt, clip_model, transform) | |
# ๅๅ่ฎญ็ป้ๅ้ช่ฏ้ | |
dataset_size = len(dataset) | |
train_size = int(dataset_size * k_split) | |
val_size = dataset_size - train_size | |
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) | |
train_loader = torch.utils.data.DataLoader(train_dataset, | |
batch_size=opt.batch_size, | |
shuffle=False, | |
num_workers=16 | |
) | |
val_loader = torch.utils.data.DataLoader(val_dataset, | |
batch_size=opt.batch_size, | |
shuffle=False, | |
num_workers=16 | |
) | |
return train_loader, val_loader | |
def create_test_dataloader(opt, clip_model, transform): | |
shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False | |
dataset = RealFakeDataset(opt, clip_model, transform) | |
sampler = get_bal_sampler(dataset) if opt.class_bal else None | |
data_loader = torch.utils.data.DataLoader(dataset, | |
batch_size=opt.batch_size, | |
shuffle=shuffle, | |
sampler=sampler, | |
num_workers=16 | |
) | |
return data_loader | |