geekyrakshit commited on
Commit
0f84baa
1 Parent(s): 659a217

added zero-dce model

Browse files
enhance_me/zero_dce/dataloader.py CHANGED
@@ -17,7 +17,7 @@ class UnpairedLowLightDataset:
17
  self.apply_random_vertical_flip = apply_random_vertical_flip
18
  self.apply_random_rotation = apply_random_rotation
19
 
20
- def load_data(self, image_path):
21
  image = tf.io.read_file(image_path)
22
  image = tf.image.decode_png(image, channels=3)
23
  image = image / 255.0
@@ -25,7 +25,7 @@ class UnpairedLowLightDataset:
25
 
26
  def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
27
  dataset = tf.data.Dataset.from_tensor_slices((images))
28
- dataset = dataset.map(self.load_data, num_parallel_calls=tf.data.AUTOTUNE)
29
  dataset = dataset.map(
30
  self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
31
  )
 
17
  self.apply_random_vertical_flip = apply_random_vertical_flip
18
  self.apply_random_rotation = apply_random_rotation
19
 
20
+ def _load_data(self, image_path):
21
  image = tf.io.read_file(image_path)
22
  image = tf.image.decode_png(image, channels=3)
23
  image = image / 255.0
 
25
 
26
  def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
27
  dataset = tf.data.Dataset.from_tensor_slices((images))
28
+ dataset = dataset.map(self._load_data, num_parallel_calls=tf.data.AUTOTUNE)
29
  dataset = dataset.map(
30
  self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
31
  )
enhance_me/zero_dce/losses/__init__.py CHANGED
@@ -5,7 +5,11 @@ from .spatial_constancy import SpatialConsistencyLoss
5
 
6
  def color_constancy_loss(x):
7
  mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
8
- mean_r, mean_g, mean_b = mean_rgb[:, :, :, 0], mean_rgb[:, :, :, 1], mean_rgb[:, :, :, 2]
 
 
 
 
9
  diff_rg = tf.square(mean_r - mean_g)
10
  diff_rb = tf.square(mean_r - mean_b)
11
  diff_gb = tf.square(mean_b - mean_g)
 
5
 
6
  def color_constancy_loss(x):
7
  mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
8
+ mean_r, mean_g, mean_b = (
9
+ mean_rgb[:, :, :, 0],
10
+ mean_rgb[:, :, :, 1],
11
+ mean_rgb[:, :, :, 2],
12
+ )
13
  diff_rg = tf.square(mean_r - mean_g)
14
  diff_rb = tf.square(mean_r - mean_b)
15
  diff_gb = tf.square(mean_b - mean_g)
enhance_me/zero_dce/models/zero_dce.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import optimizers, Model
3
+
4
+ from .dce_net import build_dce_net
5
+ from ..dataloader import UnpairedLowLightDataset
6
+ from ..losses import (
7
+ color_constancy_loss,
8
+ exposure_loss,
9
+ illumination_smoothness_loss,
10
+ SpatialConsistencyLoss,
11
+ )
12
+
13
+
14
+ class ZeroDCE(Model):
15
+ def __init__(self, **kwargs):
16
+ super(ZeroDCE, self).__init__(**kwargs)
17
+ self.dce_model = build_dce_net()
18
+
19
+ def compile(self, learning_rate, **kwargs):
20
+ super(ZeroDCE, self).compile(**kwargs)
21
+ self.optimizer = optimizers.Adam(learning_rate=learning_rate)
22
+ self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
23
+
24
+ def get_enhanced_image(self, data, output):
25
+ r1 = output[:, :, :, :3]
26
+ r2 = output[:, :, :, 3:6]
27
+ r3 = output[:, :, :, 6:9]
28
+ r4 = output[:, :, :, 9:12]
29
+ r5 = output[:, :, :, 12:15]
30
+ r6 = output[:, :, :, 15:18]
31
+ r7 = output[:, :, :, 18:21]
32
+ r8 = output[:, :, :, 21:24]
33
+ x = data + r1 * (tf.square(data) - data)
34
+ x = x + r2 * (tf.square(x) - x)
35
+ x = x + r3 * (tf.square(x) - x)
36
+ enhanced_image = x + r4 * (tf.square(x) - x)
37
+ x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
38
+ x = x + r6 * (tf.square(x) - x)
39
+ x = x + r7 * (tf.square(x) - x)
40
+ enhanced_image = x + r8 * (tf.square(x) - x)
41
+ return enhanced_image
42
+
43
+ def call(self, data):
44
+ dce_net_output = self.dce_model(data)
45
+ return self.get_enhanced_image(data, dce_net_output)
46
+
47
+ def compute_losses(self, data, output):
48
+ enhanced_image = self.get_enhanced_image(data, output)
49
+ loss_illumination = 200 * illumination_smoothness_loss(output)
50
+ loss_spatial_constancy = tf.reduce_mean(
51
+ self.spatial_constancy_loss(enhanced_image, data)
52
+ )
53
+ loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
54
+ loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
55
+ total_loss = (
56
+ loss_illumination
57
+ + loss_spatial_constancy
58
+ + loss_color_constancy
59
+ + loss_exposure
60
+ )
61
+ return {
62
+ "total_loss": total_loss,
63
+ "illumination_smoothness_loss": loss_illumination,
64
+ "spatial_constancy_loss": loss_spatial_constancy,
65
+ "color_constancy_loss": loss_color_constancy,
66
+ "exposure_loss": loss_exposure,
67
+ }
68
+
69
+ def train_step(self, data):
70
+ with tf.GradientTape() as tape:
71
+ output = self.dce_model(data)
72
+ losses = self.compute_losses(data, output)
73
+ gradients = tape.gradient(
74
+ losses["total_loss"], self.dce_model.trainable_weights
75
+ )
76
+ self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
77
+ return losses
78
+
79
+ def test_step(self, data):
80
+ output = self.dce_model(data)
81
+ return self.compute_losses(data, output)
82
+
83
+ def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
84
+ """While saving the weights, we simply save the weights of the DCE-Net"""
85
+ self.dce_model.save_weights(
86
+ filepath, overwrite=overwrite, save_format=save_format, options=options
87
+ )
88
+
89
+ def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
90
+ """While loading the weights, we simply load the weights of the DCE-Net"""
91
+ self.dce_model.load_weights(
92
+ filepath=filepath,
93
+ by_name=by_name,
94
+ skip_mismatch=skip_mismatch,
95
+ options=options,
96
+ )