File size: 1,944 Bytes
117183e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import random
import torchvision.transforms.functional as F
from torchvision import transforms

class RandomCropPair:
    def __init__(self, size):
        self.size = size

    def __call__(self, img1, img2):
        i, j, h, w = transforms.RandomCrop.get_params(img1, self.size)
        img1 = F.crop(img1, i, j, h, w)
        img2 = F.crop(img2, i, j, h, w)
        return img1, img2

class ResizePair:
    def __init__(self, size):
        self.size = size

    def __call__(self, img1, img2):
        # antialias=True is used to avoid torchvision warning
        img1 = F.resize(img1, self.size, antialias=True)
        img2 = F.resize(img2, self.size, antialias=True)
        return img1, img2

class RandomHorizontalFlipPair:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img1, img2):
        if random.random() < self.p:
            img1 = F.hflip(img1)
            img2 = F.hflip(img2)
        return img1, img2

class RandomVerticalFlipPair:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img1, img2):
        if random.random() < self.p:
            img1 = F.vflip(img1)
            img2 = F.vflip(img2)
        return img1, img2

def get_transforms(transforms_config):
    transform_list = []
    for transform in transforms_config:
        transform_type = transform['type']
        params = transform['params']
        if transform_type == 'RandomCrop':
            transform_list.append(RandomCropPair(**params))
        elif transform_type == 'Resize':
            transform_list.append(ResizePair(**params))
        elif transform_type == 'RandomHorizontalFlip':
            transform_list.append(RandomHorizontalFlipPair(**params))
        elif transform_type == 'RandomVerticalFlip':
            transform_list.append(RandomVerticalFlipPair(**params))
        else:
            raise ValueError(f"Unsupported transform type: {transform_type}")

    return transform_list