Spaces:
Runtime error
Runtime error
File size: 1,641 Bytes
bb18256 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import os, sys
from libs import *
from .layers import *
from .modules import *
class LightSEResBlock(nn.Module):
def __init__(self,
in_channels,
downsample = False,
):
super(LightSEResBlock, self).__init__()
if downsample:
self.out_channels = in_channels*2
self.conv_1 = DSConv1d(
in_channels, self.out_channels,
kernel_size = 7, padding = 3, stride = 2,
)
self.identity = nn.Sequential(
DSConv1d(
in_channels, self.out_channels,
kernel_size = 1, padding = 0, stride = 2,
),
nn.BatchNorm1d(self.out_channels),
)
else:
self.out_channels = in_channels
self.conv_1 = DSConv1d(
in_channels, self.out_channels,
kernel_size = 7, padding = 3, stride = 1,
)
self.identity = nn.Identity()
self.conv_2 = DSConv1d(
self.out_channels, self.out_channels,
kernel_size = 7, padding = 3, stride = 1,
)
self.convs = nn.Sequential(
self.conv_1,
nn.BatchNorm1d(self.out_channels),
nn.ReLU(),
nn.Dropout(0.3),
self.conv_2,
nn.BatchNorm1d(self.out_channels),
LightSEModule(self.out_channels),
)
self.act_fn = nn.ReLU()
def forward(self,
input,
):
output = self.convs(input) + self.identity(input)
output = self.act_fn(output)
return output |