Spaces:
Runtime error
Runtime error
File size: 3,039 Bytes
2673600 c23af12 2673600 056188b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import tensorflow as tf
class AugmentationFactory:
def __init__(self, image_size) -> None:
self.image_size = image_size
def random_crop(self, input_image, enhanced_image):
input_image_shape = tf.shape(input_image)[:2]
low_w = tf.random.uniform(
shape=(), maxval=input_image_shape[1] - self.image_size + 1, dtype=tf.int32
)
low_h = tf.random.uniform(
shape=(), maxval=input_image_shape[0] - self.image_size + 1, dtype=tf.int32
)
enhanced_w = low_w
enhanced_h = low_h
input_image_cropped = input_image[
low_h : low_h + self.image_size, low_w : low_w + self.image_size
]
enhanced_image_cropped = enhanced_image[
enhanced_h : enhanced_h + self.image_size,
enhanced_w : enhanced_w + self.image_size,
]
return input_image_cropped, enhanced_image_cropped
def random_horizontal_flip(sefl, input_image, enhanced_image):
return tf.cond(
tf.random.uniform(shape=(), maxval=1) < 0.5,
lambda: (input_image, enhanced_image),
lambda: (
tf.image.flip_left_right(input_image),
tf.image.flip_left_right(enhanced_image),
),
)
def random_vertical_flip(self, input_image, enhanced_image):
return tf.cond(
tf.random.uniform(shape=(), maxval=1) < 0.5,
lambda: (input_image, enhanced_image),
lambda: (
tf.image.flip_up_down(input_image),
tf.image.flip_up_down(enhanced_image),
),
)
def random_rotate(self, input_image, enhanced_image):
condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
return tf.image.rot90(input_image, condition), tf.image.rot90(
enhanced_image, condition
)
class UnpairedAugmentationFactory:
def __init__(self, image_size) -> None:
self.image_size = image_size
def random_crop(self, image):
image_shape = tf.shape(image)[:2]
crop_w = tf.random.uniform(
shape=(), maxval=image_shape[1] - self.image_size + 1, dtype=tf.int32
)
crop_h = tf.random.uniform(
shape=(), maxval=image_shape[0] - self.image_size + 1, dtype=tf.int32
)
return image[
crop_h : crop_h + self.image_size, crop_w : crop_w + self.image_size
]
def random_horizontal_flip(self, image):
return tf.cond(
tf.random.uniform(shape=(), maxval=1) < 0.5,
lambda: image,
lambda: tf.image.flip_left_right(image),
)
def random_vertical_flip(self, image):
return tf.cond(
tf.random.uniform(shape=(), maxval=1) < 0.5,
lambda: image,
lambda: tf.image.flip_up_down(image),
)
def random_rotate(self, image):
condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
return tf.image.rot90(image, condition)
|