File size: 674 Bytes
08eb57c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
def get_transforms(means, stds):
train_transforms = A.Compose(
[
A.Normalize(mean=means, std=stds, always_apply=True),
A.PadIfNeeded(min_height=36, min_width=36, always_apply=True),
A.RandomCrop(height=32, width=32, always_apply=True),
A.HorizontalFlip(),
A.Cutout (fill_value=means),
ToTensorV2(),
]
)
test_transforms = A.Compose(
[
A.Normalize(mean=means, std=stds, always_apply=True),
ToTensorV2(),
]
)
return(train_transforms, test_transforms) |