geekyrakshit commited on
Commit
4cf5013
1 Parent(s): d7c1491

added resize option

Browse files
enhance_me/zero_dce/dataloader.py CHANGED
@@ -9,20 +9,31 @@ class UnpairedLowLightDataset:
9
  def __init__(
10
  self,
11
  image_size: int = 256,
 
12
  apply_random_horizontal_flip: bool = True,
13
  apply_random_vertical_flip: bool = True,
14
  apply_random_rotation: bool = True,
15
  ) -> None:
16
  self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size)
 
 
17
  self.apply_random_horizontal_flip = apply_random_horizontal_flip
18
  self.apply_random_vertical_flip = apply_random_vertical_flip
19
  self.apply_random_rotation = apply_random_rotation
20
 
 
 
 
21
  def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
22
  dataset = tf.data.Dataset.from_tensor_slices((images))
23
  dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
24
- dataset = dataset.map(
25
- self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
 
 
 
 
 
26
  )
27
  if is_train:
28
  dataset = (
 
9
  def __init__(
10
  self,
11
  image_size: int = 256,
12
+ apply_resize: bool = False,
13
  apply_random_horizontal_flip: bool = True,
14
  apply_random_vertical_flip: bool = True,
15
  apply_random_rotation: bool = True,
16
  ) -> None:
17
  self.augmentation_factory = UnpairedAugmentationFactory(image_size=image_size)
18
+ self.image_size = image_size
19
+ self.apply_resize = apply_resize
20
  self.apply_random_horizontal_flip = apply_random_horizontal_flip
21
  self.apply_random_vertical_flip = apply_random_vertical_flip
22
  self.apply_random_rotation = apply_random_rotation
23
 
24
+ def _resize(self, image):
25
+ return tf.image.resize(image, (self.image_size, self.image_size))
26
+
27
  def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
28
  dataset = tf.data.Dataset.from_tensor_slices((images))
29
  dataset = dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE)
30
+ dataset = (
31
+ dataset.map(
32
+ self.augmentation_factory.random_crop,
33
+ num_parallel_calls=tf.data.AUTOTUNE,
34
+ )
35
+ if not self.apply_resize
36
+ else dataset.map(self._resize, num_parallel_calls=tf.data.AUTOTUNE)
37
  )
38
  if is_train:
39
  dataset = (
enhance_me/zero_dce/zero_dce.py CHANGED
@@ -113,6 +113,7 @@ class ZeroDCE(Model):
113
  self,
114
  image_size: int = 256,
115
  dataset_label: str = "lol",
 
116
  apply_random_horizontal_flip: bool = True,
117
  apply_random_vertical_flip: bool = True,
118
  apply_random_rotation: bool = True,
@@ -123,6 +124,7 @@ class ZeroDCE(Model):
123
  (self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
124
  data_loader = UnpairedLowLightDataset(
125
  image_size,
 
126
  apply_random_horizontal_flip,
127
  apply_random_vertical_flip,
128
  apply_random_rotation,
@@ -130,7 +132,7 @@ class ZeroDCE(Model):
130
  self.train_dataset, self.val_dataset = data_loader.get_datasets(
131
  self.low_images, val_split, batch_size
132
  )
133
-
134
  def train(self, epochs: int):
135
  log_dir = os.path.join(
136
  self.experiment_name,
@@ -148,7 +150,7 @@ class ZeroDCE(Model):
148
  callbacks=callbacks,
149
  )
150
  return history
151
-
152
  def infer(self, original_image):
153
  image = keras.preprocessing.image.img_to_array(original_image)
154
  image = image.astype("float32") / 255.0
@@ -157,7 +159,7 @@ class ZeroDCE(Model):
157
  output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
158
  output_image = Image.fromarray(output_image.numpy())
159
  return output_image
160
-
161
  def infer_from_file(self, original_image_file: str):
162
  original_image = Image.open(original_image_file)
163
  return self.infer(original_image)
 
113
  self,
114
  image_size: int = 256,
115
  dataset_label: str = "lol",
116
+ apply_resize: bool = False,
117
  apply_random_horizontal_flip: bool = True,
118
  apply_random_vertical_flip: bool = True,
119
  apply_random_rotation: bool = True,
 
124
  (self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
125
  data_loader = UnpairedLowLightDataset(
126
  image_size,
127
+ apply_resize,
128
  apply_random_horizontal_flip,
129
  apply_random_vertical_flip,
130
  apply_random_rotation,
 
132
  self.train_dataset, self.val_dataset = data_loader.get_datasets(
133
  self.low_images, val_split, batch_size
134
  )
135
+
136
  def train(self, epochs: int):
137
  log_dir = os.path.join(
138
  self.experiment_name,
 
150
  callbacks=callbacks,
151
  )
152
  return history
153
+
154
  def infer(self, original_image):
155
  image = keras.preprocessing.image.img_to_array(original_image)
156
  image = image.astype("float32") / 255.0
 
159
  output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
160
  output_image = Image.fromarray(output_image.numpy())
161
  return output_image
162
+
163
  def infer_from_file(self, original_image_file: str):
164
  original_image = Image.open(original_image_file)
165
  return self.infer(original_image)
notebooks/enhance_me_train.ipynb CHANGED
@@ -190,6 +190,7 @@
190
  "experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n",
191
  "image_size = 256 # @param {type:\"integer\"}\n",
192
  "dataset_label = \"lol\" # @param [\"lol\"]\n",
 
193
  "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
194
  "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
195
  "apply_random_rotation = True # @param {type:\"boolean\"}\n",
@@ -223,6 +224,7 @@
223
  "zero_dce.build_datasets(\n",
224
  " image_size=image_size,\n",
225
  " dataset_label=dataset_label,\n",
 
226
  " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
227
  " apply_random_vertical_flip=apply_random_vertical_flip,\n",
228
  " apply_random_rotation=apply_random_rotation,\n",
 
190
  "experiment_name = \"lol_dataset_256\" # @param {type:\"string\"}\n",
191
  "image_size = 256 # @param {type:\"integer\"}\n",
192
  "dataset_label = \"lol\" # @param [\"lol\"]\n",
193
+ "apply_resize = False # @param {type:\"boolean\"}\n",
194
  "apply_random_horizontal_flip = True # @param {type:\"boolean\"}\n",
195
  "apply_random_vertical_flip = True # @param {type:\"boolean\"}\n",
196
  "apply_random_rotation = True # @param {type:\"boolean\"}\n",
 
224
  "zero_dce.build_datasets(\n",
225
  " image_size=image_size,\n",
226
  " dataset_label=dataset_label,\n",
227
+ " apply_resize=apply_resize,\n",
228
  " apply_random_horizontal_flip=apply_random_horizontal_flip,\n",
229
  " apply_random_vertical_flip=apply_random_vertical_flip,\n",
230
  " apply_random_rotation=apply_random_rotation,\n",