Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow.keras import layers | |
def selective_kernel_feature_fusion( | |
multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3 | |
): | |
channels = list(multi_scale_feature_1.shape)[-1] | |
combined_feature = layers.Add()( | |
[multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3] | |
) | |
gap = layers.GlobalAveragePooling2D()(combined_feature) | |
channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels)) | |
compact_feature_representation = layers.Conv2D( | |
filters=channels // 8, kernel_size=(1, 1), activation="relu" | |
)(channel_wise_statistics) | |
feature_descriptor_1 = layers.Conv2D( | |
channels, kernel_size=(1, 1), activation="softmax" | |
)(compact_feature_representation) | |
feature_descriptor_2 = layers.Conv2D( | |
channels, kernel_size=(1, 1), activation="softmax" | |
)(compact_feature_representation) | |
feature_descriptor_3 = layers.Conv2D( | |
channels, kernel_size=(1, 1), activation="softmax" | |
)(compact_feature_representation) | |
feature_1 = multi_scale_feature_1 * feature_descriptor_1 | |
feature_2 = multi_scale_feature_2 * feature_descriptor_2 | |
feature_3 = multi_scale_feature_3 * feature_descriptor_3 | |
aggregated_feature = layers.Add()([feature_1, feature_2, feature_3]) | |
return aggregated_feature | |