geekyrakshit commited on
Commit
d7c1491
1 Parent(s): 1150923

updated unpaired dataset

Browse files
Files changed (1) hide show
  1. enhance_me/zero_dce/dataloader.py +2 -7
enhance_me/zero_dce/dataloader.py CHANGED
@@ -1,6 +1,7 @@
1
  import tensorflow as tf
2
  from typing import List
3
 
 
4
  from ..augmentation import UnpairedAugmentationFactory
5
 
6
 
@@ -17,15 +18,9 @@ class UnpairedLowLightDataset:
17
  self.apply_random_vertical_flip = apply_random_vertical_flip
18
  self.apply_random_rotation = apply_random_rotation
19
 
20
- def _load_data(self, image_path):
21
- image = tf.io.read_file(image_path)
22
- image = tf.image.decode_png(image, channels=3)
23
- image = tf.cast(image, dtype=tf.float32) / 255.0
24
- return image
25
-
26
  def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
27
  dataset = tf.data.Dataset.from_tensor_slices((images))
28
- dataset = dataset.map(self._load_data, num_parallel_calls=tf.data.AUTOTUNE)
29
  dataset = dataset.map(
30
  self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
31
  )
 
1
  import tensorflow as tf
2
  from typing import List
3
 
4
+ from ..commons import read_image
5
  from ..augmentation import UnpairedAugmentationFactory
6
 
7
 
 
18
  self.apply_random_vertical_flip = apply_random_vertical_flip
19
  self.apply_random_rotation = apply_random_rotation
20
 
 
 
 
 
 
 
21
  def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
22
  dataset = tf.data.Dataset.from_tensor_slices((images))
23
+ dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
24
  dataset = dataset.map(
25
  self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
26
  )