geekyrakshit commited on
Commit
865788c
1 Parent(s): 2673600

updated low light dataloader

Browse files
Files changed (1) hide show
  1. enhance_me/mirnet/dataloader.py +63 -2
enhance_me/mirnet/dataloader.py CHANGED
@@ -6,8 +6,17 @@ from ..augmentation import AugmentationFactory
6
 
7
 
8
  class LowLightDataset:
9
- def __init__(self, image_size: int = 256) -> None:
 
 
 
 
 
 
10
  self.augmentation_factory = AugmentationFactory(image_size=image_size)
 
 
 
11
 
12
  def load_data(self, low_light_image_path, enhanced_image_path):
13
  low_light_image = read_image(low_light_image_path)
@@ -17,15 +26,67 @@ class LowLightDataset:
17
  )
18
  return low_light_image, enhanced_image
19
 
20
- def get_dataset(
21
  self,
22
  low_light_images: List[str],
23
  enhanced_images: List[str],
24
  batch_size: int = 16,
 
25
  ):
26
  dataset = tf.data.Dataset.from_tensor_slices(
27
  (low_light_images, enhanced_images)
28
  )
29
  dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  dataset = dataset.batch(batch_size, drop_remainder=True)
31
  return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class LowLightDataset:
9
+ def __init__(
10
+ self,
11
+ image_size: int = 256,
12
+ apply_random_horizontal_flip: bool = True,
13
+ apply_random_vertical_flip: bool = True,
14
+ apply_random_rotation: bool = True,
15
+ ) -> None:
16
  self.augmentation_factory = AugmentationFactory(image_size=image_size)
17
+ self.apply_random_horizontal_flip = apply_random_horizontal_flip
18
+ self.apply_random_vertical_flip = apply_random_vertical_flip
19
+ self.apply_random_rotation = apply_random_rotation
20
 
21
  def load_data(self, low_light_image_path, enhanced_image_path):
22
  low_light_image = read_image(low_light_image_path)
 
26
  )
27
  return low_light_image, enhanced_image
28
 
29
+ def _get_dataset(
30
  self,
31
  low_light_images: List[str],
32
  enhanced_images: List[str],
33
  batch_size: int = 16,
34
+ is_train: bool = True,
35
  ):
36
  dataset = tf.data.Dataset.from_tensor_slices(
37
  (low_light_images, enhanced_images)
38
  )
39
  dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE)
40
+ dataset = dataset.map(
41
+ self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
42
+ )
43
+ if is_train:
44
+ dataset = (
45
+ dataset.map(
46
+ self.augmentation_factory.random_horizontal_flip,
47
+ num_parallel_calls=tf.data.AUTOTUNE,
48
+ )
49
+ if self.apply_random_horizontal_flip
50
+ else dataset
51
+ )
52
+ dataset = (
53
+ dataset.map(
54
+ self.augmentation_factory.random_vertical_flip,
55
+ num_parallel_calls=tf.data.AUTOTUNE,
56
+ )
57
+ if self.apply_random_vertical_flip
58
+ else dataset
59
+ )
60
+ dataset = (
61
+ dataset.map(
62
+ self.augmentation_factory.random_rotate,
63
+ num_parallel_calls=tf.data.AUTOTUNE,
64
+ )
65
+ if self.apply_random_rotation
66
+ else dataset
67
+ )
68
  dataset = dataset.batch(batch_size, drop_remainder=True)
69
  return dataset
70
+
71
+ def get_datasets(
72
+ self,
73
+ low_light_images: List[str],
74
+ enhanced_images: List[str],
75
+ val_split: float = 0.2,
76
+ batch_size: int = 16,
77
+ ):
78
+ assert len(low_light_images) == len(enhanced_images)
79
+ split_index = int(len(low_light_images) * (1 - val_split))
80
+ train_low_light_images = low_light_images[:split_index]
81
+ train_enhanced_images = enhanced_images[:split_index]
82
+ val_low_light_images = low_light_images[split_index:]
83
+ val_enhanced_images = enhanced_images[split_index:]
84
+ print(f"Number of train data points: {len(train_low_light_images)}")
85
+ print(f"Number of validation data points: {len(val_low_light_images)}")
86
+ train_dataset = self._get_dataset(
87
+ train_low_light_images, train_enhanced_images, batch_size, is_train=True
88
+ )
89
+ val_dataset = self._get_dataset(
90
+ val_low_light_images, val_enhanced_images, batch_size, is_train=False
91
+ )
92
+ return train_dataset, val_dataset