tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
7.63 kB
# The implementation of DPRNN in
# Luo. et al. "Dual-path rnn: efficient long sequence modeling
# for time-domain single-channel speech separation."
#
# The code is based on:
# https://github.com/yluo42/TAC/blob/master/utility/models.py
#
import torch
from torch.autograd import Variable
import torch.nn as nn
EPS = torch.finfo(torch.get_default_dtype()).eps
class SingleRNN(nn.Module):
"""Container module for a single RNN layer.
args:
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
input_size: int, dimension of the input feature. The input should have shape
(batch, seq_len, input_size).
hidden_size: int, dimension of the hidden state.
dropout: float, dropout ratio. Default is 0.
bidirectional: bool, whether the RNN layers are bidirectional. Default is False.
"""
def __init__(
self, rnn_type, input_size, hidden_size, dropout=0, bidirectional=False
):
super().__init__()
rnn_type = rnn_type.upper()
assert rnn_type in [
"RNN",
"LSTM",
"GRU",
], f"Only support 'RNN', 'LSTM' and 'GRU', current type: {rnn_type}"
self.rnn_type = rnn_type
self.input_size = input_size
self.hidden_size = hidden_size
self.num_direction = int(bidirectional) + 1
self.rnn = getattr(nn, rnn_type)(
input_size,
hidden_size,
1,
batch_first=True,
bidirectional=bidirectional,
)
self.dropout = nn.Dropout(p=dropout)
# linear projection layer
self.proj = nn.Linear(hidden_size * self.num_direction, input_size)
def forward(self, input):
# input shape: batch, seq, dim
# input = input.to(device)
output = input
rnn_output, _ = self.rnn(output)
rnn_output = self.dropout(rnn_output)
rnn_output = self.proj(
rnn_output.contiguous().view(-1, rnn_output.shape[2])
).view(output.shape)
return rnn_output
# dual-path RNN
class DPRNN(nn.Module):
"""Deep dual-path RNN.
args:
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
input_size: int, dimension of the input feature. The input should have shape
(batch, seq_len, input_size).
hidden_size: int, dimension of the hidden state.
output_size: int, dimension of the output size.
dropout: float, dropout ratio. Default is 0.
num_layers: int, number of stacked RNN layers. Default is 1.
bidirectional: bool, whether the RNN layers are bidirectional. Default is True.
"""
def __init__(
self,
rnn_type,
input_size,
hidden_size,
output_size,
dropout=0,
num_layers=1,
bidirectional=True,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
# dual-path RNN
self.row_rnn = nn.ModuleList([])
self.col_rnn = nn.ModuleList([])
self.row_norm = nn.ModuleList([])
self.col_norm = nn.ModuleList([])
for i in range(num_layers):
self.row_rnn.append(
SingleRNN(
rnn_type, input_size, hidden_size, dropout, bidirectional=True
)
) # intra-segment RNN is always noncausal
self.col_rnn.append(
SingleRNN(
rnn_type,
input_size,
hidden_size,
dropout,
bidirectional=bidirectional,
)
)
self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
# default is to use noncausal LayerNorm for inter-chunk RNN.
# For causal setting change it to causal normalization accordingly.
self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
# output layer
self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1))
def forward(self, input):
# input shape: batch, N, dim1, dim2
# apply RNN on dim1 first and then dim2
# output shape: B, output_size, dim1, dim2
# input = input.to(device)
batch_size, _, dim1, dim2 = input.shape
output = input
for i in range(len(self.row_rnn)):
row_input = (
output.permute(0, 3, 2, 1)
.contiguous()
.view(batch_size * dim2, dim1, -1)
) # B*dim2, dim1, N
row_output = self.row_rnn[i](row_input) # B*dim2, dim1, H
row_output = (
row_output.view(batch_size, dim2, dim1, -1)
.permute(0, 3, 2, 1)
.contiguous()
) # B, N, dim1, dim2
row_output = self.row_norm[i](row_output)
output = output + row_output
col_input = (
output.permute(0, 2, 3, 1)
.contiguous()
.view(batch_size * dim1, dim2, -1)
) # B*dim1, dim2, N
col_output = self.col_rnn[i](col_input) # B*dim1, dim2, H
col_output = (
col_output.view(batch_size, dim1, dim2, -1)
.permute(0, 3, 1, 2)
.contiguous()
) # B, N, dim1, dim2
col_output = self.col_norm[i](col_output)
output = output + col_output
output = self.output(output) # B, output_size, dim1, dim2
return output
def _pad_segment(input, segment_size):
# input is the features: (B, N, T)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
if rest > 0:
pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
input = torch.cat([input, pad], 2)
pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 2)
return input, rest
def split_feature(input, segment_size):
# split the feature into chunks of segment size
# input is the features: (B, N, T)
input, rest = _pad_segment(input, segment_size)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
segments1 = (
input[:, :, :-segment_stride]
.contiguous()
.view(batch_size, dim, -1, segment_size)
)
segments2 = (
input[:, :, segment_stride:]
.contiguous()
.view(batch_size, dim, -1, segment_size)
)
segments = (
torch.cat([segments1, segments2], 3)
.view(batch_size, dim, -1, segment_size)
.transpose(2, 3)
)
return segments.contiguous(), rest
def merge_feature(input, rest):
# merge the splitted features into full utterance
# input is the features: (B, N, L, K)
batch_size, dim, segment_size, _ = input.shape
segment_stride = segment_size // 2
input = (
input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2)
) # B, N, K, L
input1 = (
input[:, :, :, :segment_size]
.contiguous()
.view(batch_size, dim, -1)[:, :, segment_stride:]
)
input2 = (
input[:, :, :, segment_size:]
.contiguous()
.view(batch_size, dim, -1)[:, :, :-segment_stride]
)
output = input1 + input2
if rest > 0:
output = output[:, :, :-rest]
return output.contiguous() # B, N, T