geekyrakshit commited on
Commit
2673600
·
1 Parent(s): dac55f5

added dataloader

Browse files
enhance_me/__init__.py ADDED
File without changes
enhance_me/augmentation.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+
4
+ class AugmentationFactory:
5
+ def __init__(self, image_size) -> None:
6
+ self.image_size = image_size
7
+
8
+ def random_crop(self, input_image, enhanced_image):
9
+ input_image_shape = tf.shape(input_image)[:2]
10
+ low_w = tf.random.uniform(
11
+ shape=(), maxval=input_image_shape[1] - self.image_size + 1, dtype=tf.int32
12
+ )
13
+ low_h = tf.random.uniform(
14
+ shape=(), maxval=input_image_shape[0] - self.image_size + 1, dtype=tf.int32
15
+ )
16
+ enhanced_w = low_w
17
+ enhanced_h = low_h
18
+ input_image_cropped = input_image[
19
+ low_h : low_h + self.image_size, low_w : low_w + self.image_size
20
+ ]
21
+ enhanced_image_cropped = enhanced_image[
22
+ enhanced_h : enhanced_h + self.image_size,
23
+ enhanced_w : enhanced_w + self.image_size,
24
+ ]
25
+ return input_image_cropped, enhanced_image_cropped
26
+
27
+ def random_horizontal_flip(sefl, input_image, enhanced_image):
28
+ return tf.cond(
29
+ tf.random.uniform(shape=(), maxval=1) < 0.5,
30
+ lambda: (input_image, enhanced_image),
31
+ lambda: (
32
+ tf.image.flip_left_right(input_image),
33
+ tf.image.flip_left_right(enhanced_image),
34
+ ),
35
+ )
36
+
37
+ def random_vertical_flip(self, input_image, enhanced_image):
38
+ return tf.cond(
39
+ tf.random.uniform(shape=(), maxval=1) < 0.5,
40
+ lambda: (input_image, enhanced_image),
41
+ lambda: (
42
+ tf.image.flip_up_down(input_image),
43
+ tf.image.flip_up_down(enhanced_image),
44
+ ),
45
+ )
46
+
47
+ def random_rotate(input_image, enhanced_image):
48
+ condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
49
+ return tf.image.rot90(input_image, condition), tf.image.rot90(
50
+ enhanced_image, condition
51
+ )
enhance_me/commons.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+
4
+ def read_image(image_path):
5
+ image = tf.io.read_file(image_path)
6
+ image = tf.image.decode_png(image, channels=3)
7
+ image.set_shape([None, None, 3])
8
+ image = tf.cast(image, dtype=tf.float32) / 255.0
9
+ return image
enhance_me/mirnet/__init__.py ADDED
File without changes
enhance_me/mirnet/dataloader.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from typing import List
3
+
4
+ from ..commons import read_image
5
+ from ..augmentation import AugmentationFactory
6
+
7
+
8
+ class LowLightDataset:
9
+ def __init__(self, image_size: int = 256) -> None:
10
+ self.augmentation_factory = AugmentationFactory(image_size=image_size)
11
+
12
+ def load_data(self, low_light_image_path, enhanced_image_path):
13
+ low_light_image = read_image(low_light_image_path)
14
+ enhanced_image = read_image(enhanced_image_path)
15
+ low_light_image, enhanced_image = self.augmentation_factory.random_crop(
16
+ low_light_image, enhanced_image
17
+ )
18
+ return low_light_image, enhanced_image
19
+
20
+ def get_dataset(
21
+ self,
22
+ low_light_images: List[str],
23
+ enhanced_images: List[str],
24
+ batch_size: int = 16,
25
+ ):
26
+ dataset = tf.data.Dataset.from_tensor_slices(
27
+ (low_light_images, enhanced_images)
28
+ )
29
+ dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE)
30
+ dataset = dataset.batch(batch_size, drop_remainder=True)
31
+ return dataset
notebooks/.gitkeep ADDED
File without changes