Spaces:
Runtime error
Runtime error
import os, sys | |
from libs import * | |
from .layers import * | |
from .modules import * | |
from .bblocks import * | |
from .backbones import * | |
class LightX3ECG(nn.Module): | |
def __init__(self, | |
base_channels = 64, | |
num_classes = 1, | |
): | |
super(LightX3ECG, self).__init__() | |
self.backbone_0 = LightSEResNet18(base_channels) | |
self.backbone_1 = LightSEResNet18(base_channels) | |
self.backbone_2 = LightSEResNet18(base_channels) | |
self.lw_attention = nn.Sequential( | |
nn.Linear( | |
base_channels*24, base_channels*8, | |
), | |
nn.BatchNorm1d(base_channels*8), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear( | |
base_channels*8, 3, | |
), | |
) | |
self.classifier = nn.Sequential( | |
nn.Dropout(0.2), | |
nn.Linear( | |
base_channels*8, num_classes, | |
), | |
) | |
def forward(self, | |
input, | |
return_attention_scores = False, | |
): | |
features_0 = self.backbone_0(input[:, 0, :].unsqueeze(1)).squeeze(2) | |
features_1 = self.backbone_1(input[:, 1, :].unsqueeze(1)).squeeze(2) | |
features_2 = self.backbone_2(input[:, 2, :].unsqueeze(1)).squeeze(2) | |
attention_scores = torch.sigmoid( | |
self.lw_attention( | |
torch.cat( | |
[ | |
features_0, | |
features_1, | |
features_2, | |
], | |
dim = 1, | |
) | |
) | |
) | |
merged_features = torch.sum( | |
torch.stack( | |
[ | |
features_0, | |
features_1, | |
features_2, | |
], | |
dim = 1, | |
)*attention_scores.unsqueeze(-1), | |
dim = 1, | |
) | |
output = self.classifier(merged_features) | |
if not return_attention_scores: | |
return output | |
else: | |
return output, attention_scores |