DeepLearning101's picture
Upload 17 files
109bb65
raw
history blame
24.5 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import math
import time
import torch
from torch import nn
from torch.nn import functional as F
from .resample import downsample2, upsample2
from .utils import capture_init
# class BLSTM(nn.Module):
# def __init__(self, dim, layers=2, bi=True):
# super().__init__()
# klass = nn.LSTM
# self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
# self.linear = None
# if bi:
# self.linear = nn.Linear(2 * dim, dim)
# def forward(self, x, hidden=None):
# x, hidden = self.lstm(x, hidden)
# if self.linear:
# x = self.linear(x)
# return x, hidden
EPS = 1e-8
class Chomp1d(nn.Module):
"""To ensure the output length is the same as the input.
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
"""
Args:
x: [M, H, Kpad]
Returns:
[M, H, K]
"""
return x[:, :, :-self.chomp_size].contiguous()
def chose_norm(norm_type, channel_size):
"""The input of normlization will be (M, C, K), where M is batch size,
C is channel size and K is sequence length.
"""
if norm_type == "gLN":
return GlobalLayerNorm(channel_size)
elif norm_type == "cLN":
return ChannelwiseLayerNorm(channel_size)
else: # norm_type == "BN":
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
# along M and K, so this BN usage is right.
return nn.BatchNorm1d(channel_size)
class ChannelwiseLayerNorm(nn.Module):
"""Channel-wise Layer Normalization (cLN)"""
def __init__(self, channel_size):
super(ChannelwiseLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
self.reset_parameters()
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, y):
"""
Args:
y: [M, N, K], M is batch size, N is channel size, K is length
Returns:
cLN_y: [M, N, K]
"""
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return cLN_y
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, norm_type="gLN", causal=False):
super(DepthwiseSeparableConv, self).__init__()
# Use `groups` option to implement depthwise convolution
# [M, H, K] -> [M, H, K]
depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size,
stride=stride, padding=padding,
dilation=dilation, groups=in_channels,
bias=False)
if causal:
chomp = Chomp1d(padding)
prelu = nn.PReLU()
norm = chose_norm(norm_type, in_channels)
# [M, H, K] -> [M, B, K]
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
# Put together
if causal:
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm,
pointwise_conv)
else:
self.net = nn.Sequential(depthwise_conv, prelu, norm,
pointwise_conv)
def forward(self, x):
"""
Args:
x: [M, H, K]
Returns:
result: [M, B, K]
"""
return self.net(x)
class TemporalBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, norm_type="gLN", causal=False):
super(TemporalBlock, self).__init__()
# [M, B, K] -> [M, H, K]
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
prelu = nn.PReLU()
norm = chose_norm(norm_type, out_channels)
# [M, H, K] -> [M, B, K]
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size,
stride, padding, dilation, norm_type,
causal)
# Put together
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
def forward(self, x):
"""
Args:
x: [M, B, K]
Returns:
[M, B, K]
"""
residual = x
out = self.net(x)
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
return out + residual # look like w/o F.relu is better than w/ F.relu
# return F.relu(out + residual)
class GlobalLayerNorm(nn.Module):
"""Global Layer Normalization (gLN)"""
def __init__(self, channel_size):
super(GlobalLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
self.reset_parameters()
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, y):
"""
Args:
y: [M, N, K], M is batch size, N is channel size, K is length
Returns:
gLN_y: [M, N, K]
"""
# TODO: in torch 1.0, torch.mean() support dim list
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
return gLN_y
class TemporalConvNet(nn.Module):
def __init__(self, N=768, B=256, H=512, P=3, X=8, R=4, C=1, norm_type="gLN", causal=1,
mask_nonlinear='relu'):
"""
Args:
N: Number of filters in autoencoder
B: Number of channels in bottleneck 1 × 1-conv block
H: Number of channels in convolutional blocks
P: Kernel size in convolutional blocks
X: Number of convolutional blocks in each repeat
R: Number of repeats
C: Number of speakers
norm_type: BN, gLN, cLN
causal: causal or non-causal
mask_nonlinear: use which non-linear function to generate mask
"""
super(TemporalConvNet, self).__init__()
# Hyper-parameter
self.C = C
self.mask_nonlinear = mask_nonlinear
# Components
# [M, N, K] -> [M, N, K]
layer_norm = ChannelwiseLayerNorm(N)
# [M, N, K] -> [M, B, K]
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
# [M, B, K] -> [M, B, K]
repeats = []
for r in range(R):
blocks = []
for x in range(X):
dilation = 2**x
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
blocks += [TemporalBlock(B, H, P, stride=1,
padding=padding,
dilation=dilation,
norm_type=norm_type,
causal=causal)]
repeats += [nn.Sequential(*blocks)]
temporal_conv_net = nn.Sequential(*repeats)
# [M, B, K] -> [M, C*N, K]
mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias=False)
# Put together
self.network = nn.Sequential(layer_norm,
bottleneck_conv1x1,
temporal_conv_net,
mask_conv1x1)
def forward(self, mixture_w):
"""
Keep this API same with TasNet
Args:
mixture_w: [M, N, K], M is batch size
returns:
est_mask: [M, C, N, K]
"""
M, N, K = mixture_w.size()
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
if self.mask_nonlinear == 'softmax':
est_mask = F.softmax(score, dim=1)
est_mask = est_mask.squeeze(1)
elif self.mask_nonlinear == 'relu':
est_mask = F.relu(score)
est_mask = est_mask.squeeze(1)
else:
raise ValueError("Unsupported mask non-linear function")
return est_mask
def rescale_conv(conv, reference):
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)
class Demucs(nn.Module):
"""
Demucs speech enhancement model.
Args:
- chin (int): number of input channels.
- chout (int): number of output channels.
- hidden (int): number of initial hidden channels.
- depth (int): number of layers.
- kernel_size (int): kernel size for each layer.
- stride (int): stride for each layer.
- causal (bool): if false, uses BiLSTM instead of LSTM.
- resample (int): amount of resampling to apply to the input/output.
Can be one of 1, 2 or 4.
- growth (float): number of channels is multiplied by this for every layer.
- max_hidden (int): maximum number of channels. Can be useful to
control the size/speed of the model.
- normalize (bool): if true, normalize the input.
- glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions.
- rescale (float): controls custom weight initialization.
See https://arxiv.org/abs/1911.13254.
- floor (float): stability flooring when normalizing.
"""
@capture_init
def __init__(self,
chin=1,
chout=1,
hidden=48,
depth=5,
kernel_size=8,
stride=4,
causal=True,
resample=4,
growth=2,
max_hidden=10_000,
normalize=True,
glu=True,
rescale=0.1,
floor=1e-3):
super().__init__()
if resample not in [1, 2, 4]:
raise ValueError("Resample should be 1, 2 or 4.")
self.chin = chin
self.chout = chout
self.hidden = hidden
self.depth = depth
self.kernel_size = kernel_size
self.stride = stride
self.causal = causal
self.floor = floor
self.resample = resample
self.normalize = normalize
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
activation = nn.GLU(1) if glu else nn.ReLU()
ch_scale = 2 if glu else 1
for index in range(depth):
encode = []
encode += [
nn.Conv1d(chin, hidden, kernel_size, stride),
nn.ReLU(),
nn.Conv1d(hidden, hidden * ch_scale, 1), activation,
]
self.encoder.append(nn.Sequential(*encode))
decode = []
decode += [
nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
]
if index > 0:
decode.append(nn.ReLU())
self.decoder.insert(0, nn.Sequential(*decode))
chout = hidden
chin = hidden
hidden = min(int(growth * hidden), max_hidden)
# import pdb; pdb.set_trace()
self.separator = TemporalConvNet(N=chout)
# self.lstm = BLSTM(chin, bi=not causal)
if rescale:
rescale_module(self, reference=rescale)
def valid_length(self, length):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length.
"""
length = math.ceil(length * self.resample)
for idx in range(self.depth):
length = math.ceil((length - self.kernel_size) / self.stride) + 1
length = max(length, 1)
for idx in range(self.depth):
length = (length - 1) * self.stride + self.kernel_size
length = int(math.ceil(length / self.resample))
return int(length)
@property
def total_stride(self):
return self.stride ** self.depth // self.resample
def forward(self, mix):
if mix.dim() == 2:
mix = mix.unsqueeze(1)
if self.normalize:
mono = mix.mean(dim=1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
mix = mix / (self.floor + std)
else:
std = 1
length = mix.shape[-1]
x = mix
x = F.pad(x, (0, self.valid_length(length) - length))
if self.resample == 2:
x = upsample2(x)
elif self.resample == 4:
x = upsample2(x)
x = upsample2(x)
skips = []
for encode in self.encoder:
x = encode(x)
skips.append(x)
x = self.separator(x)
# x = x.permute(2, 0, 1)
# x, _ = self.lstm(x)
# x = x.permute(1, 2, 0)
# import pdb; pdb.set_trace()
for decode in self.decoder:
skip = skips.pop(-1)
x = x + skip[..., :x.shape[-1]]
x = decode(x)
if self.resample == 2:
x = downsample2(x)
elif self.resample == 4:
x = downsample2(x)
x = downsample2(x)
x = x[..., :length]
return std * x
def fast_conv(conv, x):
"""
Faster convolution evaluation if either kernel size is 1
or length of sequence is 1.
"""
batch, chin, length = x.shape
chout, chin, kernel = conv.weight.shape
assert batch == 1
if kernel == 1:
x = x.view(chin, length)
out = th.addmm(conv.bias.view(-1, 1),
conv.weight.view(chout, chin), x)
elif length == kernel:
x = x.view(chin * kernel, 1)
out = th.addmm(conv.bias.view(-1, 1),
conv.weight.view(chout, chin * kernel), x)
else:
out = conv(x)
return out.view(batch, chout, -1)
class DemucsStreamer:
"""
Streaming implementation for Demucs. It supports being fed with any amount
of audio at a time. You will get back as much audio as possible at that
point.
Args:
- demucs (Demucs): Demucs model.
- dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum
noise removal, 1 just returns the input signal. Small values > 0
allows to limit distortions.
- num_frames (int): number of frames to process at once. Higher values
will increase overall latency but improve the real time factor.
- resample_lookahead (int): extra lookahead used for the resampling.
- resample_buffer (int): size of the buffer of previous inputs/outputs
kept for resampling.
"""
def __init__(self, demucs,
dry=0,
num_frames=1,
resample_lookahead=64,
resample_buffer=256):
device = next(iter(demucs.parameters())).device
self.demucs = demucs
self.lstm_state = None
self.conv_state = None
self.dry = dry
self.resample_lookahead = resample_lookahead
self.resample_buffer = resample_buffer
self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1)
self.total_length = self.frame_length + self.resample_lookahead
self.stride = demucs.total_stride * num_frames
self.resample_in = torch.zeros(demucs.chin, resample_buffer, device=device)
self.resample_out = torch.zeros(demucs.chin, resample_buffer, device=device)
self.frames = 0
self.total_time = 0
self.variance = 0
self.pending = torch.zeros(demucs.chin, 0, device=device)
bias = demucs.decoder[0][2].bias
weight = demucs.decoder[0][2].weight
chin, chout, kernel = weight.shape
self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1)
self._weight = weight.permute(1, 2, 0).contiguous()
def reset_time_per_frame(self):
self.total_time = 0
self.frames = 0
@property
def time_per_frame(self):
return self.total_time / self.frames
def flush(self):
"""
Flush remaining audio by padding it with zero. Call this
when you have no more input and want to get back the last chunk of audio.
"""
pending_length = self.pending.shape[1]
padding = torch.zeros(self.demucs.chin, self.total_length, device=self.pending.device)
out = self.feed(padding)
return out[:, :pending_length]
def feed(self, wav):
"""
Apply the model to mix using true real time evaluation.
Normalization is done online as is the resampling.
"""
begin = time.time()
demucs = self.demucs
resample_buffer = self.resample_buffer
stride = self.stride
resample = demucs.resample
if wav.dim() != 2:
raise ValueError("input wav should be two dimensional.")
chin, _ = wav.shape
if chin != demucs.chin:
raise ValueError(f"Expected {demucs.chin} channels, got {chin}")
self.pending = torch.cat([self.pending, wav], dim=1)
outs = []
while self.pending.shape[1] >= self.total_length:
self.frames += 1
frame = self.pending[:, :self.total_length]
dry_signal = frame[:, :stride]
if demucs.normalize:
mono = frame.mean(0)
variance = (mono**2).mean()
self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance
frame = frame / (demucs.floor + math.sqrt(self.variance))
frame = torch.cat([self.resample_in, frame], dim=-1)
self.resample_in[:] = frame[:, stride - resample_buffer:stride]
if resample == 4:
frame = upsample2(upsample2(frame))
elif resample == 2:
frame = upsample2(frame)
frame = frame[:, resample * resample_buffer:] # remove pre sampling buffer
frame = frame[:, :resample * self.frame_length] # remove extra samples after window
out, extra = self._separate_frame(frame)
padded_out = torch.cat([self.resample_out, out, extra], 1)
self.resample_out[:] = out[:, -resample_buffer:]
if resample == 4:
out = downsample2(downsample2(padded_out))
elif resample == 2:
out = downsample2(padded_out)
else:
out = padded_out
out = out[:, resample_buffer // resample:]
out = out[:, :stride]
if demucs.normalize:
out *= math.sqrt(self.variance)
out = self.dry * dry_signal + (1 - self.dry) * out
outs.append(out)
self.pending = self.pending[:, stride:]
self.total_time += time.time() - begin
if outs:
out = torch.cat(outs, 1)
else:
out = torch.zeros(chin, 0, device=wav.device)
return out
def _separate_frame(self, frame):
demucs = self.demucs
skips = []
next_state = []
first = self.conv_state is None
stride = self.stride * demucs.resample
x = frame[None]
for idx, encode in enumerate(demucs.encoder):
stride //= demucs.stride
length = x.shape[2]
if idx == demucs.depth - 1:
# This is sligthly faster for the last conv
x = fast_conv(encode[0], x)
x = encode[1](x)
x = fast_conv(encode[2], x)
x = encode[3](x)
else:
if not first:
prev = self.conv_state.pop(0)
prev = prev[..., stride:]
tgt = (length - demucs.kernel_size) // demucs.stride + 1
missing = tgt - prev.shape[-1]
offset = length - demucs.kernel_size - demucs.stride * (missing - 1)
x = x[..., offset:]
x = encode[1](encode[0](x))
x = fast_conv(encode[2], x)
x = encode[3](x)
if not first:
x = torch.cat([prev, x], -1)
next_state.append(x)
skips.append(x)
x = x.permute(2, 0, 1)
x, self.lstm_state = demucs.lstm(x, self.lstm_state)
x = x.permute(1, 2, 0)
# In the following, x contains only correct samples, i.e. the one
# for which each time position is covered by two window of the upper layer.
# extra contains extra samples to the right, and is used only as a
# better padding for the online resampling.
extra = None
for idx, decode in enumerate(demucs.decoder):
skip = skips.pop(-1)
x += skip[..., :x.shape[-1]]
x = fast_conv(decode[0], x)
x = decode[1](x)
if extra is not None:
skip = skip[..., x.shape[-1]:]
extra += skip[..., :extra.shape[-1]]
extra = decode[2](decode[1](decode[0](extra)))
x = decode[2](x)
next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1))
if extra is None:
extra = x[..., -demucs.stride:]
else:
extra[..., :demucs.stride] += next_state[-1]
x = x[..., :-demucs.stride]
if not first:
prev = self.conv_state.pop(0)
x[..., :demucs.stride] += prev
if idx != demucs.depth - 1:
x = decode[3](x)
extra = decode[3](extra)
self.conv_state = next_state
return x[0], extra[0]
def test():
import argparse
parser = argparse.ArgumentParser(
"denoiser.demucs",
description="Benchmark the streaming Demucs implementation, "
"as well as checking the delta with the offline implementation.")
parser.add_argument("--resample", default=4, type=int)
parser.add_argument("--hidden", default=48, type=int)
parser.add_argument("--device", default="cpu")
parser.add_argument("-t", "--num_threads", type=int)
parser.add_argument("-f", "--num_frames", type=int, default=1)
args = parser.parse_args()
if args.num_threads:
torch.set_num_threads(args.num_threads)
sr = 16_000
sr_ms = sr / 1000
demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device)
x = torch.randn(1, sr * 4).to(args.device)
out = demucs(x[None])[0]
streamer = DemucsStreamer(demucs, num_frames=args.num_frames)
out_rt = []
frame_size = streamer.total_length
with torch.no_grad():
while x.shape[1] > 0:
out_rt.append(streamer.feed(x[:, :frame_size]))
x = x[:, frame_size:]
frame_size = streamer.demucs.total_stride
out_rt.append(streamer.flush())
out_rt = torch.cat(out_rt, 1)
print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='')
print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='')
print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='')
print(f"delta: {torch.norm(out - out_rt) / torch.norm(out):.2%}, ", end='')
print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}")
if __name__ == "__main__":
test()