Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow.keras import layers | |
from .skff import selective_kernel_feature_fusion | |
from .dual_attention import dual_attention_unit_block | |
def down_sampling_module(input_tensor): | |
channels = list(input_tensor.shape)[-1] | |
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")( | |
input_tensor | |
) | |
main_branch = layers.Conv2D( | |
channels, kernel_size=(3, 3), padding="same", activation="relu" | |
)(main_branch) | |
main_branch = layers.MaxPooling2D()(main_branch) | |
main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch) | |
skip_branch = layers.MaxPooling2D()(input_tensor) | |
skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch) | |
return layers.Add()([skip_branch, main_branch]) | |
def up_sampling_module(input_tensor): | |
channels = list(input_tensor.shape)[-1] | |
main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")( | |
input_tensor | |
) | |
main_branch = layers.Conv2D( | |
channels, kernel_size=(3, 3), padding="same", activation="relu" | |
)(main_branch) | |
main_branch = layers.UpSampling2D()(main_branch) | |
main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch) | |
skip_branch = layers.UpSampling2D()(input_tensor) | |
skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch) | |
return layers.Add()([skip_branch, main_branch]) | |
# MRB Block | |
def multi_scale_residual_block(input_tensor, channels): | |
# features | |
level1 = input_tensor | |
level2 = down_sampling_module(input_tensor) | |
level3 = down_sampling_module(level2) | |
# DAU | |
level1_dau = dual_attention_unit_block(level1) | |
level2_dau = dual_attention_unit_block(level2) | |
level3_dau = dual_attention_unit_block(level3) | |
# SKFF | |
level1_skff = selective_kernel_feature_fusion( | |
level1_dau, | |
up_sampling_module(level2_dau), | |
up_sampling_module(up_sampling_module(level3_dau)), | |
) | |
level2_skff = selective_kernel_feature_fusion( | |
down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau) | |
) | |
level3_skff = selective_kernel_feature_fusion( | |
down_sampling_module(down_sampling_module(level1_dau)), | |
down_sampling_module(level2_dau), | |
level3_dau, | |
) | |
# DAU 2 | |
level1_dau_2 = dual_attention_unit_block(level1_skff) | |
level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff))) | |
level3_dau_2 = up_sampling_module( | |
up_sampling_module(dual_attention_unit_block(level3_skff)) | |
) | |
# SKFF 2 | |
skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2) | |
conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_) | |
return layers.Add()([input_tensor, conv]) | |
def recursive_residual_group(input_tensor, num_mrb, channels): | |
conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor) | |
for _ in range(num_mrb): | |
conv1 = multi_scale_residual_block(conv1, channels) | |
conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1) | |
return layers.Add()([conv2, input_tensor]) | |