ybbwcwaps
AI Video
3cc4a06
raw
history blame
2.45 kB
import torch
import numpy as np
import torch.utils
import torch.utils.data
from torch.utils.data.sampler import WeightedRandomSampler
import torch.distributed as dist
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from .datasets import RealFakeDataset
def get_bal_sampler(dataset):
targets = []
for d in dataset.datasets:
targets.extend(d.targets)
ratio = np.bincount(targets)
w = 1. / torch.tensor(ratio, dtype=torch.float)
sample_weights = w[targets]
sampler = WeightedRandomSampler(weights=sample_weights,
num_samples=len(sample_weights))
return sampler
def create_train_val_dataloader(opt, clip_model, transform, k_split: float):
shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False
dataset = RealFakeDataset(opt, clip_model, transform)
# ๅˆ’ๅˆ†่ฎญ็ปƒ้›†ๅ’Œ้ชŒ่ฏ้›†
dataset_size = len(dataset)
train_size = int(dataset_size * k_split)
val_size = dataset_size - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=opt.batch_size,
shuffle=False,
num_workers=16
)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=opt.batch_size,
shuffle=False,
num_workers=16
)
return train_loader, val_loader
def create_test_dataloader(opt, clip_model, transform):
shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False
dataset = RealFakeDataset(opt, clip_model, transform)
sampler = get_bal_sampler(dataset) if opt.class_bal else None
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=opt.batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=16
)
return data_loader