geekyrakshit commited on
Commit
295bcab
β€’
1 Parent(s): 0f84baa

updated zero-dce model

Browse files
enhance_me/zero_dce/{models/dce_net.py β†’ dce_net.py} RENAMED
File without changes
enhance_me/zero_dce/models/__init__.py DELETED
File without changes
enhance_me/zero_dce/{models/zero_dce.py β†’ zero_dce.py} RENAMED
@@ -1,19 +1,33 @@
 
 
 
 
 
1
  import tensorflow as tf
 
2
  from tensorflow.keras import optimizers, Model
 
3
 
4
  from .dce_net import build_dce_net
5
- from ..dataloader import UnpairedLowLightDataset
6
- from ..losses import (
7
  color_constancy_loss,
8
  exposure_loss,
9
  illumination_smoothness_loss,
10
  SpatialConsistencyLoss,
11
  )
 
12
 
13
 
14
  class ZeroDCE(Model):
15
- def __init__(self, **kwargs):
16
  super(ZeroDCE, self).__init__(**kwargs)
 
 
 
 
 
 
17
  self.dce_model = build_dce_net()
18
 
19
  def compile(self, learning_rate, **kwargs):
@@ -94,3 +108,64 @@ class ZeroDCE(Model):
94
  skip_mismatch=skip_mismatch,
95
  options=options,
96
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from datetime import datetime
5
+
6
  import tensorflow as tf
7
+ from tensorflow import keras
8
  from tensorflow.keras import optimizers, Model
9
+ from wandb.keras import WandbCallback
10
 
11
  from .dce_net import build_dce_net
12
+ from .dataloader import UnpairedLowLightDataset
13
+ from .losses import (
14
  color_constancy_loss,
15
  exposure_loss,
16
  illumination_smoothness_loss,
17
  SpatialConsistencyLoss,
18
  )
19
+ from ..commons import download_lol_dataset, init_wandb
20
 
21
 
22
  class ZeroDCE(Model):
23
+ def __init__(self, experiment_name=None, wandb_api_key=None, **kwargs):
24
  super(ZeroDCE, self).__init__(**kwargs)
25
+ self.experiment_name = experiment_name
26
+ if wandb_api_key is not None:
27
+ init_wandb("mirnet", experiment_name, wandb_api_key)
28
+ self.using_wandb = True
29
+ else:
30
+ self.using_wandb = False
31
  self.dce_model = build_dce_net()
32
 
33
  def compile(self, learning_rate, **kwargs):
 
108
  skip_mismatch=skip_mismatch,
109
  options=options,
110
  )
111
+
112
+ def build_datasets(
113
+ self,
114
+ image_size: int = 256,
115
+ dataset_label: str = "lol",
116
+ apply_random_horizontal_flip: bool = True,
117
+ apply_random_vertical_flip: bool = True,
118
+ apply_random_rotation: bool = True,
119
+ val_split: float = 0.2,
120
+ batch_size: int = 16,
121
+ ) -> None:
122
+ if dataset_label == "lol":
123
+ (self.low_images, _), (self.test_low_images, _) = download_lol_dataset()
124
+ data_loader = UnpairedLowLightDataset(
125
+ image_size,
126
+ apply_random_horizontal_flip,
127
+ apply_random_vertical_flip,
128
+ apply_random_rotation,
129
+ )
130
+ self.train_dataset, self.val_dataset = data_loader.get_datasets(
131
+ self.low_images, val_split, batch_size
132
+ )
133
+
134
+ def train(self, epochs: int):
135
+ log_dir = os.path.join(
136
+ self.experiment_name,
137
+ "logs",
138
+ datetime.now().strftime("%Y%m%d-%H%M%S"),
139
+ )
140
+ tensorboard_callback = keras.callbacks.TensorBoard(log_dir, histogram_freq=1)
141
+ model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
142
+ os.path.join(self.experiment_name, "weights.h5"),
143
+ save_best_only=True,
144
+ save_weights_only=True,
145
+ )
146
+ callbacks = [
147
+ tensorboard_callback,
148
+ model_checkpoint_callback
149
+ ]
150
+ if self.using_wandb:
151
+ callbacks += [WandbCallback()]
152
+ history = self.model.fit(
153
+ self.train_dataset,
154
+ validation_data=self.val_dataset,
155
+ epochs=epochs,
156
+ callbacks=callbacks,
157
+ )
158
+ return history
159
+
160
+ def infer(self, original_image):
161
+ image = keras.preprocessing.image.img_to_array(original_image)
162
+ image = image.astype("float32") / 255.0
163
+ image = np.expand_dims(image, axis=0)
164
+ output_image = self.call(image)
165
+ output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
166
+ output_image = Image.fromarray(output_image.numpy())
167
+ return output_image
168
+
169
+ def infer_from_file(self, original_image_file: str):
170
+ original_image = Image.open(original_image_file)
171
+ return self.infer(original_image)