enhance-me / enhance_me /augmentation.py
geekyrakshit's picture
added dataloader
2673600
raw
history blame
No virus
1.88 kB
import tensorflow as tf
class AugmentationFactory:
def __init__(self, image_size) -> None:
self.image_size = image_size
def random_crop(self, input_image, enhanced_image):
input_image_shape = tf.shape(input_image)[:2]
low_w = tf.random.uniform(
shape=(), maxval=input_image_shape[1] - self.image_size + 1, dtype=tf.int32
)
low_h = tf.random.uniform(
shape=(), maxval=input_image_shape[0] - self.image_size + 1, dtype=tf.int32
)
enhanced_w = low_w
enhanced_h = low_h
input_image_cropped = input_image[
low_h : low_h + self.image_size, low_w : low_w + self.image_size
]
enhanced_image_cropped = enhanced_image[
enhanced_h : enhanced_h + self.image_size,
enhanced_w : enhanced_w + self.image_size,
]
return input_image_cropped, enhanced_image_cropped
def random_horizontal_flip(sefl, input_image, enhanced_image):
return tf.cond(
tf.random.uniform(shape=(), maxval=1) < 0.5,
lambda: (input_image, enhanced_image),
lambda: (
tf.image.flip_left_right(input_image),
tf.image.flip_left_right(enhanced_image),
),
)
def random_vertical_flip(self, input_image, enhanced_image):
return tf.cond(
tf.random.uniform(shape=(), maxval=1) < 0.5,
lambda: (input_image, enhanced_image),
lambda: (
tf.image.flip_up_down(input_image),
tf.image.flip_up_down(enhanced_image),
),
)
def random_rotate(input_image, enhanced_image):
condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
return tf.image.rot90(input_image, condition), tf.image.rot90(
enhanced_image, condition
)