Spaces:
Runtime error
Runtime error
geekyrakshit
commited on
Commit
•
6fd61b9
1
Parent(s):
865788c
added mirnet model + charbonnier loss
Browse files- enhance_me/commons.py +4 -0
- enhance_me/mirnet/losses.py +13 -0
- enhance_me/mirnet/models/__init__.py +1 -0
- enhance_me/mirnet/models/dual_attention.py +41 -0
- enhance_me/mirnet/models/mirnet_model.py +14 -0
- enhance_me/mirnet/models/recursive_residual_blocks.py +79 -0
- enhance_me/mirnet/models/skff.py +30 -0
enhance_me/commons.py
CHANGED
@@ -7,3 +7,7 @@ def read_image(image_path):
|
|
7 |
image.set_shape([None, None, 3])
|
8 |
image = tf.cast(image, dtype=tf.float32) / 255.0
|
9 |
return image
|
|
|
|
|
|
|
|
|
|
7 |
image.set_shape([None, None, 3])
|
8 |
image = tf.cast(image, dtype=tf.float32) / 255.0
|
9 |
return image
|
10 |
+
|
11 |
+
|
12 |
+
def peak_signal_noise_ratio(y_true, y_pred):
|
13 |
+
return tf.image.psnr(y_pred, y_true, max_val=255.0)
|
enhance_me/mirnet/losses.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import losses
|
3 |
+
|
4 |
+
|
5 |
+
class CharbonnierLoss(losses.Loss):
|
6 |
+
def __init__(self, epsilon: float = 1e-3, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.epsilon = epsilon
|
9 |
+
|
10 |
+
def call(self, y_true, y_pred):
|
11 |
+
return tf.reduce_mean(
|
12 |
+
tf.sqrt(tf.square(y_true - y_pred) + tf.square(self.epsilon))
|
13 |
+
)
|
enhance_me/mirnet/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .mirnet_model import build_mirnet_model
|
enhance_me/mirnet/models/dual_attention.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
|
4 |
+
|
5 |
+
def spatial_attention_block(input_tensor):
|
6 |
+
average_pooling = tf.reduce_max(input_tensor, axis=-1)
|
7 |
+
average_pooling = tf.expand_dims(average_pooling, axis=-1)
|
8 |
+
max_pooling = tf.reduce_mean(input_tensor, axis=-1)
|
9 |
+
max_pooling = tf.expand_dims(max_pooling, axis=-1)
|
10 |
+
concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
|
11 |
+
feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
|
12 |
+
feature_map = tf.nn.sigmoid(feature_map)
|
13 |
+
return input_tensor * feature_map
|
14 |
+
|
15 |
+
|
16 |
+
def channel_attention_block(input_tensor):
|
17 |
+
channels = list(input_tensor.shape)[-1]
|
18 |
+
average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
|
19 |
+
feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
|
20 |
+
feature_activations = layers.Conv2D(
|
21 |
+
filters=channels // 8, kernel_size=(1, 1), activation="relu"
|
22 |
+
)(feature_descriptor)
|
23 |
+
feature_activations = layers.Conv2D(
|
24 |
+
filters=channels, kernel_size=(1, 1), activation="sigmoid"
|
25 |
+
)(feature_activations)
|
26 |
+
return input_tensor * feature_activations
|
27 |
+
|
28 |
+
|
29 |
+
def dual_attention_unit_block(input_tensor):
|
30 |
+
channels = list(input_tensor.shape)[-1]
|
31 |
+
feature_map = layers.Conv2D(
|
32 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
33 |
+
)(input_tensor)
|
34 |
+
feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
|
35 |
+
feature_map
|
36 |
+
)
|
37 |
+
channel_attention = channel_attention_block(feature_map)
|
38 |
+
spatial_attention = spatial_attention_block(feature_map)
|
39 |
+
concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
|
40 |
+
concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
|
41 |
+
return layers.Add()([input_tensor, concatenation])
|
enhance_me/mirnet/models/mirnet_model.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers, Input, Model
|
3 |
+
|
4 |
+
from .recursive_residual_blocks import recursive_residual_group
|
5 |
+
|
6 |
+
|
7 |
+
def mirnet_model(num_rrg, num_mrb, channels):
|
8 |
+
input_tensor = Input(shape=[None, None, 3])
|
9 |
+
x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
|
10 |
+
for _ in range(num_rrg):
|
11 |
+
x1 = recursive_residual_group(x1, num_mrb, channels)
|
12 |
+
conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
|
13 |
+
output_tensor = layers.Add()([input_tensor, conv])
|
14 |
+
return Model(input_tensor, output_tensor)
|
enhance_me/mirnet/models/recursive_residual_blocks.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
|
4 |
+
from .skff import selective_kernel_feature_fusion
|
5 |
+
from .dual_attention import dual_attention_unit_block
|
6 |
+
|
7 |
+
|
8 |
+
def down_sampling_module(input_tensor):
|
9 |
+
channels = list(input_tensor.shape)[-1]
|
10 |
+
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
|
11 |
+
input_tensor
|
12 |
+
)
|
13 |
+
main_branch = layers.Conv2D(
|
14 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
15 |
+
)(main_branch)
|
16 |
+
main_branch = layers.MaxPooling2D()(main_branch)
|
17 |
+
main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
|
18 |
+
skip_branch = layers.MaxPooling2D()(input_tensor)
|
19 |
+
skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
|
20 |
+
return layers.Add()([skip_branch, main_branch])
|
21 |
+
|
22 |
+
|
23 |
+
def up_sampling_module(input_tensor):
|
24 |
+
channels = list(input_tensor.shape)[-1]
|
25 |
+
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
|
26 |
+
input_tensor
|
27 |
+
)
|
28 |
+
main_branch = layers.Conv2D(
|
29 |
+
channels, kernel_size=(3, 3), padding="same", activation="relu"
|
30 |
+
)(main_branch)
|
31 |
+
main_branch = layers.UpSampling2D()(main_branch)
|
32 |
+
main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
|
33 |
+
skip_branch = layers.UpSampling2D()(input_tensor)
|
34 |
+
skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
|
35 |
+
return layers.Add()([skip_branch, main_branch])
|
36 |
+
|
37 |
+
|
38 |
+
# MRB Block
|
39 |
+
def multi_scale_residual_block(input_tensor, channels):
|
40 |
+
# features
|
41 |
+
level1 = input_tensor
|
42 |
+
level2 = down_sampling_module(input_tensor)
|
43 |
+
level3 = down_sampling_module(level2)
|
44 |
+
# DAU
|
45 |
+
level1_dau = dual_attention_unit_block(level1)
|
46 |
+
level2_dau = dual_attention_unit_block(level2)
|
47 |
+
level3_dau = dual_attention_unit_block(level3)
|
48 |
+
# SKFF
|
49 |
+
level1_skff = selective_kernel_feature_fusion(
|
50 |
+
level1_dau,
|
51 |
+
up_sampling_module(level2_dau),
|
52 |
+
up_sampling_module(up_sampling_module(level3_dau)),
|
53 |
+
)
|
54 |
+
level2_skff = selective_kernel_feature_fusion(
|
55 |
+
down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau)
|
56 |
+
)
|
57 |
+
level3_skff = selective_kernel_feature_fusion(
|
58 |
+
down_sampling_module(down_sampling_module(level1_dau)),
|
59 |
+
down_sampling_module(level2_dau),
|
60 |
+
level3_dau,
|
61 |
+
)
|
62 |
+
# DAU 2
|
63 |
+
level1_dau_2 = dual_attention_unit_block(level1_skff)
|
64 |
+
level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
|
65 |
+
level3_dau_2 = up_sampling_module(
|
66 |
+
up_sampling_module(dual_attention_unit_block(level3_skff))
|
67 |
+
)
|
68 |
+
# SKFF 2
|
69 |
+
skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2)
|
70 |
+
conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
|
71 |
+
return layers.Add()([input_tensor, conv])
|
72 |
+
|
73 |
+
|
74 |
+
def recursive_residual_group(input_tensor, num_mrb, channels):
|
75 |
+
conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
|
76 |
+
for _ in range(num_mrb):
|
77 |
+
conv1 = multi_scale_residual_block(conv1, channels)
|
78 |
+
conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
|
79 |
+
return layers.Add()([conv2, input_tensor])
|
enhance_me/mirnet/models/skff.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import layers
|
3 |
+
|
4 |
+
|
5 |
+
def selective_kernel_feature_fusion(
|
6 |
+
multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
|
7 |
+
):
|
8 |
+
channels = list(multi_scale_feature_1.shape)[-1]
|
9 |
+
combined_feature = layers.Add()(
|
10 |
+
[multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
|
11 |
+
)
|
12 |
+
gap = layers.GlobalAveragePooling2D()(combined_feature)
|
13 |
+
channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
|
14 |
+
compact_feature_representation = layers.Conv2D(
|
15 |
+
filters=channels // 8, kernel_size=(1, 1), activation="relu"
|
16 |
+
)(channel_wise_statistics)
|
17 |
+
feature_descriptor_1 = layers.Conv2D(
|
18 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
19 |
+
)(compact_feature_representation)
|
20 |
+
feature_descriptor_2 = layers.Conv2D(
|
21 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
22 |
+
)(compact_feature_representation)
|
23 |
+
feature_descriptor_3 = layers.Conv2D(
|
24 |
+
channels, kernel_size=(1, 1), activation="softmax"
|
25 |
+
)(compact_feature_representation)
|
26 |
+
feature_1 = multi_scale_feature_1 * feature_descriptor_1
|
27 |
+
feature_2 = multi_scale_feature_2 * feature_descriptor_2
|
28 |
+
feature_3 = multi_scale_feature_3 * feature_descriptor_3
|
29 |
+
aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
|
30 |
+
return aggregated_feature
|