Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,944 Bytes
117183e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import random
import torchvision.transforms.functional as F
from torchvision import transforms
class RandomCropPair:
def __init__(self, size):
self.size = size
def __call__(self, img1, img2):
i, j, h, w = transforms.RandomCrop.get_params(img1, self.size)
img1 = F.crop(img1, i, j, h, w)
img2 = F.crop(img2, i, j, h, w)
return img1, img2
class ResizePair:
def __init__(self, size):
self.size = size
def __call__(self, img1, img2):
# antialias=True is used to avoid torchvision warning
img1 = F.resize(img1, self.size, antialias=True)
img2 = F.resize(img2, self.size, antialias=True)
return img1, img2
class RandomHorizontalFlipPair:
def __init__(self, p=0.5):
self.p = p
def __call__(self, img1, img2):
if random.random() < self.p:
img1 = F.hflip(img1)
img2 = F.hflip(img2)
return img1, img2
class RandomVerticalFlipPair:
def __init__(self, p=0.5):
self.p = p
def __call__(self, img1, img2):
if random.random() < self.p:
img1 = F.vflip(img1)
img2 = F.vflip(img2)
return img1, img2
def get_transforms(transforms_config):
transform_list = []
for transform in transforms_config:
transform_type = transform['type']
params = transform['params']
if transform_type == 'RandomCrop':
transform_list.append(RandomCropPair(**params))
elif transform_type == 'Resize':
transform_list.append(ResizePair(**params))
elif transform_type == 'RandomHorizontalFlip':
transform_list.append(RandomHorizontalFlipPair(**params))
elif transform_type == 'RandomVerticalFlip':
transform_list.append(RandomVerticalFlipPair(**params))
else:
raise ValueError(f"Unsupported transform type: {transform_type}")
return transform_list |