Spaces:
Running
Running
from torch import nn | |
from torch.nn import functional as F | |
class SpatialAttention(nn.Module): | |
def __init__(self, in_channels): | |
super(SpatialAttention, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1, padding=0) | |
def forward(self, x): | |
# Calculate attention scores | |
attention_scores = self.conv1(x) | |
attention_scores = F.softmax(attention_scores, dim=2) | |
# Apply attention to input features | |
attended_features = x * attention_scores | |
return attended_features | |
class DCT_Attention_Fusion_Conv(nn.Module): | |
def __init__(self, channels): | |
super(DCT_Attention_Fusion_Conv, self).__init__() | |
self.rgb_attention = SpatialAttention(channels) | |
self.depth_attention = SpatialAttention(channels) | |
self.rgb_pooling = nn.AdaptiveAvgPool2d(1) | |
self.depth_pooling = nn.AdaptiveAvgPool2d(1) | |
def forward(self, rgb_features, DCT_features): | |
# Spatial attention for both modalities | |
rgb_attended_features = self.rgb_attention(rgb_features) | |
depth_attended_features = self.depth_attention(DCT_features) | |
# Adaptive pooling for both modalities | |
rgb_pooled = self.rgb_pooling(rgb_attended_features) | |
depth_pooled = self.depth_pooling(depth_attended_features) | |
# Upsample attended and pooled features to the original size | |
rgb_upsampled = F.interpolate(rgb_pooled, size=rgb_features.size()[2:], mode='bilinear', align_corners=False) | |
depth_upsampled = F.interpolate(depth_pooled, size=DCT_features.size()[2:], mode='bilinear', align_corners=False) | |
# Concatenate the upsampled features | |
fused_features = F.relu(rgb_upsampled+depth_upsampled) | |
# fused_features = fused_features.sum(dim=1) | |
return fused_features | |