enhance-me / enhance_me /mirnet /models /dual_attention.py
geekyrakshit's picture
added mirnet model + charbonnier loss
6fd61b9
raw
history blame
1.79 kB
import tensorflow as tf
from tensorflow.keras import layers
def spatial_attention_block(input_tensor):
average_pooling = tf.reduce_max(input_tensor, axis=-1)
average_pooling = tf.expand_dims(average_pooling, axis=-1)
max_pooling = tf.reduce_mean(input_tensor, axis=-1)
max_pooling = tf.expand_dims(max_pooling, axis=-1)
concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling])
feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
feature_map = tf.nn.sigmoid(feature_map)
return input_tensor * feature_map
def channel_attention_block(input_tensor):
channels = list(input_tensor.shape)[-1]
average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
feature_activations = layers.Conv2D(
filters=channels // 8, kernel_size=(1, 1), activation="relu"
)(feature_descriptor)
feature_activations = layers.Conv2D(
filters=channels, kernel_size=(1, 1), activation="sigmoid"
)(feature_activations)
return input_tensor * feature_activations
def dual_attention_unit_block(input_tensor):
channels = list(input_tensor.shape)[-1]
feature_map = layers.Conv2D(
channels, kernel_size=(3, 3), padding="same", activation="relu"
)(input_tensor)
feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
feature_map
)
channel_attention = channel_attention_block(feature_map)
spatial_attention = spatial_attention_block(feature_map)
concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
return layers.Add()([input_tensor, concatenation])