geekyrakshit commited on
Commit
659a217
1 Parent(s): 8b5218a

added loss functions

Browse files
enhance_me/zero_dce/losses/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from .spatial_constancy import SpatialConsistencyLoss
4
+
5
+
6
+ def color_constancy_loss(x):
7
+ mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
8
+ mean_r, mean_g, mean_b = mean_rgb[:, :, :, 0], mean_rgb[:, :, :, 1], mean_rgb[:, :, :, 2]
9
+ diff_rg = tf.square(mean_r - mean_g)
10
+ diff_rb = tf.square(mean_r - mean_b)
11
+ diff_gb = tf.square(mean_b - mean_g)
12
+ return tf.sqrt(tf.square(diff_rg) + tf.square(diff_rb) + tf.square(diff_gb))
13
+
14
+
15
+ def exposure_loss(x, mean_val=0.6):
16
+ x = tf.reduce_mean(x, axis=3, keepdims=True)
17
+ mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
18
+ return tf.reduce_mean(tf.square(mean - mean_val))
19
+
20
+
21
+ def illumination_smoothness_loss(x):
22
+ batch_size = tf.shape(x)[0]
23
+ h_x = tf.shape(x)[1]
24
+ w_x = tf.shape(x)[2]
25
+ count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
26
+ count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
27
+ h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
28
+ w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
29
+ batch_size = tf.cast(batch_size, dtype=tf.float32)
30
+ count_h = tf.cast(count_h, dtype=tf.float32)
31
+ count_w = tf.cast(count_w, dtype=tf.float32)
32
+ return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
enhance_me/zero_dce/losses/spatial_constancy.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import losses
3
+
4
+
5
+ class SpatialConsistencyLoss(losses.Loss):
6
+ def __init__(self, **kwargs):
7
+ super(SpatialConsistencyLoss, self).__init__(reduction="none")
8
+
9
+ self.left_kernel = tf.constant(
10
+ [[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
11
+ )
12
+ self.right_kernel = tf.constant(
13
+ [[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
14
+ )
15
+ self.up_kernel = tf.constant(
16
+ [[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
17
+ )
18
+ self.down_kernel = tf.constant(
19
+ [[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
20
+ )
21
+
22
+ def call(self, y_true, y_pred):
23
+
24
+ original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
25
+ enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
26
+ original_pool = tf.nn.avg_pool2d(
27
+ original_mean, ksize=4, strides=4, padding="VALID"
28
+ )
29
+ enhanced_pool = tf.nn.avg_pool2d(
30
+ enhanced_mean, ksize=4, strides=4, padding="VALID"
31
+ )
32
+
33
+ d_original_left = tf.nn.conv2d(
34
+ original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
35
+ )
36
+ d_original_right = tf.nn.conv2d(
37
+ original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
38
+ )
39
+ d_original_up = tf.nn.conv2d(
40
+ original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
41
+ )
42
+ d_original_down = tf.nn.conv2d(
43
+ original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
44
+ )
45
+
46
+ d_enhanced_left = tf.nn.conv2d(
47
+ enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME"
48
+ )
49
+ d_enhanced_right = tf.nn.conv2d(
50
+ enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME"
51
+ )
52
+ d_enhanced_up = tf.nn.conv2d(
53
+ enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
54
+ )
55
+ d_enhanced_down = tf.nn.conv2d(
56
+ enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME"
57
+ )
58
+
59
+ d_left = tf.square(d_original_left - d_enhanced_left)
60
+ d_right = tf.square(d_original_right - d_enhanced_right)
61
+ d_up = tf.square(d_original_up - d_enhanced_up)
62
+ d_down = tf.square(d_original_down - d_enhanced_down)
63
+ return d_left + d_right + d_up + d_down