Sijuade commited on
Commit
08eb57c
1 Parent(s): cb1f74d

Create augment/augment.py

Browse files
Files changed (1) hide show
  1. augment/augment.py +27 -0
augment/augment.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import albumentations as A
4
+ from albumentations.pytorch import ToTensorV2
5
+
6
+
7
+
8
+ def get_transforms(means, stds):
9
+ train_transforms = A.Compose(
10
+ [
11
+ A.Normalize(mean=means, std=stds, always_apply=True),
12
+ A.PadIfNeeded(min_height=36, min_width=36, always_apply=True),
13
+ A.RandomCrop(height=32, width=32, always_apply=True),
14
+ A.HorizontalFlip(),
15
+ A.Cutout (fill_value=means),
16
+ ToTensorV2(),
17
+ ]
18
+ )
19
+
20
+ test_transforms = A.Compose(
21
+ [
22
+ A.Normalize(mean=means, std=stds, always_apply=True),
23
+ ToTensorV2(),
24
+ ]
25
+ )
26
+
27
+ return(train_transforms, test_transforms)