S12 / augment /augment.py
Sijuade's picture
Create augment/augment.py
08eb57c
raw
history blame
674 Bytes
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
def get_transforms(means, stds):
train_transforms = A.Compose(
[
A.Normalize(mean=means, std=stds, always_apply=True),
A.PadIfNeeded(min_height=36, min_width=36, always_apply=True),
A.RandomCrop(height=32, width=32, always_apply=True),
A.HorizontalFlip(),
A.Cutout (fill_value=means),
ToTensorV2(),
]
)
test_transforms = A.Compose(
[
A.Normalize(mean=means, std=stds, always_apply=True),
ToTensorV2(),
]
)
return(train_transforms, test_transforms)