import tensorflow as tf from tensorflow.keras import layers def spatial_attention_block(input_tensor): average_pooling = tf.reduce_max(input_tensor, axis=-1) average_pooling = tf.expand_dims(average_pooling, axis=-1) max_pooling = tf.reduce_mean(input_tensor, axis=-1) max_pooling = tf.expand_dims(max_pooling, axis=-1) concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling]) feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated) feature_map = tf.nn.sigmoid(feature_map) return input_tensor * feature_map def channel_attention_block(input_tensor): channels = list(input_tensor.shape)[-1] average_pooling = layers.GlobalAveragePooling2D()(input_tensor) feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels)) feature_activations = layers.Conv2D( filters=channels // 8, kernel_size=(1, 1), activation="relu" )(feature_descriptor) feature_activations = layers.Conv2D( filters=channels, kernel_size=(1, 1), activation="sigmoid" )(feature_activations) return input_tensor * feature_activations def dual_attention_unit_block(input_tensor): channels = list(input_tensor.shape)[-1] feature_map = layers.Conv2D( channels, kernel_size=(3, 3), padding="same", activation="relu" )(input_tensor) feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")( feature_map ) channel_attention = channel_attention_block(feature_map) spatial_attention = spatial_attention_block(feature_map) concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention]) concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation) return layers.Add()([input_tensor, concatenation])