import tensorflow as tf from tensorflow.keras import layers from .skff import selective_kernel_feature_fusion from .dual_attention import dual_attention_unit_block def down_sampling_module(input_tensor): channels = list(input_tensor.shape)[-1] main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")( input_tensor ) main_branch = layers.Conv2D( channels, kernel_size=(3, 3), padding="same", activation="relu" )(main_branch) main_branch = layers.MaxPooling2D()(main_branch) main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch) skip_branch = layers.MaxPooling2D()(input_tensor) skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch) return layers.Add()([skip_branch, main_branch]) def up_sampling_module(input_tensor): channels = list(input_tensor.shape)[-1] main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")( input_tensor ) main_branch = layers.Conv2D( channels, kernel_size=(3, 3), padding="same", activation="relu" )(main_branch) main_branch = layers.UpSampling2D()(main_branch) main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch) skip_branch = layers.UpSampling2D()(input_tensor) skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch) return layers.Add()([skip_branch, main_branch]) # MRB Block def multi_scale_residual_block(input_tensor, channels): # features level1 = input_tensor level2 = down_sampling_module(input_tensor) level3 = down_sampling_module(level2) # DAU level1_dau = dual_attention_unit_block(level1) level2_dau = dual_attention_unit_block(level2) level3_dau = dual_attention_unit_block(level3) # SKFF level1_skff = selective_kernel_feature_fusion( level1_dau, up_sampling_module(level2_dau), up_sampling_module(up_sampling_module(level3_dau)), ) level2_skff = selective_kernel_feature_fusion( down_sampling_module(level1_dau), level2_dau, up_sampling_module(level3_dau) ) level3_skff = selective_kernel_feature_fusion( down_sampling_module(down_sampling_module(level1_dau)), down_sampling_module(level2_dau), level3_dau, ) # DAU 2 level1_dau_2 = dual_attention_unit_block(level1_skff) level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff))) level3_dau_2 = up_sampling_module( up_sampling_module(dual_attention_unit_block(level3_skff)) ) # SKFF 2 skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2) conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_) return layers.Add()([input_tensor, conv]) def recursive_residual_group(input_tensor, num_mrb, channels): conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor) for _ in range(num_mrb): conv1 = multi_scale_residual_block(conv1, channels) conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1) return layers.Add()([conv2, input_tensor])