File size: 3,205 Bytes
5381499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F


class _Conv1DLayer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
    ):
        """
        Args:
            x (Tensor): The ouput. Shape: (batch, in_channels, in_frames)
            length (Tensor): The valid length of each sample. Shape: (batch)

        Returns:
            x (Tensor): The output. Shape: (batch, out_channels, out_frames)
            length (Tensor): The valid length of each sample. Shape: (batch)
        """
        super().__init__()

        self.kernel_size = kernel_size
        self.stride = stride

        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride,
            kernel_size=kernel_size,
            bias=False,
        )

        self.layernorm = nn.LayerNorm(out_channels)

    def forward(self, x: torch.Tensor, length: torch.Tensor):
        x = self.conv(x)
        x = x.transpose_(1, 2)
        x = self.layernorm(x)
        x = x.transpose_(1, 2)
        x = F.gelu(x)

        length = (length - self.kernel_size) // self.stride + 1
        length = length.clamp_min_(min=0)  # prevent negative lengths
        return x, length


class FeatureExtractor(nn.Module):
    def __init__(self, config):
        """
        Extracts features from the waveform.

        Args:
            waveforms (Tensor): The waveform to extract features from. Shape: (batch, wavelength)
            wavelength (Tensor): The valid length of each waveform. Shape: (batch)

        Returns:
            features (Tensor): The extracted features. Shape: (batch, num_frames, num_channels)
            num_frames (Tensor): The valid length of each feature. Shape: (batch)
        """
        super().__init__()

        num_channels = config.num_channels
        kernel_sizes = config.kernel_sizes
        strides = config.strides

        assert (
            len(num_channels) == len(kernel_sizes) == len(strides)
        ), "The number of layers must be the same for all parameters"

        self.conv_layers = nn.ModuleList(
            (
                _Conv1DLayer(
                    in_channels=1,
                    out_channels=num_channels[0],
                    kernel_size=kernel_sizes[0],
                    stride=strides[0],
                ),
            )
        )

        for i in range(1, len(num_channels)):
            self.conv_layers.append(
                _Conv1DLayer(
                    in_channels=num_channels[i - 1],
                    out_channels=num_channels[i],
                    kernel_size=kernel_sizes[i],
                    stride=strides[i],
                )
            )

    def forward(self, waveforms: torch.Tensor, wavelength: torch.Tensor):
        features = waveforms.unsqueeze(1)

        for conv_layer in self.conv_layers:
            features, wavelength = conv_layer(features, wavelength)

        # (batch, num_channels, num_frames) -> (batch, num_frames, num_channels)
        features = features.transpose(1, 2)
        return features, wavelength