geekyrakshit commited on
Commit
6fd61b9
1 Parent(s): 865788c

added mirnet model + charbonnier loss

Browse files
enhance_me/commons.py CHANGED
@@ -7,3 +7,7 @@ def read_image(image_path):
7
  image.set_shape([None, None, 3])
8
  image = tf.cast(image, dtype=tf.float32) / 255.0
9
  return image
 
 
 
 
 
7
  image.set_shape([None, None, 3])
8
  image = tf.cast(image, dtype=tf.float32) / 255.0
9
  return image
10
+
11
+
12
+ def peak_signal_noise_ratio(y_true, y_pred):
13
+ return tf.image.psnr(y_pred, y_true, max_val=255.0)
enhance_me/mirnet/losses.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import losses
3
+
4
+
5
+ class CharbonnierLoss(losses.Loss):
6
+ def __init__(self, epsilon: float = 1e-3, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.epsilon = epsilon
9
+
10
+ def call(self, y_true, y_pred):
11
+ return tf.reduce_mean(
12
+ tf.sqrt(tf.square(y_true - y_pred) + tf.square(self.epsilon))
13
+ )
enhance_me/mirnet/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mirnet_model import build_mirnet_model
enhance_me/mirnet/models/dual_attention.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+
5
+ def spatial_attention_block(input_tensor):
6
+ average_pooling = tf.reduce_max(input_tensor, axis=-1)
7
+ average_pooling = tf.expand_dims(average_pooling, axis=-1)
8
+ max_pooling = tf.reduce_mean(input_tensor, axis=-1)
9
+ max_pooling = tf.expand_dims(max_pooling, axis=-1)
10
+ concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
11
+ feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
12
+ feature_map = tf.nn.sigmoid(feature_map)
13
+ return input_tensor * feature_map
14
+
15
+
16
+ def channel_attention_block(input_tensor):
17
+ channels = list(input_tensor.shape)[-1]
18
+ average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
19
+ feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
20
+ feature_activations = layers.Conv2D(
21
+ filters=channels // 8, kernel_size=(1, 1), activation="relu"
22
+ )(feature_descriptor)
23
+ feature_activations = layers.Conv2D(
24
+ filters=channels, kernel_size=(1, 1), activation="sigmoid"
25
+ )(feature_activations)
26
+ return input_tensor * feature_activations
27
+
28
+
29
+ def dual_attention_unit_block(input_tensor):
30
+ channels = list(input_tensor.shape)[-1]
31
+ feature_map = layers.Conv2D(
32
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
33
+ )(input_tensor)
34
+ feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
35
+ feature_map
36
+ )
37
+ channel_attention = channel_attention_block(feature_map)
38
+ spatial_attention = spatial_attention_block(feature_map)
39
+ concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
40
+ concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
41
+ return layers.Add()([input_tensor, concatenation])
enhance_me/mirnet/models/mirnet_model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, Input, Model
3
+
4
+ from .recursive_residual_blocks import recursive_residual_group
5
+
6
+
7
+ def mirnet_model(num_rrg, num_mrb, channels):
8
+ input_tensor = Input(shape=[None, None, 3])
9
+ x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
10
+ for _ in range(num_rrg):
11
+ x1 = recursive_residual_group(x1, num_mrb, channels)
12
+ conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
13
+ output_tensor = layers.Add()([input_tensor, conv])
14
+ return Model(input_tensor, output_tensor)
enhance_me/mirnet/models/recursive_residual_blocks.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+ from .skff import selective_kernel_feature_fusion
5
+ from .dual_attention import dual_attention_unit_block
6
+
7
+
8
+ def down_sampling_module(input_tensor):
9
+ channels = list(input_tensor.shape)[-1]
10
+ main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
11
+ input_tensor
12
+ )
13
+ main_branch = layers.Conv2D(
14
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
15
+ )(main_branch)
16
+ main_branch = layers.MaxPooling2D()(main_branch)
17
+ main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
18
+ skip_branch = layers.MaxPooling2D()(input_tensor)
19
+ skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
20
+ return layers.Add()([skip_branch, main_branch])
21
+
22
+
23
+ def up_sampling_module(input_tensor):
24
+ channels = list(input_tensor.shape)[-1]
25
+ main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
26
+ input_tensor
27
+ )
28
+ main_branch = layers.Conv2D(
29
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
30
+ )(main_branch)
31
+ main_branch = layers.UpSampling2D()(main_branch)
32
+ main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
33
+ skip_branch = layers.UpSampling2D()(input_tensor)
34
+ skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
35
+ return layers.Add()([skip_branch, main_branch])
36
+
37
+
38
+ # MRB Block
39
+ def multi_scale_residual_block(input_tensor, channels):
40
+ # features
41
+ level1 = input_tensor
42
+ level2 = down_sampling_module(input_tensor)
43
+ level3 = down_sampling_module(level2)
44
+ # DAU
45
+ level1_dau = dual_attention_unit_block(level1)
46
+ level2_dau = dual_attention_unit_block(level2)
47
+ level3_dau = dual_attention_unit_block(level3)
48
+ # SKFF
49
+ level1_skff = selective_kernel_feature_fusion(
50
+ level1_dau,
51
+ up_sampling_module(level2_dau),
52
+ up_sampling_module(up_sampling_module(level3_dau)),
53
+ )
54
+ level2_skff = selective_kernel_feature_fusion(
55
+ down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau)
56
+ )
57
+ level3_skff = selective_kernel_feature_fusion(
58
+ down_sampling_module(down_sampling_module(level1_dau)),
59
+ down_sampling_module(level2_dau),
60
+ level3_dau,
61
+ )
62
+ # DAU 2
63
+ level1_dau_2 = dual_attention_unit_block(level1_skff)
64
+ level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
65
+ level3_dau_2 = up_sampling_module(
66
+ up_sampling_module(dual_attention_unit_block(level3_skff))
67
+ )
68
+ # SKFF 2
69
+ skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2)
70
+ conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
71
+ return layers.Add()([input_tensor, conv])
72
+
73
+
74
+ def recursive_residual_group(input_tensor, num_mrb, channels):
75
+ conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
76
+ for _ in range(num_mrb):
77
+ conv1 = multi_scale_residual_block(conv1, channels)
78
+ conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
79
+ return layers.Add()([conv2, input_tensor])
enhance_me/mirnet/models/skff.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+
4
+
5
+ def selective_kernel_feature_fusion(
6
+ multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
7
+ ):
8
+ channels = list(multi_scale_feature_1.shape)[-1]
9
+ combined_feature = layers.Add()(
10
+ [multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
11
+ )
12
+ gap = layers.GlobalAveragePooling2D()(combined_feature)
13
+ channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
14
+ compact_feature_representation = layers.Conv2D(
15
+ filters=channels // 8, kernel_size=(1, 1), activation="relu"
16
+ )(channel_wise_statistics)
17
+ feature_descriptor_1 = layers.Conv2D(
18
+ channels, kernel_size=(1, 1), activation="softmax"
19
+ )(compact_feature_representation)
20
+ feature_descriptor_2 = layers.Conv2D(
21
+ channels, kernel_size=(1, 1), activation="softmax"
22
+ )(compact_feature_representation)
23
+ feature_descriptor_3 = layers.Conv2D(
24
+ channels, kernel_size=(1, 1), activation="softmax"
25
+ )(compact_feature_representation)
26
+ feature_1 = multi_scale_feature_1 * feature_descriptor_1
27
+ feature_2 = multi_scale_feature_2 * feature_descriptor_2
28
+ feature_3 = multi_scale_feature_3 * feature_descriptor_3
29
+ aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
30
+ return aggregated_feature