Spaces:
Runtime error
Runtime error
import os, sys | |
from libs import * | |
from .layers import * | |
from .modules import * | |
class LightSEResBlock(nn.Module): | |
def __init__(self, | |
in_channels, | |
downsample = False, | |
): | |
super(LightSEResBlock, self).__init__() | |
if downsample: | |
self.out_channels = in_channels*2 | |
self.conv_1 = DSConv1d( | |
in_channels, self.out_channels, | |
kernel_size = 7, padding = 3, stride = 2, | |
) | |
self.identity = nn.Sequential( | |
DSConv1d( | |
in_channels, self.out_channels, | |
kernel_size = 1, padding = 0, stride = 2, | |
), | |
nn.BatchNorm1d(self.out_channels), | |
) | |
else: | |
self.out_channels = in_channels | |
self.conv_1 = DSConv1d( | |
in_channels, self.out_channels, | |
kernel_size = 7, padding = 3, stride = 1, | |
) | |
self.identity = nn.Identity() | |
self.conv_2 = DSConv1d( | |
self.out_channels, self.out_channels, | |
kernel_size = 7, padding = 3, stride = 1, | |
) | |
self.convs = nn.Sequential( | |
self.conv_1, | |
nn.BatchNorm1d(self.out_channels), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
self.conv_2, | |
nn.BatchNorm1d(self.out_channels), | |
LightSEModule(self.out_channels), | |
) | |
self.act_fn = nn.ReLU() | |
def forward(self, | |
input, | |
): | |
output = self.convs(input) + self.identity(input) | |
output = self.act_fn(output) | |
return output |