Spaces:
Runtime error
Runtime error
Commit
·
0f84baa
1
Parent(s):
659a217
added zero-dce model
Browse files
enhance_me/zero_dce/dataloader.py
CHANGED
@@ -17,7 +17,7 @@ class UnpairedLowLightDataset:
|
|
17 |
self.apply_random_vertical_flip = apply_random_vertical_flip
|
18 |
self.apply_random_rotation = apply_random_rotation
|
19 |
|
20 |
-
def
|
21 |
image = tf.io.read_file(image_path)
|
22 |
image = tf.image.decode_png(image, channels=3)
|
23 |
image = image / 255.0
|
@@ -25,7 +25,7 @@ class UnpairedLowLightDataset:
|
|
25 |
|
26 |
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
|
27 |
dataset = tf.data.Dataset.from_tensor_slices((images))
|
28 |
-
dataset = dataset.map(self.
|
29 |
dataset = dataset.map(
|
30 |
self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
|
31 |
)
|
|
|
17 |
self.apply_random_vertical_flip = apply_random_vertical_flip
|
18 |
self.apply_random_rotation = apply_random_rotation
|
19 |
|
20 |
+
def _load_data(self, image_path):
|
21 |
image = tf.io.read_file(image_path)
|
22 |
image = tf.image.decode_png(image, channels=3)
|
23 |
image = image / 255.0
|
|
|
25 |
|
26 |
def _get_dataset(self, images: List[str], batch_size: int, is_train: bool):
|
27 |
dataset = tf.data.Dataset.from_tensor_slices((images))
|
28 |
+
dataset = dataset.map(self._load_data, num_parallel_calls=tf.data.AUTOTUNE)
|
29 |
dataset = dataset.map(
|
30 |
self.augmentation_factory.random_crop, num_parallel_calls=tf.data.AUTOTUNE
|
31 |
)
|
enhance_me/zero_dce/losses/__init__.py
CHANGED
@@ -5,7 +5,11 @@ from .spatial_constancy import SpatialConsistencyLoss
|
|
5 |
|
6 |
def color_constancy_loss(x):
|
7 |
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
|
8 |
-
mean_r, mean_g, mean_b =
|
|
|
|
|
|
|
|
|
9 |
diff_rg = tf.square(mean_r - mean_g)
|
10 |
diff_rb = tf.square(mean_r - mean_b)
|
11 |
diff_gb = tf.square(mean_b - mean_g)
|
|
|
5 |
|
6 |
def color_constancy_loss(x):
|
7 |
mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
|
8 |
+
mean_r, mean_g, mean_b = (
|
9 |
+
mean_rgb[:, :, :, 0],
|
10 |
+
mean_rgb[:, :, :, 1],
|
11 |
+
mean_rgb[:, :, :, 2],
|
12 |
+
)
|
13 |
diff_rg = tf.square(mean_r - mean_g)
|
14 |
diff_rb = tf.square(mean_r - mean_b)
|
15 |
diff_gb = tf.square(mean_b - mean_g)
|
enhance_me/zero_dce/models/zero_dce.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import optimizers, Model
|
3 |
+
|
4 |
+
from .dce_net import build_dce_net
|
5 |
+
from ..dataloader import UnpairedLowLightDataset
|
6 |
+
from ..losses import (
|
7 |
+
color_constancy_loss,
|
8 |
+
exposure_loss,
|
9 |
+
illumination_smoothness_loss,
|
10 |
+
SpatialConsistencyLoss,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class ZeroDCE(Model):
|
15 |
+
def __init__(self, **kwargs):
|
16 |
+
super(ZeroDCE, self).__init__(**kwargs)
|
17 |
+
self.dce_model = build_dce_net()
|
18 |
+
|
19 |
+
def compile(self, learning_rate, **kwargs):
|
20 |
+
super(ZeroDCE, self).compile(**kwargs)
|
21 |
+
self.optimizer = optimizers.Adam(learning_rate=learning_rate)
|
22 |
+
self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
|
23 |
+
|
24 |
+
def get_enhanced_image(self, data, output):
|
25 |
+
r1 = output[:, :, :, :3]
|
26 |
+
r2 = output[:, :, :, 3:6]
|
27 |
+
r3 = output[:, :, :, 6:9]
|
28 |
+
r4 = output[:, :, :, 9:12]
|
29 |
+
r5 = output[:, :, :, 12:15]
|
30 |
+
r6 = output[:, :, :, 15:18]
|
31 |
+
r7 = output[:, :, :, 18:21]
|
32 |
+
r8 = output[:, :, :, 21:24]
|
33 |
+
x = data + r1 * (tf.square(data) - data)
|
34 |
+
x = x + r2 * (tf.square(x) - x)
|
35 |
+
x = x + r3 * (tf.square(x) - x)
|
36 |
+
enhanced_image = x + r4 * (tf.square(x) - x)
|
37 |
+
x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
|
38 |
+
x = x + r6 * (tf.square(x) - x)
|
39 |
+
x = x + r7 * (tf.square(x) - x)
|
40 |
+
enhanced_image = x + r8 * (tf.square(x) - x)
|
41 |
+
return enhanced_image
|
42 |
+
|
43 |
+
def call(self, data):
|
44 |
+
dce_net_output = self.dce_model(data)
|
45 |
+
return self.get_enhanced_image(data, dce_net_output)
|
46 |
+
|
47 |
+
def compute_losses(self, data, output):
|
48 |
+
enhanced_image = self.get_enhanced_image(data, output)
|
49 |
+
loss_illumination = 200 * illumination_smoothness_loss(output)
|
50 |
+
loss_spatial_constancy = tf.reduce_mean(
|
51 |
+
self.spatial_constancy_loss(enhanced_image, data)
|
52 |
+
)
|
53 |
+
loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
|
54 |
+
loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
|
55 |
+
total_loss = (
|
56 |
+
loss_illumination
|
57 |
+
+ loss_spatial_constancy
|
58 |
+
+ loss_color_constancy
|
59 |
+
+ loss_exposure
|
60 |
+
)
|
61 |
+
return {
|
62 |
+
"total_loss": total_loss,
|
63 |
+
"illumination_smoothness_loss": loss_illumination,
|
64 |
+
"spatial_constancy_loss": loss_spatial_constancy,
|
65 |
+
"color_constancy_loss": loss_color_constancy,
|
66 |
+
"exposure_loss": loss_exposure,
|
67 |
+
}
|
68 |
+
|
69 |
+
def train_step(self, data):
|
70 |
+
with tf.GradientTape() as tape:
|
71 |
+
output = self.dce_model(data)
|
72 |
+
losses = self.compute_losses(data, output)
|
73 |
+
gradients = tape.gradient(
|
74 |
+
losses["total_loss"], self.dce_model.trainable_weights
|
75 |
+
)
|
76 |
+
self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
|
77 |
+
return losses
|
78 |
+
|
79 |
+
def test_step(self, data):
|
80 |
+
output = self.dce_model(data)
|
81 |
+
return self.compute_losses(data, output)
|
82 |
+
|
83 |
+
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
|
84 |
+
"""While saving the weights, we simply save the weights of the DCE-Net"""
|
85 |
+
self.dce_model.save_weights(
|
86 |
+
filepath, overwrite=overwrite, save_format=save_format, options=options
|
87 |
+
)
|
88 |
+
|
89 |
+
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
|
90 |
+
"""While loading the weights, we simply load the weights of the DCE-Net"""
|
91 |
+
self.dce_model.load_weights(
|
92 |
+
filepath=filepath,
|
93 |
+
by_name=by_name,
|
94 |
+
skip_mismatch=skip_mismatch,
|
95 |
+
options=options,
|
96 |
+
)
|