Agusbs98's picture
Upload 37 files
bb18256
raw
history blame
No virus
1.64 kB
import os, sys
from libs import *
from .layers import *
from .modules import *
class LightSEResBlock(nn.Module):
def __init__(self,
in_channels,
downsample = False,
):
super(LightSEResBlock, self).__init__()
if downsample:
self.out_channels = in_channels*2
self.conv_1 = DSConv1d(
in_channels, self.out_channels,
kernel_size = 7, padding = 3, stride = 2,
)
self.identity = nn.Sequential(
DSConv1d(
in_channels, self.out_channels,
kernel_size = 1, padding = 0, stride = 2,
),
nn.BatchNorm1d(self.out_channels),
)
else:
self.out_channels = in_channels
self.conv_1 = DSConv1d(
in_channels, self.out_channels,
kernel_size = 7, padding = 3, stride = 1,
)
self.identity = nn.Identity()
self.conv_2 = DSConv1d(
self.out_channels, self.out_channels,
kernel_size = 7, padding = 3, stride = 1,
)
self.convs = nn.Sequential(
self.conv_1,
nn.BatchNorm1d(self.out_channels),
nn.ReLU(),
nn.Dropout(0.3),
self.conv_2,
nn.BatchNorm1d(self.out_channels),
LightSEModule(self.out_channels),
)
self.act_fn = nn.ReLU()
def forward(self,
input,
):
output = self.convs(input) + self.identity(input)
output = self.act_fn(output)
return output