import tensorflow as tf from tensorflow.keras import layers def selective_kernel_feature_fusion( multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3 ): channels = list(multi_scale_feature_1.shape)[-1] combined_feature = layers.Add()( [multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3] ) gap = layers.GlobalAveragePooling2D()(combined_feature) channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels)) compact_feature_representation = layers.Conv2D( filters=channels // 8, kernel_size=(1, 1), activation="relu" )(channel_wise_statistics) feature_descriptor_1 = layers.Conv2D( channels, kernel_size=(1, 1), activation="softmax" )(compact_feature_representation) feature_descriptor_2 = layers.Conv2D( channels, kernel_size=(1, 1), activation="softmax" )(compact_feature_representation) feature_descriptor_3 = layers.Conv2D( channels, kernel_size=(1, 1), activation="softmax" )(compact_feature_representation) feature_1 = multi_scale_feature_1 * feature_descriptor_1 feature_2 = multi_scale_feature_2 * feature_descriptor_2 feature_3 = multi_scale_feature_3 * feature_descriptor_3 aggregated_feature = layers.Add()([feature_1, feature_2, feature_3]) return aggregated_feature