import torch import torch.nn as nn import albumentations as A from albumentations.pytorch import ToTensorV2 def get_transforms(means, stds): train_transforms = A.Compose( [ A.Normalize(mean=means, std=stds, always_apply=True), A.PadIfNeeded(min_height=36, min_width=36, always_apply=True), A.RandomCrop(height=32, width=32, always_apply=True), A.HorizontalFlip(), A.Cutout (fill_value=means), ToTensorV2(), ] ) test_transforms = A.Compose( [ A.Normalize(mean=means, std=stds, always_apply=True), ToTensorV2(), ] ) return(train_transforms, test_transforms)