mpc001's picture
Upload 125 files
09481f3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""ConvolutionModule definition."""
import torch
from torch import nn
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
:param int channels: channels of cnn
:param int kernel_size: kernerl size of cnn
"""
def __init__(self, channels, kernel_size, bias=True):
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_cov1 = nn.Conv1d(
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_cov2 = nn.Conv1d(
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
)
self.activation = Swish()
def forward(self, x):
"""Compute covolution module.
:param torch.Tensor x: (batch, time, size)
:return torch.Tensor: convoluted `value` (batch, time, d_model)
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_cov1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_cov2(x)
return x.transpose(1, 2)
class Swish(nn.Module):
"""Construct an Swish object."""
def forward(self, x):
"""Return Swich activation function."""
return x * torch.sigmoid(x)