TheComputerMan commited on
Commit
baf679b
·
1 Parent(s): 0c5616f

Upload Convolution.py

Browse files
Files changed (1) hide show
  1. Convolution.py +55 -0
Convolution.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # Northwestern Polytechnical University (Pengcheng Guo)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # Adapted by Florian Lux 2021
5
+
6
+
7
+ from torch import nn
8
+
9
+
10
+ class ConvolutionModule(nn.Module):
11
+ """
12
+ ConvolutionModule in Conformer model.
13
+
14
+ Args:
15
+ channels (int): The number of channels of conv layers.
16
+ kernel_size (int): Kernel size of conv layers.
17
+
18
+ """
19
+
20
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
21
+ super(ConvolutionModule, self).__init__()
22
+ # kernel_size should be an odd number for 'SAME' padding
23
+ assert (kernel_size - 1) % 2 == 0
24
+
25
+ self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
26
+ self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
27
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=channels)
28
+ self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
29
+ self.activation = activation
30
+
31
+ def forward(self, x):
32
+ """
33
+ Compute convolution module.
34
+
35
+ Args:
36
+ x (torch.Tensor): Input tensor (#batch, time, channels).
37
+
38
+ Returns:
39
+ torch.Tensor: Output tensor (#batch, time, channels).
40
+
41
+ """
42
+ # exchange the temporal dimension and the feature dimension
43
+ x = x.transpose(1, 2)
44
+
45
+ # GLU mechanism
46
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
47
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
48
+
49
+ # 1D Depthwise Conv
50
+ x = self.depthwise_conv(x)
51
+ x = self.activation(self.norm(x))
52
+
53
+ x = self.pointwise_conv2(x)
54
+
55
+ return x.transpose(1, 2)