Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow.keras import layers | |
def spatial_attention_block(input_tensor): | |
average_pooling = tf.reduce_max(input_tensor, axis=-1) | |
average_pooling = tf.expand_dims(average_pooling, axis=-1) | |
max_pooling = tf.reduce_mean(input_tensor, axis=-1) | |
max_pooling = tf.expand_dims(max_pooling, axis=-1) | |
concatenated = layers.Concatenate(axis=-1)([average_pooling, max_pooling]) | |
feature_map = layers.Conv2D(1, kernel_size=(1, 1))(concatenated) | |
feature_map = tf.nn.sigmoid(feature_map) | |
return input_tensor * feature_map | |
def channel_attention_block(input_tensor): | |
channels = list(input_tensor.shape)[-1] | |
average_pooling = layers.GlobalAveragePooling2D()(input_tensor) | |
feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels)) | |
feature_activations = layers.Conv2D( | |
filters=channels // 8, kernel_size=(1, 1), activation="relu" | |
)(feature_descriptor) | |
feature_activations = layers.Conv2D( | |
filters=channels, kernel_size=(1, 1), activation="sigmoid" | |
)(feature_activations) | |
return input_tensor * feature_activations | |
def dual_attention_unit_block(input_tensor): | |
channels = list(input_tensor.shape)[-1] | |
feature_map = layers.Conv2D( | |
channels, kernel_size=(3, 3), padding="same", activation="relu" | |
)(input_tensor) | |
feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")( | |
feature_map | |
) | |
channel_attention = channel_attention_block(feature_map) | |
spatial_attention = spatial_attention_block(feature_map) | |
concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention]) | |
concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation) | |
return layers.Add()([input_tensor, concatenation]) | |