|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import time |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from .resample import downsample2, upsample2 |
|
from .utils import capture_init |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EPS = 1e-8 |
|
class Chomp1d(nn.Module): |
|
"""To ensure the output length is the same as the input. |
|
""" |
|
def __init__(self, chomp_size): |
|
super(Chomp1d, self).__init__() |
|
self.chomp_size = chomp_size |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: [M, H, Kpad] |
|
Returns: |
|
[M, H, K] |
|
""" |
|
return x[:, :, :-self.chomp_size].contiguous() |
|
|
|
def chose_norm(norm_type, channel_size): |
|
"""The input of normlization will be (M, C, K), where M is batch size, |
|
C is channel size and K is sequence length. |
|
""" |
|
if norm_type == "gLN": |
|
return GlobalLayerNorm(channel_size) |
|
elif norm_type == "cLN": |
|
return ChannelwiseLayerNorm(channel_size) |
|
else: |
|
|
|
|
|
return nn.BatchNorm1d(channel_size) |
|
|
|
class ChannelwiseLayerNorm(nn.Module): |
|
"""Channel-wise Layer Normalization (cLN)""" |
|
def __init__(self, channel_size): |
|
super(ChannelwiseLayerNorm, self).__init__() |
|
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) |
|
self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.gamma.data.fill_(1) |
|
self.beta.data.zero_() |
|
|
|
def forward(self, y): |
|
""" |
|
Args: |
|
y: [M, N, K], M is batch size, N is channel size, K is length |
|
Returns: |
|
cLN_y: [M, N, K] |
|
""" |
|
mean = torch.mean(y, dim=1, keepdim=True) |
|
var = torch.var(y, dim=1, keepdim=True, unbiased=False) |
|
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta |
|
return cLN_y |
|
|
|
class DepthwiseSeparableConv(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, |
|
stride, padding, dilation, norm_type="gLN", causal=False): |
|
super(DepthwiseSeparableConv, self).__init__() |
|
|
|
|
|
depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, |
|
stride=stride, padding=padding, |
|
dilation=dilation, groups=in_channels, |
|
bias=False) |
|
if causal: |
|
chomp = Chomp1d(padding) |
|
prelu = nn.PReLU() |
|
norm = chose_norm(norm_type, in_channels) |
|
|
|
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) |
|
|
|
if causal: |
|
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, |
|
pointwise_conv) |
|
else: |
|
self.net = nn.Sequential(depthwise_conv, prelu, norm, |
|
pointwise_conv) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: [M, H, K] |
|
Returns: |
|
result: [M, B, K] |
|
""" |
|
return self.net(x) |
|
|
|
class TemporalBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, |
|
stride, padding, dilation, norm_type="gLN", causal=False): |
|
super(TemporalBlock, self).__init__() |
|
|
|
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) |
|
prelu = nn.PReLU() |
|
norm = chose_norm(norm_type, out_channels) |
|
|
|
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, |
|
stride, padding, dilation, norm_type, |
|
causal) |
|
|
|
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: [M, B, K] |
|
Returns: |
|
[M, B, K] |
|
""" |
|
residual = x |
|
out = self.net(x) |
|
|
|
return out + residual |
|
|
|
|
|
class GlobalLayerNorm(nn.Module): |
|
"""Global Layer Normalization (gLN)""" |
|
def __init__(self, channel_size): |
|
super(GlobalLayerNorm, self).__init__() |
|
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) |
|
self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
self.gamma.data.fill_(1) |
|
self.beta.data.zero_() |
|
|
|
def forward(self, y): |
|
""" |
|
Args: |
|
y: [M, N, K], M is batch size, N is channel size, K is length |
|
Returns: |
|
gLN_y: [M, N, K] |
|
""" |
|
|
|
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) |
|
var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) |
|
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta |
|
return gLN_y |
|
|
|
class TemporalConvNet(nn.Module): |
|
def __init__(self, N=768, B=256, H=512, P=3, X=8, R=4, C=1, norm_type="gLN", causal=1, |
|
mask_nonlinear='relu'): |
|
""" |
|
Args: |
|
N: Number of filters in autoencoder |
|
B: Number of channels in bottleneck 1 × 1-conv block |
|
H: Number of channels in convolutional blocks |
|
P: Kernel size in convolutional blocks |
|
X: Number of convolutional blocks in each repeat |
|
R: Number of repeats |
|
C: Number of speakers |
|
norm_type: BN, gLN, cLN |
|
causal: causal or non-causal |
|
mask_nonlinear: use which non-linear function to generate mask |
|
""" |
|
super(TemporalConvNet, self).__init__() |
|
|
|
self.C = C |
|
self.mask_nonlinear = mask_nonlinear |
|
|
|
|
|
layer_norm = ChannelwiseLayerNorm(N) |
|
|
|
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) |
|
|
|
repeats = [] |
|
for r in range(R): |
|
blocks = [] |
|
for x in range(X): |
|
dilation = 2**x |
|
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 |
|
blocks += [TemporalBlock(B, H, P, stride=1, |
|
padding=padding, |
|
dilation=dilation, |
|
norm_type=norm_type, |
|
causal=causal)] |
|
repeats += [nn.Sequential(*blocks)] |
|
temporal_conv_net = nn.Sequential(*repeats) |
|
|
|
mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias=False) |
|
|
|
self.network = nn.Sequential(layer_norm, |
|
bottleneck_conv1x1, |
|
temporal_conv_net, |
|
mask_conv1x1) |
|
|
|
def forward(self, mixture_w): |
|
""" |
|
Keep this API same with TasNet |
|
Args: |
|
mixture_w: [M, N, K], M is batch size |
|
returns: |
|
est_mask: [M, C, N, K] |
|
""" |
|
M, N, K = mixture_w.size() |
|
score = self.network(mixture_w) |
|
score = score.view(M, self.C, N, K) |
|
if self.mask_nonlinear == 'softmax': |
|
est_mask = F.softmax(score, dim=1) |
|
est_mask = est_mask.squeeze(1) |
|
elif self.mask_nonlinear == 'relu': |
|
est_mask = F.relu(score) |
|
est_mask = est_mask.squeeze(1) |
|
else: |
|
raise ValueError("Unsupported mask non-linear function") |
|
return est_mask |
|
|
|
|
|
|
|
def rescale_conv(conv, reference): |
|
std = conv.weight.std().detach() |
|
scale = (std / reference)**0.5 |
|
conv.weight.data /= scale |
|
if conv.bias is not None: |
|
conv.bias.data /= scale |
|
|
|
|
|
def rescale_module(module, reference): |
|
for sub in module.modules(): |
|
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): |
|
rescale_conv(sub, reference) |
|
|
|
|
|
class Demucs(nn.Module): |
|
""" |
|
Demucs speech enhancement model. |
|
Args: |
|
- chin (int): number of input channels. |
|
- chout (int): number of output channels. |
|
- hidden (int): number of initial hidden channels. |
|
- depth (int): number of layers. |
|
- kernel_size (int): kernel size for each layer. |
|
- stride (int): stride for each layer. |
|
- causal (bool): if false, uses BiLSTM instead of LSTM. |
|
- resample (int): amount of resampling to apply to the input/output. |
|
Can be one of 1, 2 or 4. |
|
- growth (float): number of channels is multiplied by this for every layer. |
|
- max_hidden (int): maximum number of channels. Can be useful to |
|
control the size/speed of the model. |
|
- normalize (bool): if true, normalize the input. |
|
- glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions. |
|
- rescale (float): controls custom weight initialization. |
|
See https://arxiv.org/abs/1911.13254. |
|
- floor (float): stability flooring when normalizing. |
|
|
|
""" |
|
@capture_init |
|
def __init__(self, |
|
chin=1, |
|
chout=1, |
|
hidden=48, |
|
depth=5, |
|
kernel_size=8, |
|
stride=4, |
|
causal=True, |
|
resample=4, |
|
growth=2, |
|
max_hidden=10_000, |
|
normalize=True, |
|
glu=True, |
|
rescale=0.1, |
|
floor=1e-3): |
|
|
|
super().__init__() |
|
if resample not in [1, 2, 4]: |
|
raise ValueError("Resample should be 1, 2 or 4.") |
|
|
|
self.chin = chin |
|
self.chout = chout |
|
self.hidden = hidden |
|
self.depth = depth |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.causal = causal |
|
self.floor = floor |
|
self.resample = resample |
|
self.normalize = normalize |
|
|
|
self.encoder = nn.ModuleList() |
|
self.decoder = nn.ModuleList() |
|
activation = nn.GLU(1) if glu else nn.ReLU() |
|
ch_scale = 2 if glu else 1 |
|
|
|
for index in range(depth): |
|
encode = [] |
|
encode += [ |
|
nn.Conv1d(chin, hidden, kernel_size, stride), |
|
nn.ReLU(), |
|
nn.Conv1d(hidden, hidden * ch_scale, 1), activation, |
|
] |
|
self.encoder.append(nn.Sequential(*encode)) |
|
|
|
decode = [] |
|
decode += [ |
|
nn.Conv1d(hidden, ch_scale * hidden, 1), activation, |
|
nn.ConvTranspose1d(hidden, chout, kernel_size, stride), |
|
] |
|
if index > 0: |
|
decode.append(nn.ReLU()) |
|
self.decoder.insert(0, nn.Sequential(*decode)) |
|
chout = hidden |
|
chin = hidden |
|
hidden = min(int(growth * hidden), max_hidden) |
|
|
|
self.separator = TemporalConvNet(N=chout) |
|
|
|
if rescale: |
|
rescale_module(self, reference=rescale) |
|
|
|
def valid_length(self, length): |
|
""" |
|
Return the nearest valid length to use with the model so that |
|
there is no time steps left over in a convolutions, e.g. for all |
|
layers, size of the input - kernel_size % stride = 0. |
|
|
|
If the mixture has a valid length, the estimated sources |
|
will have exactly the same length. |
|
""" |
|
length = math.ceil(length * self.resample) |
|
for idx in range(self.depth): |
|
length = math.ceil((length - self.kernel_size) / self.stride) + 1 |
|
length = max(length, 1) |
|
for idx in range(self.depth): |
|
length = (length - 1) * self.stride + self.kernel_size |
|
length = int(math.ceil(length / self.resample)) |
|
return int(length) |
|
|
|
@property |
|
def total_stride(self): |
|
return self.stride ** self.depth // self.resample |
|
|
|
def forward(self, mix): |
|
if mix.dim() == 2: |
|
mix = mix.unsqueeze(1) |
|
|
|
if self.normalize: |
|
mono = mix.mean(dim=1, keepdim=True) |
|
std = mono.std(dim=-1, keepdim=True) |
|
mix = mix / (self.floor + std) |
|
else: |
|
std = 1 |
|
length = mix.shape[-1] |
|
x = mix |
|
x = F.pad(x, (0, self.valid_length(length) - length)) |
|
if self.resample == 2: |
|
x = upsample2(x) |
|
elif self.resample == 4: |
|
x = upsample2(x) |
|
x = upsample2(x) |
|
skips = [] |
|
for encode in self.encoder: |
|
x = encode(x) |
|
skips.append(x) |
|
x = self.separator(x) |
|
|
|
|
|
|
|
|
|
for decode in self.decoder: |
|
skip = skips.pop(-1) |
|
x = x + skip[..., :x.shape[-1]] |
|
x = decode(x) |
|
if self.resample == 2: |
|
x = downsample2(x) |
|
elif self.resample == 4: |
|
x = downsample2(x) |
|
x = downsample2(x) |
|
|
|
x = x[..., :length] |
|
return std * x |
|
|
|
|
|
def fast_conv(conv, x): |
|
""" |
|
Faster convolution evaluation if either kernel size is 1 |
|
or length of sequence is 1. |
|
""" |
|
batch, chin, length = x.shape |
|
chout, chin, kernel = conv.weight.shape |
|
assert batch == 1 |
|
if kernel == 1: |
|
x = x.view(chin, length) |
|
out = th.addmm(conv.bias.view(-1, 1), |
|
conv.weight.view(chout, chin), x) |
|
elif length == kernel: |
|
x = x.view(chin * kernel, 1) |
|
out = th.addmm(conv.bias.view(-1, 1), |
|
conv.weight.view(chout, chin * kernel), x) |
|
else: |
|
out = conv(x) |
|
return out.view(batch, chout, -1) |
|
|
|
|
|
class DemucsStreamer: |
|
""" |
|
Streaming implementation for Demucs. It supports being fed with any amount |
|
of audio at a time. You will get back as much audio as possible at that |
|
point. |
|
|
|
Args: |
|
- demucs (Demucs): Demucs model. |
|
- dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum |
|
noise removal, 1 just returns the input signal. Small values > 0 |
|
allows to limit distortions. |
|
- num_frames (int): number of frames to process at once. Higher values |
|
will increase overall latency but improve the real time factor. |
|
- resample_lookahead (int): extra lookahead used for the resampling. |
|
- resample_buffer (int): size of the buffer of previous inputs/outputs |
|
kept for resampling. |
|
""" |
|
def __init__(self, demucs, |
|
dry=0, |
|
num_frames=1, |
|
resample_lookahead=64, |
|
resample_buffer=256): |
|
device = next(iter(demucs.parameters())).device |
|
self.demucs = demucs |
|
self.lstm_state = None |
|
self.conv_state = None |
|
self.dry = dry |
|
self.resample_lookahead = resample_lookahead |
|
self.resample_buffer = resample_buffer |
|
self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1) |
|
self.total_length = self.frame_length + self.resample_lookahead |
|
self.stride = demucs.total_stride * num_frames |
|
self.resample_in = torch.zeros(demucs.chin, resample_buffer, device=device) |
|
self.resample_out = torch.zeros(demucs.chin, resample_buffer, device=device) |
|
|
|
self.frames = 0 |
|
self.total_time = 0 |
|
self.variance = 0 |
|
self.pending = torch.zeros(demucs.chin, 0, device=device) |
|
|
|
bias = demucs.decoder[0][2].bias |
|
weight = demucs.decoder[0][2].weight |
|
chin, chout, kernel = weight.shape |
|
self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1) |
|
self._weight = weight.permute(1, 2, 0).contiguous() |
|
|
|
def reset_time_per_frame(self): |
|
self.total_time = 0 |
|
self.frames = 0 |
|
|
|
@property |
|
def time_per_frame(self): |
|
return self.total_time / self.frames |
|
|
|
def flush(self): |
|
""" |
|
Flush remaining audio by padding it with zero. Call this |
|
when you have no more input and want to get back the last chunk of audio. |
|
""" |
|
pending_length = self.pending.shape[1] |
|
padding = torch.zeros(self.demucs.chin, self.total_length, device=self.pending.device) |
|
out = self.feed(padding) |
|
return out[:, :pending_length] |
|
|
|
def feed(self, wav): |
|
""" |
|
Apply the model to mix using true real time evaluation. |
|
Normalization is done online as is the resampling. |
|
""" |
|
begin = time.time() |
|
demucs = self.demucs |
|
resample_buffer = self.resample_buffer |
|
stride = self.stride |
|
resample = demucs.resample |
|
|
|
if wav.dim() != 2: |
|
raise ValueError("input wav should be two dimensional.") |
|
chin, _ = wav.shape |
|
if chin != demucs.chin: |
|
raise ValueError(f"Expected {demucs.chin} channels, got {chin}") |
|
|
|
self.pending = torch.cat([self.pending, wav], dim=1) |
|
outs = [] |
|
while self.pending.shape[1] >= self.total_length: |
|
self.frames += 1 |
|
frame = self.pending[:, :self.total_length] |
|
dry_signal = frame[:, :stride] |
|
if demucs.normalize: |
|
mono = frame.mean(0) |
|
variance = (mono**2).mean() |
|
self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance |
|
frame = frame / (demucs.floor + math.sqrt(self.variance)) |
|
frame = torch.cat([self.resample_in, frame], dim=-1) |
|
self.resample_in[:] = frame[:, stride - resample_buffer:stride] |
|
|
|
if resample == 4: |
|
frame = upsample2(upsample2(frame)) |
|
elif resample == 2: |
|
frame = upsample2(frame) |
|
frame = frame[:, resample * resample_buffer:] |
|
frame = frame[:, :resample * self.frame_length] |
|
|
|
out, extra = self._separate_frame(frame) |
|
padded_out = torch.cat([self.resample_out, out, extra], 1) |
|
self.resample_out[:] = out[:, -resample_buffer:] |
|
if resample == 4: |
|
out = downsample2(downsample2(padded_out)) |
|
elif resample == 2: |
|
out = downsample2(padded_out) |
|
else: |
|
out = padded_out |
|
|
|
out = out[:, resample_buffer // resample:] |
|
out = out[:, :stride] |
|
|
|
if demucs.normalize: |
|
out *= math.sqrt(self.variance) |
|
out = self.dry * dry_signal + (1 - self.dry) * out |
|
outs.append(out) |
|
self.pending = self.pending[:, stride:] |
|
|
|
self.total_time += time.time() - begin |
|
if outs: |
|
out = torch.cat(outs, 1) |
|
else: |
|
out = torch.zeros(chin, 0, device=wav.device) |
|
return out |
|
|
|
def _separate_frame(self, frame): |
|
demucs = self.demucs |
|
skips = [] |
|
next_state = [] |
|
first = self.conv_state is None |
|
stride = self.stride * demucs.resample |
|
x = frame[None] |
|
for idx, encode in enumerate(demucs.encoder): |
|
stride //= demucs.stride |
|
length = x.shape[2] |
|
if idx == demucs.depth - 1: |
|
|
|
x = fast_conv(encode[0], x) |
|
x = encode[1](x) |
|
x = fast_conv(encode[2], x) |
|
x = encode[3](x) |
|
else: |
|
if not first: |
|
prev = self.conv_state.pop(0) |
|
prev = prev[..., stride:] |
|
tgt = (length - demucs.kernel_size) // demucs.stride + 1 |
|
missing = tgt - prev.shape[-1] |
|
offset = length - demucs.kernel_size - demucs.stride * (missing - 1) |
|
x = x[..., offset:] |
|
x = encode[1](encode[0](x)) |
|
x = fast_conv(encode[2], x) |
|
x = encode[3](x) |
|
if not first: |
|
x = torch.cat([prev, x], -1) |
|
next_state.append(x) |
|
skips.append(x) |
|
|
|
x = x.permute(2, 0, 1) |
|
x, self.lstm_state = demucs.lstm(x, self.lstm_state) |
|
x = x.permute(1, 2, 0) |
|
|
|
|
|
|
|
|
|
extra = None |
|
for idx, decode in enumerate(demucs.decoder): |
|
skip = skips.pop(-1) |
|
x += skip[..., :x.shape[-1]] |
|
x = fast_conv(decode[0], x) |
|
x = decode[1](x) |
|
|
|
if extra is not None: |
|
skip = skip[..., x.shape[-1]:] |
|
extra += skip[..., :extra.shape[-1]] |
|
extra = decode[2](decode[1](decode[0](extra))) |
|
x = decode[2](x) |
|
next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1)) |
|
if extra is None: |
|
extra = x[..., -demucs.stride:] |
|
else: |
|
extra[..., :demucs.stride] += next_state[-1] |
|
x = x[..., :-demucs.stride] |
|
|
|
if not first: |
|
prev = self.conv_state.pop(0) |
|
x[..., :demucs.stride] += prev |
|
if idx != demucs.depth - 1: |
|
x = decode[3](x) |
|
extra = decode[3](extra) |
|
self.conv_state = next_state |
|
return x[0], extra[0] |
|
|
|
|
|
def test(): |
|
import argparse |
|
parser = argparse.ArgumentParser( |
|
"denoiser.demucs", |
|
description="Benchmark the streaming Demucs implementation, " |
|
"as well as checking the delta with the offline implementation.") |
|
parser.add_argument("--resample", default=4, type=int) |
|
parser.add_argument("--hidden", default=48, type=int) |
|
parser.add_argument("--device", default="cpu") |
|
parser.add_argument("-t", "--num_threads", type=int) |
|
parser.add_argument("-f", "--num_frames", type=int, default=1) |
|
args = parser.parse_args() |
|
if args.num_threads: |
|
torch.set_num_threads(args.num_threads) |
|
sr = 16_000 |
|
sr_ms = sr / 1000 |
|
demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device) |
|
x = torch.randn(1, sr * 4).to(args.device) |
|
out = demucs(x[None])[0] |
|
streamer = DemucsStreamer(demucs, num_frames=args.num_frames) |
|
out_rt = [] |
|
frame_size = streamer.total_length |
|
with torch.no_grad(): |
|
while x.shape[1] > 0: |
|
out_rt.append(streamer.feed(x[:, :frame_size])) |
|
x = x[:, frame_size:] |
|
frame_size = streamer.demucs.total_stride |
|
out_rt.append(streamer.flush()) |
|
out_rt = torch.cat(out_rt, 1) |
|
print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='') |
|
print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='') |
|
print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='') |
|
print(f"delta: {torch.norm(out - out_rt) / torch.norm(out):.2%}, ", end='') |
|
print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
test() |
|
|