import tensorflow as tf class AugmentationFactory: def __init__(self, image_size) -> None: self.image_size = image_size def random_crop(self, input_image, enhanced_image): input_image_shape = tf.shape(input_image)[:2] low_w = tf.random.uniform( shape=(), maxval=input_image_shape[1] - self.image_size + 1, dtype=tf.int32 ) low_h = tf.random.uniform( shape=(), maxval=input_image_shape[0] - self.image_size + 1, dtype=tf.int32 ) enhanced_w = low_w enhanced_h = low_h input_image_cropped = input_image[ low_h : low_h + self.image_size, low_w : low_w + self.image_size ] enhanced_image_cropped = enhanced_image[ enhanced_h : enhanced_h + self.image_size, enhanced_w : enhanced_w + self.image_size, ] return input_image_cropped, enhanced_image_cropped def random_horizontal_flip(sefl, input_image, enhanced_image): return tf.cond( tf.random.uniform(shape=(), maxval=1) < 0.5, lambda: (input_image, enhanced_image), lambda: ( tf.image.flip_left_right(input_image), tf.image.flip_left_right(enhanced_image), ), ) def random_vertical_flip(self, input_image, enhanced_image): return tf.cond( tf.random.uniform(shape=(), maxval=1) < 0.5, lambda: (input_image, enhanced_image), lambda: ( tf.image.flip_up_down(input_image), tf.image.flip_up_down(enhanced_image), ), ) def random_rotate(input_image, enhanced_image): condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32) return tf.image.rot90(input_image, condition), tf.image.rot90( enhanced_image, condition )