File size: 3,039 Bytes
2673600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c23af12
2673600
 
 
 
056188b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import tensorflow as tf


class AugmentationFactory:
    def __init__(self, image_size) -> None:
        self.image_size = image_size

    def random_crop(self, input_image, enhanced_image):
        input_image_shape = tf.shape(input_image)[:2]
        low_w = tf.random.uniform(
            shape=(), maxval=input_image_shape[1] - self.image_size + 1, dtype=tf.int32
        )
        low_h = tf.random.uniform(
            shape=(), maxval=input_image_shape[0] - self.image_size + 1, dtype=tf.int32
        )
        enhanced_w = low_w
        enhanced_h = low_h
        input_image_cropped = input_image[
            low_h : low_h + self.image_size, low_w : low_w + self.image_size
        ]
        enhanced_image_cropped = enhanced_image[
            enhanced_h : enhanced_h + self.image_size,
            enhanced_w : enhanced_w + self.image_size,
        ]
        return input_image_cropped, enhanced_image_cropped

    def random_horizontal_flip(sefl, input_image, enhanced_image):
        return tf.cond(
            tf.random.uniform(shape=(), maxval=1) < 0.5,
            lambda: (input_image, enhanced_image),
            lambda: (
                tf.image.flip_left_right(input_image),
                tf.image.flip_left_right(enhanced_image),
            ),
        )

    def random_vertical_flip(self, input_image, enhanced_image):
        return tf.cond(
            tf.random.uniform(shape=(), maxval=1) < 0.5,
            lambda: (input_image, enhanced_image),
            lambda: (
                tf.image.flip_up_down(input_image),
                tf.image.flip_up_down(enhanced_image),
            ),
        )

    def random_rotate(self, input_image, enhanced_image):
        condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
        return tf.image.rot90(input_image, condition), tf.image.rot90(
            enhanced_image, condition
        )


class UnpairedAugmentationFactory:
    def __init__(self, image_size) -> None:
        self.image_size = image_size

    def random_crop(self, image):
        image_shape = tf.shape(image)[:2]
        crop_w = tf.random.uniform(
            shape=(), maxval=image_shape[1] - self.image_size + 1, dtype=tf.int32
        )
        crop_h = tf.random.uniform(
            shape=(), maxval=image_shape[0] - self.image_size + 1, dtype=tf.int32
        )
        return image[
            crop_h : crop_h + self.image_size, crop_w : crop_w + self.image_size
        ]

    def random_horizontal_flip(self, image):
        return tf.cond(
            tf.random.uniform(shape=(), maxval=1) < 0.5,
            lambda: image,
            lambda: tf.image.flip_left_right(image),
        )

    def random_vertical_flip(self, image):
        return tf.cond(
            tf.random.uniform(shape=(), maxval=1) < 0.5,
            lambda: image,
            lambda: tf.image.flip_up_down(image),
        )

    def random_rotate(self, image):
        condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
        return tf.image.rot90(image, condition)