DeepLearning101
commited on
Commit
·
109bb65
1
Parent(s):
2917403
Upload 17 files
Browse files- denoiser/__init__.py +5 -0
- denoiser/audio.py +89 -0
- denoiser/augment.py +191 -0
- denoiser/conv_demucs.py +661 -0
- denoiser/data.py +99 -0
- denoiser/demucs.py +449 -0
- denoiser/distrib.py +100 -0
- denoiser/dsp.py +64 -0
- denoiser/enhance.py +138 -0
- denoiser/evaluate.py +136 -0
- denoiser/executor.py +79 -0
- denoiser/live.py +161 -0
- denoiser/pretrained.py +72 -0
- denoiser/resample.py +75 -0
- denoiser/solver.py +233 -0
- denoiser/stft_loss.py +144 -0
- denoiser/utils.py +165 -0
denoiser/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
denoiser/audio.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
import math
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import torchaudio
|
15 |
+
from torch.nn import functional as F
|
16 |
+
|
17 |
+
|
18 |
+
def find_audio_files(path, exts=[".wav"], progress=True):
|
19 |
+
audio_files = []
|
20 |
+
for root, folders, files in os.walk(path, followlinks=True):
|
21 |
+
for file in files:
|
22 |
+
file = Path(root) / file
|
23 |
+
if file.suffix.lower() in exts:
|
24 |
+
audio_files.append(str(file.resolve()))
|
25 |
+
meta = []
|
26 |
+
for idx, file in enumerate(audio_files):
|
27 |
+
siginfo, _ = torchaudio.info(file)
|
28 |
+
length = siginfo.length // siginfo.channels
|
29 |
+
meta.append((file, length))
|
30 |
+
if progress:
|
31 |
+
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
|
32 |
+
meta.sort()
|
33 |
+
return meta
|
34 |
+
|
35 |
+
|
36 |
+
class Audioset:
|
37 |
+
def __init__(self, files=None, length=None, stride=None,
|
38 |
+
pad=True, with_path=False, sample_rate=None):
|
39 |
+
"""
|
40 |
+
files should be a list [(file, length)]
|
41 |
+
"""
|
42 |
+
self.files = files
|
43 |
+
self.num_examples = []
|
44 |
+
self.length = length
|
45 |
+
self.stride = stride or length
|
46 |
+
self.with_path = with_path
|
47 |
+
self.sample_rate = sample_rate
|
48 |
+
for file, file_length in self.files:
|
49 |
+
if length is None:
|
50 |
+
examples = 1
|
51 |
+
elif file_length < length:
|
52 |
+
examples = 1 if pad else 0
|
53 |
+
elif pad:
|
54 |
+
examples = int(math.ceil((file_length - self.length) / self.stride) + 1)
|
55 |
+
else:
|
56 |
+
examples = (file_length - self.length) // self.stride + 1
|
57 |
+
self.num_examples.append(examples)
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return sum(self.num_examples)
|
61 |
+
|
62 |
+
def __getitem__(self, index):
|
63 |
+
for (file, _), examples in zip(self.files, self.num_examples):
|
64 |
+
if index >= examples:
|
65 |
+
index -= examples
|
66 |
+
continue
|
67 |
+
num_frames = 0
|
68 |
+
offset = 0
|
69 |
+
if self.length is not None:
|
70 |
+
offset = self.stride * index
|
71 |
+
num_frames = self.length
|
72 |
+
out, sr = torchaudio.load(str(file), offset=offset, num_frames=num_frames)
|
73 |
+
if self.sample_rate is not None:
|
74 |
+
if sr != self.sample_rate:
|
75 |
+
raise RuntimeError(f"Expected {file} to have sample rate of "
|
76 |
+
f"{self.sample_rate}, but got {sr}")
|
77 |
+
if num_frames:
|
78 |
+
out = F.pad(out, (0, num_frames - out.shape[-1]))
|
79 |
+
if self.with_path:
|
80 |
+
return out, file
|
81 |
+
else:
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
meta = []
|
87 |
+
for path in sys.argv[1:]:
|
88 |
+
meta += find_audio_files(path)
|
89 |
+
json.dump(meta, sys.stdout, indent=4)
|
denoiser/augment.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import random
|
9 |
+
import torch as th
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
from . import dsp
|
14 |
+
|
15 |
+
|
16 |
+
class Remix(nn.Module):
|
17 |
+
"""Remix.
|
18 |
+
Mixes different noises with clean speech within a given batch
|
19 |
+
"""
|
20 |
+
|
21 |
+
def forward(self, sources):
|
22 |
+
noise, clean = sources
|
23 |
+
bs, *other = noise.shape
|
24 |
+
device = noise.device
|
25 |
+
perm = th.argsort(th.rand(bs, device=device), dim=0)
|
26 |
+
return th.stack([noise[perm], clean])
|
27 |
+
|
28 |
+
|
29 |
+
class RevEcho(nn.Module):
|
30 |
+
"""
|
31 |
+
Hacky Reverb but runs on GPU without slowing down training.
|
32 |
+
This reverb adds a succession of attenuated echos of the input
|
33 |
+
signal to itself. Intuitively, the delay of the first echo will happen
|
34 |
+
after roughly 2x the radius of the room and is controlled by `first_delay`.
|
35 |
+
Then RevEcho keeps adding echos with the same delay and further attenuation
|
36 |
+
until the amplitude ratio between the last and first echo is 1e-3.
|
37 |
+
The attenuation factor and the number of echos to adds is controlled
|
38 |
+
by RT60 (measured in seconds). RT60 is the average time to get to -60dB
|
39 |
+
(remember volume is measured over the squared amplitude so this matches
|
40 |
+
the 1e-3 ratio).
|
41 |
+
|
42 |
+
At each call to RevEcho, `first_delay`, `initial` and `RT60` are
|
43 |
+
sampled from their range. Then, to prevent this reverb from being too regular,
|
44 |
+
the delay time is resampled uniformly within `first_delay +- 10%`,
|
45 |
+
as controlled by the `jitter` parameter. Finally, for a denser reverb,
|
46 |
+
multiple trains of echos are added with different jitter noises.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
- initial: amplitude of the first echo as a fraction
|
50 |
+
of the input signal. For each sample, actually sampled from
|
51 |
+
`[0, initial]`. Larger values means louder reverb. Physically,
|
52 |
+
this would depend on the absorption of the room walls.
|
53 |
+
- rt60: range of values to sample the RT60 in seconds, i.e.
|
54 |
+
after RT60 seconds, the echo amplitude is 1e-3 of the first echo.
|
55 |
+
The default values follow the recommendations of
|
56 |
+
https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf, Section 2.4.
|
57 |
+
Physically this would also be related to the absorption of the
|
58 |
+
room walls and there is likely a relation between `RT60` and
|
59 |
+
`initial`, which we ignore here.
|
60 |
+
- first_delay: range of values to sample the first echo delay in seconds.
|
61 |
+
The default values are equivalent to sampling a room of 3 to 10 meters.
|
62 |
+
- repeat: how many train of echos with differents jitters to add.
|
63 |
+
Higher values means a denser reverb.
|
64 |
+
- jitter: jitter used to make each repetition of the reverb echo train
|
65 |
+
slightly different. For instance a jitter of 0.1 means
|
66 |
+
the delay between two echos will be in the range `first_delay +- 10%`,
|
67 |
+
with the jittering noise being resampled after each single echo.
|
68 |
+
- keep_clean: fraction of the reverb of the clean speech to add back
|
69 |
+
to the ground truth. 0 = dereverberation, 1 = no dereverberation.
|
70 |
+
- sample_rate: sample rate of the input signals.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, proba=0.5, initial=0.3, rt60=(0.3, 1.3), first_delay=(0.01, 0.03),
|
74 |
+
repeat=3, jitter=0.1, keep_clean=0.1, sample_rate=16000):
|
75 |
+
super().__init__()
|
76 |
+
self.proba = proba
|
77 |
+
self.initial = initial
|
78 |
+
self.rt60 = rt60
|
79 |
+
self.first_delay = first_delay
|
80 |
+
self.repeat = repeat
|
81 |
+
self.jitter = jitter
|
82 |
+
self.keep_clean = keep_clean
|
83 |
+
self.sample_rate = sample_rate
|
84 |
+
|
85 |
+
def _reverb(self, source, initial, first_delay, rt60):
|
86 |
+
"""
|
87 |
+
Return the reverb for a single source.
|
88 |
+
"""
|
89 |
+
length = source.shape[-1]
|
90 |
+
reverb = th.zeros_like(source)
|
91 |
+
for _ in range(self.repeat):
|
92 |
+
frac = 1 # what fraction of the first echo amplitude is still here
|
93 |
+
echo = initial * source
|
94 |
+
while frac > 1e-3:
|
95 |
+
# First jitter noise for the delay
|
96 |
+
jitter = 1 + self.jitter * random.uniform(-1, 1)
|
97 |
+
delay = min(
|
98 |
+
1 + int(jitter * first_delay * self.sample_rate),
|
99 |
+
length)
|
100 |
+
# Delay the echo in time by padding with zero on the left
|
101 |
+
echo = F.pad(echo[:, :, :-delay], (delay, 0))
|
102 |
+
reverb += echo
|
103 |
+
|
104 |
+
# Second jitter noise for the attenuation
|
105 |
+
jitter = 1 + self.jitter * random.uniform(-1, 1)
|
106 |
+
# we want, with `d` the attenuation, d**(rt60 / first_ms) = 1e-3
|
107 |
+
# i.e. log10(d) = -3 * first_ms / rt60, so that
|
108 |
+
attenuation = 10**(-3 * jitter * first_delay / rt60)
|
109 |
+
echo *= attenuation
|
110 |
+
frac *= attenuation
|
111 |
+
return reverb
|
112 |
+
|
113 |
+
def forward(self, wav):
|
114 |
+
if random.random() >= self.proba:
|
115 |
+
return wav
|
116 |
+
noise, clean = wav
|
117 |
+
# Sample characteristics for the reverb
|
118 |
+
initial = random.random() * self.initial
|
119 |
+
first_delay = random.uniform(*self.first_delay)
|
120 |
+
rt60 = random.uniform(*self.rt60)
|
121 |
+
|
122 |
+
reverb_noise = self._reverb(noise, initial, first_delay, rt60)
|
123 |
+
# Reverb for the noise is always added back to the noise
|
124 |
+
noise += reverb_noise
|
125 |
+
reverb_clean = self._reverb(clean, initial, first_delay, rt60)
|
126 |
+
# Split clean reverb among the clean speech and noise
|
127 |
+
clean += self.keep_clean * reverb_clean
|
128 |
+
noise += (1 - self.keep_clean) * reverb_clean
|
129 |
+
|
130 |
+
return th.stack([noise, clean])
|
131 |
+
|
132 |
+
|
133 |
+
class BandMask(nn.Module):
|
134 |
+
"""BandMask.
|
135 |
+
Maskes bands of frequencies. Similar to Park, Daniel S., et al.
|
136 |
+
"Specaugment: A simple data augmentation method for automatic speech recognition."
|
137 |
+
(https://arxiv.org/pdf/1904.08779.pdf) but over the waveform.
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self, maxwidth=0.2, bands=120, sample_rate=16_000):
|
141 |
+
"""__init__.
|
142 |
+
|
143 |
+
:param maxwidth: the maximum width to remove
|
144 |
+
:param bands: number of bands
|
145 |
+
:param sample_rate: signal sample rate
|
146 |
+
"""
|
147 |
+
super().__init__()
|
148 |
+
self.maxwidth = maxwidth
|
149 |
+
self.bands = bands
|
150 |
+
self.sample_rate = sample_rate
|
151 |
+
|
152 |
+
def forward(self, wav):
|
153 |
+
bands = self.bands
|
154 |
+
bandwidth = int(abs(self.maxwidth) * bands)
|
155 |
+
mels = dsp.mel_frequencies(bands, 40, self.sample_rate/2) / self.sample_rate
|
156 |
+
low = random.randrange(bands)
|
157 |
+
high = random.randrange(low, min(bands, low + bandwidth))
|
158 |
+
filters = dsp.LowPassFilters([mels[low], mels[high]]).to(wav.device)
|
159 |
+
low, midlow = filters(wav)
|
160 |
+
# band pass filtering
|
161 |
+
out = wav - midlow + low
|
162 |
+
return out
|
163 |
+
|
164 |
+
|
165 |
+
class Shift(nn.Module):
|
166 |
+
"""Shift."""
|
167 |
+
|
168 |
+
def __init__(self, shift=8192, same=False):
|
169 |
+
"""__init__.
|
170 |
+
|
171 |
+
:param shift: randomly shifts the signals up to a given factor
|
172 |
+
:param same: shifts both clean and noisy files by the same factor
|
173 |
+
"""
|
174 |
+
super().__init__()
|
175 |
+
self.shift = shift
|
176 |
+
self.same = same
|
177 |
+
|
178 |
+
def forward(self, wav):
|
179 |
+
sources, batch, channels, length = wav.shape
|
180 |
+
length = length - self.shift
|
181 |
+
if self.shift > 0:
|
182 |
+
if not self.training:
|
183 |
+
wav = wav[..., :length]
|
184 |
+
else:
|
185 |
+
offsets = th.randint(
|
186 |
+
self.shift,
|
187 |
+
[1 if self.same else sources, batch, 1, 1], device=wav.device)
|
188 |
+
offsets = offsets.expand(sources, -1, channels, -1)
|
189 |
+
indexes = th.arange(length, device=wav.device)
|
190 |
+
wav = wav.gather(3, indexes + offsets)
|
191 |
+
return wav
|
denoiser/conv_demucs.py
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import math
|
9 |
+
import time
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from .resample import downsample2, upsample2
|
16 |
+
from .utils import capture_init
|
17 |
+
|
18 |
+
|
19 |
+
# class BLSTM(nn.Module):
|
20 |
+
# def __init__(self, dim, layers=2, bi=True):
|
21 |
+
# super().__init__()
|
22 |
+
# klass = nn.LSTM
|
23 |
+
# self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
|
24 |
+
# self.linear = None
|
25 |
+
# if bi:
|
26 |
+
# self.linear = nn.Linear(2 * dim, dim)
|
27 |
+
|
28 |
+
# def forward(self, x, hidden=None):
|
29 |
+
# x, hidden = self.lstm(x, hidden)
|
30 |
+
# if self.linear:
|
31 |
+
# x = self.linear(x)
|
32 |
+
# return x, hidden
|
33 |
+
|
34 |
+
EPS = 1e-8
|
35 |
+
class Chomp1d(nn.Module):
|
36 |
+
"""To ensure the output length is the same as the input.
|
37 |
+
"""
|
38 |
+
def __init__(self, chomp_size):
|
39 |
+
super(Chomp1d, self).__init__()
|
40 |
+
self.chomp_size = chomp_size
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
x: [M, H, Kpad]
|
46 |
+
Returns:
|
47 |
+
[M, H, K]
|
48 |
+
"""
|
49 |
+
return x[:, :, :-self.chomp_size].contiguous()
|
50 |
+
|
51 |
+
def chose_norm(norm_type, channel_size):
|
52 |
+
"""The input of normlization will be (M, C, K), where M is batch size,
|
53 |
+
C is channel size and K is sequence length.
|
54 |
+
"""
|
55 |
+
if norm_type == "gLN":
|
56 |
+
return GlobalLayerNorm(channel_size)
|
57 |
+
elif norm_type == "cLN":
|
58 |
+
return ChannelwiseLayerNorm(channel_size)
|
59 |
+
else: # norm_type == "BN":
|
60 |
+
# Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
|
61 |
+
# along M and K, so this BN usage is right.
|
62 |
+
return nn.BatchNorm1d(channel_size)
|
63 |
+
|
64 |
+
class ChannelwiseLayerNorm(nn.Module):
|
65 |
+
"""Channel-wise Layer Normalization (cLN)"""
|
66 |
+
def __init__(self, channel_size):
|
67 |
+
super(ChannelwiseLayerNorm, self).__init__()
|
68 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
69 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
|
70 |
+
self.reset_parameters()
|
71 |
+
|
72 |
+
def reset_parameters(self):
|
73 |
+
self.gamma.data.fill_(1)
|
74 |
+
self.beta.data.zero_()
|
75 |
+
|
76 |
+
def forward(self, y):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
80 |
+
Returns:
|
81 |
+
cLN_y: [M, N, K]
|
82 |
+
"""
|
83 |
+
mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
|
84 |
+
var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
|
85 |
+
cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
86 |
+
return cLN_y
|
87 |
+
|
88 |
+
class DepthwiseSeparableConv(nn.Module):
|
89 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
90 |
+
stride, padding, dilation, norm_type="gLN", causal=False):
|
91 |
+
super(DepthwiseSeparableConv, self).__init__()
|
92 |
+
# Use `groups` option to implement depthwise convolution
|
93 |
+
# [M, H, K] -> [M, H, K]
|
94 |
+
depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size,
|
95 |
+
stride=stride, padding=padding,
|
96 |
+
dilation=dilation, groups=in_channels,
|
97 |
+
bias=False)
|
98 |
+
if causal:
|
99 |
+
chomp = Chomp1d(padding)
|
100 |
+
prelu = nn.PReLU()
|
101 |
+
norm = chose_norm(norm_type, in_channels)
|
102 |
+
# [M, H, K] -> [M, B, K]
|
103 |
+
pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
104 |
+
# Put together
|
105 |
+
if causal:
|
106 |
+
self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm,
|
107 |
+
pointwise_conv)
|
108 |
+
else:
|
109 |
+
self.net = nn.Sequential(depthwise_conv, prelu, norm,
|
110 |
+
pointwise_conv)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
x: [M, H, K]
|
116 |
+
Returns:
|
117 |
+
result: [M, B, K]
|
118 |
+
"""
|
119 |
+
return self.net(x)
|
120 |
+
|
121 |
+
class TemporalBlock(nn.Module):
|
122 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
123 |
+
stride, padding, dilation, norm_type="gLN", causal=False):
|
124 |
+
super(TemporalBlock, self).__init__()
|
125 |
+
# [M, B, K] -> [M, H, K]
|
126 |
+
conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
127 |
+
prelu = nn.PReLU()
|
128 |
+
norm = chose_norm(norm_type, out_channels)
|
129 |
+
# [M, H, K] -> [M, B, K]
|
130 |
+
dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size,
|
131 |
+
stride, padding, dilation, norm_type,
|
132 |
+
causal)
|
133 |
+
# Put together
|
134 |
+
self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
"""
|
138 |
+
Args:
|
139 |
+
x: [M, B, K]
|
140 |
+
Returns:
|
141 |
+
[M, B, K]
|
142 |
+
"""
|
143 |
+
residual = x
|
144 |
+
out = self.net(x)
|
145 |
+
# TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
|
146 |
+
return out + residual # look like w/o F.relu is better than w/ F.relu
|
147 |
+
# return F.relu(out + residual)
|
148 |
+
|
149 |
+
class GlobalLayerNorm(nn.Module):
|
150 |
+
"""Global Layer Normalization (gLN)"""
|
151 |
+
def __init__(self, channel_size):
|
152 |
+
super(GlobalLayerNorm, self).__init__()
|
153 |
+
self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
|
154 |
+
self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
|
155 |
+
self.reset_parameters()
|
156 |
+
|
157 |
+
def reset_parameters(self):
|
158 |
+
self.gamma.data.fill_(1)
|
159 |
+
self.beta.data.zero_()
|
160 |
+
|
161 |
+
def forward(self, y):
|
162 |
+
"""
|
163 |
+
Args:
|
164 |
+
y: [M, N, K], M is batch size, N is channel size, K is length
|
165 |
+
Returns:
|
166 |
+
gLN_y: [M, N, K]
|
167 |
+
"""
|
168 |
+
# TODO: in torch 1.0, torch.mean() support dim list
|
169 |
+
mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
|
170 |
+
var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
|
171 |
+
gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
|
172 |
+
return gLN_y
|
173 |
+
|
174 |
+
class TemporalConvNet(nn.Module):
|
175 |
+
def __init__(self, N=768, B=256, H=512, P=3, X=8, R=4, C=1, norm_type="gLN", causal=1,
|
176 |
+
mask_nonlinear='relu'):
|
177 |
+
"""
|
178 |
+
Args:
|
179 |
+
N: Number of filters in autoencoder
|
180 |
+
B: Number of channels in bottleneck 1 × 1-conv block
|
181 |
+
H: Number of channels in convolutional blocks
|
182 |
+
P: Kernel size in convolutional blocks
|
183 |
+
X: Number of convolutional blocks in each repeat
|
184 |
+
R: Number of repeats
|
185 |
+
C: Number of speakers
|
186 |
+
norm_type: BN, gLN, cLN
|
187 |
+
causal: causal or non-causal
|
188 |
+
mask_nonlinear: use which non-linear function to generate mask
|
189 |
+
"""
|
190 |
+
super(TemporalConvNet, self).__init__()
|
191 |
+
# Hyper-parameter
|
192 |
+
self.C = C
|
193 |
+
self.mask_nonlinear = mask_nonlinear
|
194 |
+
# Components
|
195 |
+
# [M, N, K] -> [M, N, K]
|
196 |
+
layer_norm = ChannelwiseLayerNorm(N)
|
197 |
+
# [M, N, K] -> [M, B, K]
|
198 |
+
bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
|
199 |
+
# [M, B, K] -> [M, B, K]
|
200 |
+
repeats = []
|
201 |
+
for r in range(R):
|
202 |
+
blocks = []
|
203 |
+
for x in range(X):
|
204 |
+
dilation = 2**x
|
205 |
+
padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
|
206 |
+
blocks += [TemporalBlock(B, H, P, stride=1,
|
207 |
+
padding=padding,
|
208 |
+
dilation=dilation,
|
209 |
+
norm_type=norm_type,
|
210 |
+
causal=causal)]
|
211 |
+
repeats += [nn.Sequential(*blocks)]
|
212 |
+
temporal_conv_net = nn.Sequential(*repeats)
|
213 |
+
# [M, B, K] -> [M, C*N, K]
|
214 |
+
mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias=False)
|
215 |
+
# Put together
|
216 |
+
self.network = nn.Sequential(layer_norm,
|
217 |
+
bottleneck_conv1x1,
|
218 |
+
temporal_conv_net,
|
219 |
+
mask_conv1x1)
|
220 |
+
|
221 |
+
def forward(self, mixture_w):
|
222 |
+
"""
|
223 |
+
Keep this API same with TasNet
|
224 |
+
Args:
|
225 |
+
mixture_w: [M, N, K], M is batch size
|
226 |
+
returns:
|
227 |
+
est_mask: [M, C, N, K]
|
228 |
+
"""
|
229 |
+
M, N, K = mixture_w.size()
|
230 |
+
score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
|
231 |
+
score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
|
232 |
+
if self.mask_nonlinear == 'softmax':
|
233 |
+
est_mask = F.softmax(score, dim=1)
|
234 |
+
est_mask = est_mask.squeeze(1)
|
235 |
+
elif self.mask_nonlinear == 'relu':
|
236 |
+
est_mask = F.relu(score)
|
237 |
+
est_mask = est_mask.squeeze(1)
|
238 |
+
else:
|
239 |
+
raise ValueError("Unsupported mask non-linear function")
|
240 |
+
return est_mask
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
def rescale_conv(conv, reference):
|
245 |
+
std = conv.weight.std().detach()
|
246 |
+
scale = (std / reference)**0.5
|
247 |
+
conv.weight.data /= scale
|
248 |
+
if conv.bias is not None:
|
249 |
+
conv.bias.data /= scale
|
250 |
+
|
251 |
+
|
252 |
+
def rescale_module(module, reference):
|
253 |
+
for sub in module.modules():
|
254 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
255 |
+
rescale_conv(sub, reference)
|
256 |
+
|
257 |
+
|
258 |
+
class Demucs(nn.Module):
|
259 |
+
"""
|
260 |
+
Demucs speech enhancement model.
|
261 |
+
Args:
|
262 |
+
- chin (int): number of input channels.
|
263 |
+
- chout (int): number of output channels.
|
264 |
+
- hidden (int): number of initial hidden channels.
|
265 |
+
- depth (int): number of layers.
|
266 |
+
- kernel_size (int): kernel size for each layer.
|
267 |
+
- stride (int): stride for each layer.
|
268 |
+
- causal (bool): if false, uses BiLSTM instead of LSTM.
|
269 |
+
- resample (int): amount of resampling to apply to the input/output.
|
270 |
+
Can be one of 1, 2 or 4.
|
271 |
+
- growth (float): number of channels is multiplied by this for every layer.
|
272 |
+
- max_hidden (int): maximum number of channels. Can be useful to
|
273 |
+
control the size/speed of the model.
|
274 |
+
- normalize (bool): if true, normalize the input.
|
275 |
+
- glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions.
|
276 |
+
- rescale (float): controls custom weight initialization.
|
277 |
+
See https://arxiv.org/abs/1911.13254.
|
278 |
+
- floor (float): stability flooring when normalizing.
|
279 |
+
|
280 |
+
"""
|
281 |
+
@capture_init
|
282 |
+
def __init__(self,
|
283 |
+
chin=1,
|
284 |
+
chout=1,
|
285 |
+
hidden=48,
|
286 |
+
depth=5,
|
287 |
+
kernel_size=8,
|
288 |
+
stride=4,
|
289 |
+
causal=True,
|
290 |
+
resample=4,
|
291 |
+
growth=2,
|
292 |
+
max_hidden=10_000,
|
293 |
+
normalize=True,
|
294 |
+
glu=True,
|
295 |
+
rescale=0.1,
|
296 |
+
floor=1e-3):
|
297 |
+
|
298 |
+
super().__init__()
|
299 |
+
if resample not in [1, 2, 4]:
|
300 |
+
raise ValueError("Resample should be 1, 2 or 4.")
|
301 |
+
|
302 |
+
self.chin = chin
|
303 |
+
self.chout = chout
|
304 |
+
self.hidden = hidden
|
305 |
+
self.depth = depth
|
306 |
+
self.kernel_size = kernel_size
|
307 |
+
self.stride = stride
|
308 |
+
self.causal = causal
|
309 |
+
self.floor = floor
|
310 |
+
self.resample = resample
|
311 |
+
self.normalize = normalize
|
312 |
+
|
313 |
+
self.encoder = nn.ModuleList()
|
314 |
+
self.decoder = nn.ModuleList()
|
315 |
+
activation = nn.GLU(1) if glu else nn.ReLU()
|
316 |
+
ch_scale = 2 if glu else 1
|
317 |
+
|
318 |
+
for index in range(depth):
|
319 |
+
encode = []
|
320 |
+
encode += [
|
321 |
+
nn.Conv1d(chin, hidden, kernel_size, stride),
|
322 |
+
nn.ReLU(),
|
323 |
+
nn.Conv1d(hidden, hidden * ch_scale, 1), activation,
|
324 |
+
]
|
325 |
+
self.encoder.append(nn.Sequential(*encode))
|
326 |
+
|
327 |
+
decode = []
|
328 |
+
decode += [
|
329 |
+
nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
|
330 |
+
nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
|
331 |
+
]
|
332 |
+
if index > 0:
|
333 |
+
decode.append(nn.ReLU())
|
334 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
335 |
+
chout = hidden
|
336 |
+
chin = hidden
|
337 |
+
hidden = min(int(growth * hidden), max_hidden)
|
338 |
+
# import pdb; pdb.set_trace()
|
339 |
+
self.separator = TemporalConvNet(N=chout)
|
340 |
+
# self.lstm = BLSTM(chin, bi=not causal)
|
341 |
+
if rescale:
|
342 |
+
rescale_module(self, reference=rescale)
|
343 |
+
|
344 |
+
def valid_length(self, length):
|
345 |
+
"""
|
346 |
+
Return the nearest valid length to use with the model so that
|
347 |
+
there is no time steps left over in a convolutions, e.g. for all
|
348 |
+
layers, size of the input - kernel_size % stride = 0.
|
349 |
+
|
350 |
+
If the mixture has a valid length, the estimated sources
|
351 |
+
will have exactly the same length.
|
352 |
+
"""
|
353 |
+
length = math.ceil(length * self.resample)
|
354 |
+
for idx in range(self.depth):
|
355 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
356 |
+
length = max(length, 1)
|
357 |
+
for idx in range(self.depth):
|
358 |
+
length = (length - 1) * self.stride + self.kernel_size
|
359 |
+
length = int(math.ceil(length / self.resample))
|
360 |
+
return int(length)
|
361 |
+
|
362 |
+
@property
|
363 |
+
def total_stride(self):
|
364 |
+
return self.stride ** self.depth // self.resample
|
365 |
+
|
366 |
+
def forward(self, mix):
|
367 |
+
if mix.dim() == 2:
|
368 |
+
mix = mix.unsqueeze(1)
|
369 |
+
|
370 |
+
if self.normalize:
|
371 |
+
mono = mix.mean(dim=1, keepdim=True)
|
372 |
+
std = mono.std(dim=-1, keepdim=True)
|
373 |
+
mix = mix / (self.floor + std)
|
374 |
+
else:
|
375 |
+
std = 1
|
376 |
+
length = mix.shape[-1]
|
377 |
+
x = mix
|
378 |
+
x = F.pad(x, (0, self.valid_length(length) - length))
|
379 |
+
if self.resample == 2:
|
380 |
+
x = upsample2(x)
|
381 |
+
elif self.resample == 4:
|
382 |
+
x = upsample2(x)
|
383 |
+
x = upsample2(x)
|
384 |
+
skips = []
|
385 |
+
for encode in self.encoder:
|
386 |
+
x = encode(x)
|
387 |
+
skips.append(x)
|
388 |
+
x = self.separator(x)
|
389 |
+
# x = x.permute(2, 0, 1)
|
390 |
+
# x, _ = self.lstm(x)
|
391 |
+
# x = x.permute(1, 2, 0)
|
392 |
+
# import pdb; pdb.set_trace()
|
393 |
+
for decode in self.decoder:
|
394 |
+
skip = skips.pop(-1)
|
395 |
+
x = x + skip[..., :x.shape[-1]]
|
396 |
+
x = decode(x)
|
397 |
+
if self.resample == 2:
|
398 |
+
x = downsample2(x)
|
399 |
+
elif self.resample == 4:
|
400 |
+
x = downsample2(x)
|
401 |
+
x = downsample2(x)
|
402 |
+
|
403 |
+
x = x[..., :length]
|
404 |
+
return std * x
|
405 |
+
|
406 |
+
|
407 |
+
def fast_conv(conv, x):
|
408 |
+
"""
|
409 |
+
Faster convolution evaluation if either kernel size is 1
|
410 |
+
or length of sequence is 1.
|
411 |
+
"""
|
412 |
+
batch, chin, length = x.shape
|
413 |
+
chout, chin, kernel = conv.weight.shape
|
414 |
+
assert batch == 1
|
415 |
+
if kernel == 1:
|
416 |
+
x = x.view(chin, length)
|
417 |
+
out = th.addmm(conv.bias.view(-1, 1),
|
418 |
+
conv.weight.view(chout, chin), x)
|
419 |
+
elif length == kernel:
|
420 |
+
x = x.view(chin * kernel, 1)
|
421 |
+
out = th.addmm(conv.bias.view(-1, 1),
|
422 |
+
conv.weight.view(chout, chin * kernel), x)
|
423 |
+
else:
|
424 |
+
out = conv(x)
|
425 |
+
return out.view(batch, chout, -1)
|
426 |
+
|
427 |
+
|
428 |
+
class DemucsStreamer:
|
429 |
+
"""
|
430 |
+
Streaming implementation for Demucs. It supports being fed with any amount
|
431 |
+
of audio at a time. You will get back as much audio as possible at that
|
432 |
+
point.
|
433 |
+
|
434 |
+
Args:
|
435 |
+
- demucs (Demucs): Demucs model.
|
436 |
+
- dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum
|
437 |
+
noise removal, 1 just returns the input signal. Small values > 0
|
438 |
+
allows to limit distortions.
|
439 |
+
- num_frames (int): number of frames to process at once. Higher values
|
440 |
+
will increase overall latency but improve the real time factor.
|
441 |
+
- resample_lookahead (int): extra lookahead used for the resampling.
|
442 |
+
- resample_buffer (int): size of the buffer of previous inputs/outputs
|
443 |
+
kept for resampling.
|
444 |
+
"""
|
445 |
+
def __init__(self, demucs,
|
446 |
+
dry=0,
|
447 |
+
num_frames=1,
|
448 |
+
resample_lookahead=64,
|
449 |
+
resample_buffer=256):
|
450 |
+
device = next(iter(demucs.parameters())).device
|
451 |
+
self.demucs = demucs
|
452 |
+
self.lstm_state = None
|
453 |
+
self.conv_state = None
|
454 |
+
self.dry = dry
|
455 |
+
self.resample_lookahead = resample_lookahead
|
456 |
+
self.resample_buffer = resample_buffer
|
457 |
+
self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1)
|
458 |
+
self.total_length = self.frame_length + self.resample_lookahead
|
459 |
+
self.stride = demucs.total_stride * num_frames
|
460 |
+
self.resample_in = torch.zeros(demucs.chin, resample_buffer, device=device)
|
461 |
+
self.resample_out = torch.zeros(demucs.chin, resample_buffer, device=device)
|
462 |
+
|
463 |
+
self.frames = 0
|
464 |
+
self.total_time = 0
|
465 |
+
self.variance = 0
|
466 |
+
self.pending = torch.zeros(demucs.chin, 0, device=device)
|
467 |
+
|
468 |
+
bias = demucs.decoder[0][2].bias
|
469 |
+
weight = demucs.decoder[0][2].weight
|
470 |
+
chin, chout, kernel = weight.shape
|
471 |
+
self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1)
|
472 |
+
self._weight = weight.permute(1, 2, 0).contiguous()
|
473 |
+
|
474 |
+
def reset_time_per_frame(self):
|
475 |
+
self.total_time = 0
|
476 |
+
self.frames = 0
|
477 |
+
|
478 |
+
@property
|
479 |
+
def time_per_frame(self):
|
480 |
+
return self.total_time / self.frames
|
481 |
+
|
482 |
+
def flush(self):
|
483 |
+
"""
|
484 |
+
Flush remaining audio by padding it with zero. Call this
|
485 |
+
when you have no more input and want to get back the last chunk of audio.
|
486 |
+
"""
|
487 |
+
pending_length = self.pending.shape[1]
|
488 |
+
padding = torch.zeros(self.demucs.chin, self.total_length, device=self.pending.device)
|
489 |
+
out = self.feed(padding)
|
490 |
+
return out[:, :pending_length]
|
491 |
+
|
492 |
+
def feed(self, wav):
|
493 |
+
"""
|
494 |
+
Apply the model to mix using true real time evaluation.
|
495 |
+
Normalization is done online as is the resampling.
|
496 |
+
"""
|
497 |
+
begin = time.time()
|
498 |
+
demucs = self.demucs
|
499 |
+
resample_buffer = self.resample_buffer
|
500 |
+
stride = self.stride
|
501 |
+
resample = demucs.resample
|
502 |
+
|
503 |
+
if wav.dim() != 2:
|
504 |
+
raise ValueError("input wav should be two dimensional.")
|
505 |
+
chin, _ = wav.shape
|
506 |
+
if chin != demucs.chin:
|
507 |
+
raise ValueError(f"Expected {demucs.chin} channels, got {chin}")
|
508 |
+
|
509 |
+
self.pending = torch.cat([self.pending, wav], dim=1)
|
510 |
+
outs = []
|
511 |
+
while self.pending.shape[1] >= self.total_length:
|
512 |
+
self.frames += 1
|
513 |
+
frame = self.pending[:, :self.total_length]
|
514 |
+
dry_signal = frame[:, :stride]
|
515 |
+
if demucs.normalize:
|
516 |
+
mono = frame.mean(0)
|
517 |
+
variance = (mono**2).mean()
|
518 |
+
self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance
|
519 |
+
frame = frame / (demucs.floor + math.sqrt(self.variance))
|
520 |
+
frame = torch.cat([self.resample_in, frame], dim=-1)
|
521 |
+
self.resample_in[:] = frame[:, stride - resample_buffer:stride]
|
522 |
+
|
523 |
+
if resample == 4:
|
524 |
+
frame = upsample2(upsample2(frame))
|
525 |
+
elif resample == 2:
|
526 |
+
frame = upsample2(frame)
|
527 |
+
frame = frame[:, resample * resample_buffer:] # remove pre sampling buffer
|
528 |
+
frame = frame[:, :resample * self.frame_length] # remove extra samples after window
|
529 |
+
|
530 |
+
out, extra = self._separate_frame(frame)
|
531 |
+
padded_out = torch.cat([self.resample_out, out, extra], 1)
|
532 |
+
self.resample_out[:] = out[:, -resample_buffer:]
|
533 |
+
if resample == 4:
|
534 |
+
out = downsample2(downsample2(padded_out))
|
535 |
+
elif resample == 2:
|
536 |
+
out = downsample2(padded_out)
|
537 |
+
else:
|
538 |
+
out = padded_out
|
539 |
+
|
540 |
+
out = out[:, resample_buffer // resample:]
|
541 |
+
out = out[:, :stride]
|
542 |
+
|
543 |
+
if demucs.normalize:
|
544 |
+
out *= math.sqrt(self.variance)
|
545 |
+
out = self.dry * dry_signal + (1 - self.dry) * out
|
546 |
+
outs.append(out)
|
547 |
+
self.pending = self.pending[:, stride:]
|
548 |
+
|
549 |
+
self.total_time += time.time() - begin
|
550 |
+
if outs:
|
551 |
+
out = torch.cat(outs, 1)
|
552 |
+
else:
|
553 |
+
out = torch.zeros(chin, 0, device=wav.device)
|
554 |
+
return out
|
555 |
+
|
556 |
+
def _separate_frame(self, frame):
|
557 |
+
demucs = self.demucs
|
558 |
+
skips = []
|
559 |
+
next_state = []
|
560 |
+
first = self.conv_state is None
|
561 |
+
stride = self.stride * demucs.resample
|
562 |
+
x = frame[None]
|
563 |
+
for idx, encode in enumerate(demucs.encoder):
|
564 |
+
stride //= demucs.stride
|
565 |
+
length = x.shape[2]
|
566 |
+
if idx == demucs.depth - 1:
|
567 |
+
# This is sligthly faster for the last conv
|
568 |
+
x = fast_conv(encode[0], x)
|
569 |
+
x = encode[1](x)
|
570 |
+
x = fast_conv(encode[2], x)
|
571 |
+
x = encode[3](x)
|
572 |
+
else:
|
573 |
+
if not first:
|
574 |
+
prev = self.conv_state.pop(0)
|
575 |
+
prev = prev[..., stride:]
|
576 |
+
tgt = (length - demucs.kernel_size) // demucs.stride + 1
|
577 |
+
missing = tgt - prev.shape[-1]
|
578 |
+
offset = length - demucs.kernel_size - demucs.stride * (missing - 1)
|
579 |
+
x = x[..., offset:]
|
580 |
+
x = encode[1](encode[0](x))
|
581 |
+
x = fast_conv(encode[2], x)
|
582 |
+
x = encode[3](x)
|
583 |
+
if not first:
|
584 |
+
x = torch.cat([prev, x], -1)
|
585 |
+
next_state.append(x)
|
586 |
+
skips.append(x)
|
587 |
+
|
588 |
+
x = x.permute(2, 0, 1)
|
589 |
+
x, self.lstm_state = demucs.lstm(x, self.lstm_state)
|
590 |
+
x = x.permute(1, 2, 0)
|
591 |
+
# In the following, x contains only correct samples, i.e. the one
|
592 |
+
# for which each time position is covered by two window of the upper layer.
|
593 |
+
# extra contains extra samples to the right, and is used only as a
|
594 |
+
# better padding for the online resampling.
|
595 |
+
extra = None
|
596 |
+
for idx, decode in enumerate(demucs.decoder):
|
597 |
+
skip = skips.pop(-1)
|
598 |
+
x += skip[..., :x.shape[-1]]
|
599 |
+
x = fast_conv(decode[0], x)
|
600 |
+
x = decode[1](x)
|
601 |
+
|
602 |
+
if extra is not None:
|
603 |
+
skip = skip[..., x.shape[-1]:]
|
604 |
+
extra += skip[..., :extra.shape[-1]]
|
605 |
+
extra = decode[2](decode[1](decode[0](extra)))
|
606 |
+
x = decode[2](x)
|
607 |
+
next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1))
|
608 |
+
if extra is None:
|
609 |
+
extra = x[..., -demucs.stride:]
|
610 |
+
else:
|
611 |
+
extra[..., :demucs.stride] += next_state[-1]
|
612 |
+
x = x[..., :-demucs.stride]
|
613 |
+
|
614 |
+
if not first:
|
615 |
+
prev = self.conv_state.pop(0)
|
616 |
+
x[..., :demucs.stride] += prev
|
617 |
+
if idx != demucs.depth - 1:
|
618 |
+
x = decode[3](x)
|
619 |
+
extra = decode[3](extra)
|
620 |
+
self.conv_state = next_state
|
621 |
+
return x[0], extra[0]
|
622 |
+
|
623 |
+
|
624 |
+
def test():
|
625 |
+
import argparse
|
626 |
+
parser = argparse.ArgumentParser(
|
627 |
+
"denoiser.demucs",
|
628 |
+
description="Benchmark the streaming Demucs implementation, "
|
629 |
+
"as well as checking the delta with the offline implementation.")
|
630 |
+
parser.add_argument("--resample", default=4, type=int)
|
631 |
+
parser.add_argument("--hidden", default=48, type=int)
|
632 |
+
parser.add_argument("--device", default="cpu")
|
633 |
+
parser.add_argument("-t", "--num_threads", type=int)
|
634 |
+
parser.add_argument("-f", "--num_frames", type=int, default=1)
|
635 |
+
args = parser.parse_args()
|
636 |
+
if args.num_threads:
|
637 |
+
torch.set_num_threads(args.num_threads)
|
638 |
+
sr = 16_000
|
639 |
+
sr_ms = sr / 1000
|
640 |
+
demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device)
|
641 |
+
x = torch.randn(1, sr * 4).to(args.device)
|
642 |
+
out = demucs(x[None])[0]
|
643 |
+
streamer = DemucsStreamer(demucs, num_frames=args.num_frames)
|
644 |
+
out_rt = []
|
645 |
+
frame_size = streamer.total_length
|
646 |
+
with torch.no_grad():
|
647 |
+
while x.shape[1] > 0:
|
648 |
+
out_rt.append(streamer.feed(x[:, :frame_size]))
|
649 |
+
x = x[:, frame_size:]
|
650 |
+
frame_size = streamer.demucs.total_stride
|
651 |
+
out_rt.append(streamer.flush())
|
652 |
+
out_rt = torch.cat(out_rt, 1)
|
653 |
+
print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='')
|
654 |
+
print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='')
|
655 |
+
print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='')
|
656 |
+
print(f"delta: {torch.norm(out - out_rt) / torch.norm(out):.2%}, ", end='')
|
657 |
+
print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}")
|
658 |
+
|
659 |
+
|
660 |
+
if __name__ == "__main__":
|
661 |
+
test()
|
denoiser/data.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez and adiyoss
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import re
|
12 |
+
|
13 |
+
from .audio import Audioset
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def match_dns(noisy, clean):
|
19 |
+
"""match_dns.
|
20 |
+
Match noisy and clean DNS dataset filenames.
|
21 |
+
|
22 |
+
:param noisy: list of the noisy filenames
|
23 |
+
:param clean: list of the clean filenames
|
24 |
+
"""
|
25 |
+
logger.debug("Matching noisy and clean for dns dataset")
|
26 |
+
noisydict = {}
|
27 |
+
extra_noisy = []
|
28 |
+
for path, size in noisy:
|
29 |
+
match = re.search(r'fileid_(\d+)\.wav$', path)
|
30 |
+
if match is None:
|
31 |
+
# maybe we are mixing some other dataset in
|
32 |
+
extra_noisy.append((path, size))
|
33 |
+
else:
|
34 |
+
noisydict[match.group(1)] = (path, size)
|
35 |
+
noisy[:] = []
|
36 |
+
extra_clean = []
|
37 |
+
copied = list(clean)
|
38 |
+
clean[:] = []
|
39 |
+
for path, size in copied:
|
40 |
+
match = re.search(r'fileid_(\d+)\.wav$', path)
|
41 |
+
if match is None:
|
42 |
+
extra_clean.append((path, size))
|
43 |
+
else:
|
44 |
+
noisy.append(noisydict[match.group(1)])
|
45 |
+
clean.append((path, size))
|
46 |
+
extra_noisy.sort()
|
47 |
+
extra_clean.sort()
|
48 |
+
clean += extra_clean
|
49 |
+
noisy += extra_noisy
|
50 |
+
|
51 |
+
|
52 |
+
def match_files(noisy, clean, matching="sort"):
|
53 |
+
"""match_files.
|
54 |
+
Sort files to match noisy and clean filenames.
|
55 |
+
:param noisy: list of the noisy filenames
|
56 |
+
:param clean: list of the clean filenames
|
57 |
+
:param matching: the matching function, at this point only sort is supported
|
58 |
+
"""
|
59 |
+
if matching == "dns":
|
60 |
+
# dns dataset filenames don't match when sorted, we have to manually match them
|
61 |
+
match_dns(noisy, clean)
|
62 |
+
elif matching == "sort":
|
63 |
+
noisy.sort()
|
64 |
+
clean.sort()
|
65 |
+
else:
|
66 |
+
raise ValueError(f"Invalid value for matching {matching}")
|
67 |
+
|
68 |
+
|
69 |
+
class NoisyCleanSet:
|
70 |
+
def __init__(self, json_dir, matching="sort", length=None, stride=None,
|
71 |
+
pad=True, sample_rate=None):
|
72 |
+
"""__init__.
|
73 |
+
|
74 |
+
:param json_dir: directory containing both clean.json and noisy.json
|
75 |
+
:param matching: matching function for the files
|
76 |
+
:param length: maximum sequence length
|
77 |
+
:param stride: the stride used for splitting audio sequences
|
78 |
+
:param pad: pad the end of the sequence with zeros
|
79 |
+
:param sample_rate: the signals sampling rate
|
80 |
+
"""
|
81 |
+
noisy_json = os.path.join(json_dir, 'noisy.json')
|
82 |
+
clean_json = os.path.join(json_dir, 'clean.json')
|
83 |
+
with open(noisy_json, 'r') as f:
|
84 |
+
noisy = json.load(f)
|
85 |
+
with open(clean_json, 'r') as f:
|
86 |
+
clean = json.load(f)
|
87 |
+
|
88 |
+
match_files(noisy, clean, matching)
|
89 |
+
kw = {'length': length, 'stride': stride, 'pad': pad, 'sample_rate': sample_rate}
|
90 |
+
self.clean_set = Audioset(clean, **kw)
|
91 |
+
self.noisy_set = Audioset(noisy, **kw)
|
92 |
+
|
93 |
+
assert len(self.clean_set) == len(self.noisy_set)
|
94 |
+
|
95 |
+
def __getitem__(self, index):
|
96 |
+
return self.noisy_set[index], self.clean_set[index]
|
97 |
+
|
98 |
+
def __len__(self):
|
99 |
+
return len(self.noisy_set)
|
denoiser/demucs.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import math
|
9 |
+
import time
|
10 |
+
|
11 |
+
import torch as th
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
from .resample import downsample2, upsample2
|
16 |
+
from .utils import capture_init
|
17 |
+
|
18 |
+
|
19 |
+
class BLSTM(nn.Module):
|
20 |
+
def __init__(self, dim, layers=2, bi=True):
|
21 |
+
super().__init__()
|
22 |
+
klass = nn.LSTM
|
23 |
+
self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
|
24 |
+
self.linear = None
|
25 |
+
if bi:
|
26 |
+
self.linear = nn.Linear(2 * dim, dim)
|
27 |
+
|
28 |
+
def forward(self, x, hidden=None):
|
29 |
+
x, hidden = self.lstm(x, hidden)
|
30 |
+
if self.linear:
|
31 |
+
x = self.linear(x)
|
32 |
+
return x, hidden
|
33 |
+
|
34 |
+
|
35 |
+
def rescale_conv(conv, reference):
|
36 |
+
std = conv.weight.std().detach()
|
37 |
+
scale = (std / reference)**0.5
|
38 |
+
conv.weight.data /= scale
|
39 |
+
if conv.bias is not None:
|
40 |
+
conv.bias.data /= scale
|
41 |
+
|
42 |
+
|
43 |
+
def rescale_module(module, reference):
|
44 |
+
for sub in module.modules():
|
45 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
|
46 |
+
rescale_conv(sub, reference)
|
47 |
+
|
48 |
+
|
49 |
+
class Demucs(nn.Module):
|
50 |
+
"""
|
51 |
+
Demucs speech enhancement model.
|
52 |
+
Args:
|
53 |
+
- chin (int): number of input channels.
|
54 |
+
- chout (int): number of output channels.
|
55 |
+
- hidden (int): number of initial hidden channels.
|
56 |
+
- depth (int): number of layers.
|
57 |
+
- kernel_size (int): kernel size for each layer.
|
58 |
+
- stride (int): stride for each layer.
|
59 |
+
- causal (bool): if false, uses BiLSTM instead of LSTM.
|
60 |
+
- resample (int): amount of resampling to apply to the input/output.
|
61 |
+
Can be one of 1, 2 or 4.
|
62 |
+
- growth (float): number of channels is multiplied by this for every layer.
|
63 |
+
- max_hidden (int): maximum number of channels. Can be useful to
|
64 |
+
control the size/speed of the model.
|
65 |
+
- normalize (bool): if true, normalize the input.
|
66 |
+
- glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions.
|
67 |
+
- rescale (float): controls custom weight initialization.
|
68 |
+
See https://arxiv.org/abs/1911.13254.
|
69 |
+
- floor (float): stability flooring when normalizing.
|
70 |
+
|
71 |
+
"""
|
72 |
+
@capture_init
|
73 |
+
def __init__(self,
|
74 |
+
chin=1,
|
75 |
+
chout=1,
|
76 |
+
hidden=48,
|
77 |
+
depth=5,
|
78 |
+
kernel_size=8,
|
79 |
+
stride=4,
|
80 |
+
causal=True,
|
81 |
+
resample=4,
|
82 |
+
growth=2,
|
83 |
+
max_hidden=10_000,
|
84 |
+
normalize=True,
|
85 |
+
glu=True,
|
86 |
+
rescale=0.1,
|
87 |
+
floor=1e-3):
|
88 |
+
|
89 |
+
super().__init__()
|
90 |
+
if resample not in [1, 2, 4]:
|
91 |
+
raise ValueError("Resample should be 1, 2 or 4.")
|
92 |
+
|
93 |
+
self.chin = chin
|
94 |
+
self.chout = chout
|
95 |
+
self.hidden = hidden
|
96 |
+
self.depth = depth
|
97 |
+
self.kernel_size = kernel_size
|
98 |
+
self.stride = stride
|
99 |
+
self.causal = causal
|
100 |
+
self.floor = floor
|
101 |
+
self.resample = resample
|
102 |
+
self.normalize = normalize
|
103 |
+
|
104 |
+
self.encoder = nn.ModuleList()
|
105 |
+
self.decoder = nn.ModuleList()
|
106 |
+
activation = nn.GLU(1) if glu else nn.ReLU()
|
107 |
+
ch_scale = 2 if glu else 1
|
108 |
+
|
109 |
+
for index in range(depth):
|
110 |
+
encode = []
|
111 |
+
encode += [
|
112 |
+
nn.Conv1d(chin, hidden, kernel_size, stride),
|
113 |
+
nn.ReLU(),
|
114 |
+
nn.Conv1d(hidden, hidden * ch_scale, 1), activation,
|
115 |
+
]
|
116 |
+
self.encoder.append(nn.Sequential(*encode))
|
117 |
+
|
118 |
+
decode = []
|
119 |
+
decode += [
|
120 |
+
nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
|
121 |
+
nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
|
122 |
+
]
|
123 |
+
if index > 0:
|
124 |
+
decode.append(nn.ReLU())
|
125 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
126 |
+
chout = hidden
|
127 |
+
chin = hidden
|
128 |
+
hidden = min(int(growth * hidden), max_hidden)
|
129 |
+
|
130 |
+
self.lstm = BLSTM(chin, bi=not causal)
|
131 |
+
if rescale:
|
132 |
+
rescale_module(self, reference=rescale)
|
133 |
+
|
134 |
+
def valid_length(self, length):
|
135 |
+
"""
|
136 |
+
Return the nearest valid length to use with the model so that
|
137 |
+
there is no time steps left over in a convolutions, e.g. for all
|
138 |
+
layers, size of the input - kernel_size % stride = 0.
|
139 |
+
|
140 |
+
If the mixture has a valid length, the estimated sources
|
141 |
+
will have exactly the same length.
|
142 |
+
"""
|
143 |
+
length = math.ceil(length * self.resample)
|
144 |
+
for idx in range(self.depth):
|
145 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
146 |
+
length = max(length, 1)
|
147 |
+
for idx in range(self.depth):
|
148 |
+
length = (length - 1) * self.stride + self.kernel_size
|
149 |
+
length = int(math.ceil(length / self.resample))
|
150 |
+
return int(length)
|
151 |
+
|
152 |
+
@property
|
153 |
+
def total_stride(self):
|
154 |
+
return self.stride ** self.depth // self.resample
|
155 |
+
|
156 |
+
def forward(self, mix):
|
157 |
+
if mix.dim() == 2:
|
158 |
+
mix = mix.unsqueeze(1)
|
159 |
+
|
160 |
+
if self.normalize:
|
161 |
+
mono = mix.mean(dim=1, keepdim=True)
|
162 |
+
std = mono.std(dim=-1, keepdim=True)
|
163 |
+
mix = mix / (self.floor + std)
|
164 |
+
else:
|
165 |
+
std = 1
|
166 |
+
length = mix.shape[-1]
|
167 |
+
x = mix
|
168 |
+
x = F.pad(x, (0, self.valid_length(length) - length))
|
169 |
+
if self.resample == 2:
|
170 |
+
x = upsample2(x)
|
171 |
+
elif self.resample == 4:
|
172 |
+
x = upsample2(x)
|
173 |
+
x = upsample2(x)
|
174 |
+
skips = []
|
175 |
+
for encode in self.encoder:
|
176 |
+
x = encode(x)
|
177 |
+
skips.append(x)
|
178 |
+
x = x.permute(2, 0, 1)
|
179 |
+
x, _ = self.lstm(x)
|
180 |
+
x = x.permute(1, 2, 0)
|
181 |
+
for decode in self.decoder:
|
182 |
+
skip = skips.pop(-1)
|
183 |
+
x = x + skip[..., :x.shape[-1]]
|
184 |
+
x = decode(x)
|
185 |
+
if self.resample == 2:
|
186 |
+
x = downsample2(x)
|
187 |
+
elif self.resample == 4:
|
188 |
+
x = downsample2(x)
|
189 |
+
x = downsample2(x)
|
190 |
+
|
191 |
+
x = x[..., :length]
|
192 |
+
return std * x
|
193 |
+
|
194 |
+
|
195 |
+
def fast_conv(conv, x):
|
196 |
+
"""
|
197 |
+
Faster convolution evaluation if either kernel size is 1
|
198 |
+
or length of sequence is 1.
|
199 |
+
"""
|
200 |
+
batch, chin, length = x.shape
|
201 |
+
chout, chin, kernel = conv.weight.shape
|
202 |
+
assert batch == 1
|
203 |
+
if kernel == 1:
|
204 |
+
x = x.view(chin, length)
|
205 |
+
out = th.addmm(conv.bias.view(-1, 1),
|
206 |
+
conv.weight.view(chout, chin), x)
|
207 |
+
elif length == kernel:
|
208 |
+
x = x.view(chin * kernel, 1)
|
209 |
+
out = th.addmm(conv.bias.view(-1, 1),
|
210 |
+
conv.weight.view(chout, chin * kernel), x)
|
211 |
+
else:
|
212 |
+
out = conv(x)
|
213 |
+
return out.view(batch, chout, -1)
|
214 |
+
|
215 |
+
|
216 |
+
class DemucsStreamer:
|
217 |
+
"""
|
218 |
+
Streaming implementation for Demucs. It supports being fed with any amount
|
219 |
+
of audio at a time. You will get back as much audio as possible at that
|
220 |
+
point.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
- demucs (Demucs): Demucs model.
|
224 |
+
- dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum
|
225 |
+
noise removal, 1 just returns the input signal. Small values > 0
|
226 |
+
allows to limit distortions.
|
227 |
+
- num_frames (int): number of frames to process at once. Higher values
|
228 |
+
will increase overall latency but improve the real time factor.
|
229 |
+
- resample_lookahead (int): extra lookahead used for the resampling.
|
230 |
+
- resample_buffer (int): size of the buffer of previous inputs/outputs
|
231 |
+
kept for resampling.
|
232 |
+
"""
|
233 |
+
def __init__(self, demucs,
|
234 |
+
dry=0,
|
235 |
+
num_frames=1,
|
236 |
+
resample_lookahead=64,
|
237 |
+
resample_buffer=256):
|
238 |
+
device = next(iter(demucs.parameters())).device
|
239 |
+
self.demucs = demucs
|
240 |
+
self.lstm_state = None
|
241 |
+
self.conv_state = None
|
242 |
+
self.dry = dry
|
243 |
+
self.resample_lookahead = resample_lookahead
|
244 |
+
self.resample_buffer = resample_buffer
|
245 |
+
self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1)
|
246 |
+
self.total_length = self.frame_length + self.resample_lookahead
|
247 |
+
self.stride = demucs.total_stride * num_frames
|
248 |
+
self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device)
|
249 |
+
self.resample_out = th.zeros(demucs.chin, resample_buffer, device=device)
|
250 |
+
|
251 |
+
self.frames = 0
|
252 |
+
self.total_time = 0
|
253 |
+
self.variance = 0
|
254 |
+
self.pending = th.zeros(demucs.chin, 0, device=device)
|
255 |
+
|
256 |
+
bias = demucs.decoder[0][2].bias
|
257 |
+
weight = demucs.decoder[0][2].weight
|
258 |
+
chin, chout, kernel = weight.shape
|
259 |
+
self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1)
|
260 |
+
self._weight = weight.permute(1, 2, 0).contiguous()
|
261 |
+
|
262 |
+
def reset_time_per_frame(self):
|
263 |
+
self.total_time = 0
|
264 |
+
self.frames = 0
|
265 |
+
|
266 |
+
@property
|
267 |
+
def time_per_frame(self):
|
268 |
+
return self.total_time / self.frames
|
269 |
+
|
270 |
+
def flush(self):
|
271 |
+
"""
|
272 |
+
Flush remaining audio by padding it with zero. Call this
|
273 |
+
when you have no more input and want to get back the last chunk of audio.
|
274 |
+
"""
|
275 |
+
pending_length = self.pending.shape[1]
|
276 |
+
padding = th.zeros(self.demucs.chin, self.total_length, device=self.pending.device)
|
277 |
+
out = self.feed(padding)
|
278 |
+
return out[:, :pending_length]
|
279 |
+
|
280 |
+
def feed(self, wav):
|
281 |
+
"""
|
282 |
+
Apply the model to mix using true real time evaluation.
|
283 |
+
Normalization is done online as is the resampling.
|
284 |
+
"""
|
285 |
+
begin = time.time()
|
286 |
+
demucs = self.demucs
|
287 |
+
resample_buffer = self.resample_buffer
|
288 |
+
stride = self.stride
|
289 |
+
resample = demucs.resample
|
290 |
+
|
291 |
+
if wav.dim() != 2:
|
292 |
+
raise ValueError("input wav should be two dimensional.")
|
293 |
+
chin, _ = wav.shape
|
294 |
+
if chin != demucs.chin:
|
295 |
+
raise ValueError(f"Expected {demucs.chin} channels, got {chin}")
|
296 |
+
|
297 |
+
self.pending = th.cat([self.pending, wav], dim=1)
|
298 |
+
outs = []
|
299 |
+
while self.pending.shape[1] >= self.total_length:
|
300 |
+
self.frames += 1
|
301 |
+
frame = self.pending[:, :self.total_length]
|
302 |
+
dry_signal = frame[:, :stride]
|
303 |
+
if demucs.normalize:
|
304 |
+
mono = frame.mean(0)
|
305 |
+
variance = (mono**2).mean()
|
306 |
+
self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance
|
307 |
+
frame = frame / (demucs.floor + math.sqrt(self.variance))
|
308 |
+
frame = th.cat([self.resample_in, frame], dim=-1)
|
309 |
+
self.resample_in[:] = frame[:, stride - resample_buffer:stride]
|
310 |
+
|
311 |
+
if resample == 4:
|
312 |
+
frame = upsample2(upsample2(frame))
|
313 |
+
elif resample == 2:
|
314 |
+
frame = upsample2(frame)
|
315 |
+
frame = frame[:, resample * resample_buffer:] # remove pre sampling buffer
|
316 |
+
frame = frame[:, :resample * self.frame_length] # remove extra samples after window
|
317 |
+
|
318 |
+
out, extra = self._separate_frame(frame)
|
319 |
+
padded_out = th.cat([self.resample_out, out, extra], 1)
|
320 |
+
self.resample_out[:] = out[:, -resample_buffer:]
|
321 |
+
if resample == 4:
|
322 |
+
out = downsample2(downsample2(padded_out))
|
323 |
+
elif resample == 2:
|
324 |
+
out = downsample2(padded_out)
|
325 |
+
else:
|
326 |
+
out = padded_out
|
327 |
+
|
328 |
+
out = out[:, resample_buffer // resample:]
|
329 |
+
out = out[:, :stride]
|
330 |
+
|
331 |
+
if demucs.normalize:
|
332 |
+
out *= math.sqrt(self.variance)
|
333 |
+
out = self.dry * dry_signal + (1 - self.dry) * out
|
334 |
+
outs.append(out)
|
335 |
+
self.pending = self.pending[:, stride:]
|
336 |
+
|
337 |
+
self.total_time += time.time() - begin
|
338 |
+
if outs:
|
339 |
+
out = th.cat(outs, 1)
|
340 |
+
else:
|
341 |
+
out = th.zeros(chin, 0, device=wav.device)
|
342 |
+
return out
|
343 |
+
|
344 |
+
def _separate_frame(self, frame):
|
345 |
+
demucs = self.demucs
|
346 |
+
skips = []
|
347 |
+
next_state = []
|
348 |
+
first = self.conv_state is None
|
349 |
+
stride = self.stride * demucs.resample
|
350 |
+
x = frame[None]
|
351 |
+
for idx, encode in enumerate(demucs.encoder):
|
352 |
+
stride //= demucs.stride
|
353 |
+
length = x.shape[2]
|
354 |
+
if idx == demucs.depth - 1:
|
355 |
+
# This is sligthly faster for the last conv
|
356 |
+
x = fast_conv(encode[0], x)
|
357 |
+
x = encode[1](x)
|
358 |
+
x = fast_conv(encode[2], x)
|
359 |
+
x = encode[3](x)
|
360 |
+
else:
|
361 |
+
if not first:
|
362 |
+
prev = self.conv_state.pop(0)
|
363 |
+
prev = prev[..., stride:]
|
364 |
+
tgt = (length - demucs.kernel_size) // demucs.stride + 1
|
365 |
+
missing = tgt - prev.shape[-1]
|
366 |
+
offset = length - demucs.kernel_size - demucs.stride * (missing - 1)
|
367 |
+
x = x[..., offset:]
|
368 |
+
x = encode[1](encode[0](x))
|
369 |
+
x = fast_conv(encode[2], x)
|
370 |
+
x = encode[3](x)
|
371 |
+
if not first:
|
372 |
+
x = th.cat([prev, x], -1)
|
373 |
+
next_state.append(x)
|
374 |
+
skips.append(x)
|
375 |
+
|
376 |
+
x = x.permute(2, 0, 1)
|
377 |
+
x, self.lstm_state = demucs.lstm(x, self.lstm_state)
|
378 |
+
x = x.permute(1, 2, 0)
|
379 |
+
# In the following, x contains only correct samples, i.e. the one
|
380 |
+
# for which each time position is covered by two window of the upper layer.
|
381 |
+
# extra contains extra samples to the right, and is used only as a
|
382 |
+
# better padding for the online resampling.
|
383 |
+
extra = None
|
384 |
+
for idx, decode in enumerate(demucs.decoder):
|
385 |
+
skip = skips.pop(-1)
|
386 |
+
x += skip[..., :x.shape[-1]]
|
387 |
+
x = fast_conv(decode[0], x)
|
388 |
+
x = decode[1](x)
|
389 |
+
|
390 |
+
if extra is not None:
|
391 |
+
skip = skip[..., x.shape[-1]:]
|
392 |
+
extra += skip[..., :extra.shape[-1]]
|
393 |
+
extra = decode[2](decode[1](decode[0](extra)))
|
394 |
+
x = decode[2](x)
|
395 |
+
next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1))
|
396 |
+
if extra is None:
|
397 |
+
extra = x[..., -demucs.stride:]
|
398 |
+
else:
|
399 |
+
extra[..., :demucs.stride] += next_state[-1]
|
400 |
+
x = x[..., :-demucs.stride]
|
401 |
+
|
402 |
+
if not first:
|
403 |
+
prev = self.conv_state.pop(0)
|
404 |
+
x[..., :demucs.stride] += prev
|
405 |
+
if idx != demucs.depth - 1:
|
406 |
+
x = decode[3](x)
|
407 |
+
extra = decode[3](extra)
|
408 |
+
self.conv_state = next_state
|
409 |
+
return x[0], extra[0]
|
410 |
+
|
411 |
+
|
412 |
+
def test():
|
413 |
+
import argparse
|
414 |
+
parser = argparse.ArgumentParser(
|
415 |
+
"denoiser.demucs",
|
416 |
+
description="Benchmark the streaming Demucs implementation, "
|
417 |
+
"as well as checking the delta with the offline implementation.")
|
418 |
+
parser.add_argument("--resample", default=4, type=int)
|
419 |
+
parser.add_argument("--hidden", default=48, type=int)
|
420 |
+
parser.add_argument("--device", default="cpu")
|
421 |
+
parser.add_argument("-t", "--num_threads", type=int)
|
422 |
+
parser.add_argument("-f", "--num_frames", type=int, default=1)
|
423 |
+
args = parser.parse_args()
|
424 |
+
if args.num_threads:
|
425 |
+
th.set_num_threads(args.num_threads)
|
426 |
+
sr = 16_000
|
427 |
+
sr_ms = sr / 1000
|
428 |
+
demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device)
|
429 |
+
x = th.randn(1, sr * 4).to(args.device)
|
430 |
+
out = demucs(x[None])[0]
|
431 |
+
streamer = DemucsStreamer(demucs, num_frames=args.num_frames)
|
432 |
+
out_rt = []
|
433 |
+
frame_size = streamer.total_length
|
434 |
+
with th.no_grad():
|
435 |
+
while x.shape[1] > 0:
|
436 |
+
out_rt.append(streamer.feed(x[:, :frame_size]))
|
437 |
+
x = x[:, frame_size:]
|
438 |
+
frame_size = streamer.demucs.total_stride
|
439 |
+
out_rt.append(streamer.flush())
|
440 |
+
out_rt = th.cat(out_rt, 1)
|
441 |
+
print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='')
|
442 |
+
print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='')
|
443 |
+
print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='')
|
444 |
+
print(f"delta: {th.norm(out - out_rt) / th.norm(out):.2%}, ", end='')
|
445 |
+
print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}")
|
446 |
+
|
447 |
+
|
448 |
+
if __name__ == "__main__":
|
449 |
+
test()
|
denoiser/distrib.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.data.distributed import DistributedSampler
|
13 |
+
from torch.utils.data import DataLoader, Subset
|
14 |
+
from torch.nn.parallel.distributed import DistributedDataParallel
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
rank = 0
|
18 |
+
world_size = 1
|
19 |
+
|
20 |
+
|
21 |
+
def init(args):
|
22 |
+
"""init.
|
23 |
+
|
24 |
+
Initialize DDP using the given rendezvous file.
|
25 |
+
"""
|
26 |
+
global rank, world_size
|
27 |
+
if args.ddp:
|
28 |
+
assert args.rank is not None and args.world_size is not None
|
29 |
+
rank = args.rank
|
30 |
+
world_size = args.world_size
|
31 |
+
if world_size == 1:
|
32 |
+
return
|
33 |
+
torch.cuda.set_device(rank)
|
34 |
+
torch.distributed.init_process_group(
|
35 |
+
backend=args.ddp_backend,
|
36 |
+
init_method='file://' + os.path.abspath(args.rendezvous_file),
|
37 |
+
world_size=world_size,
|
38 |
+
rank=rank)
|
39 |
+
logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size)
|
40 |
+
|
41 |
+
|
42 |
+
def average(metrics, count=1.):
|
43 |
+
"""average.
|
44 |
+
|
45 |
+
Average all the relevant metrices across processes
|
46 |
+
`metrics`should be a 1D float32 fector. Returns the average of `metrics`
|
47 |
+
over all hosts. You can use `count` to control the weight of each worker.
|
48 |
+
"""
|
49 |
+
if world_size == 1:
|
50 |
+
return metrics
|
51 |
+
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
|
52 |
+
tensor *= count
|
53 |
+
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
54 |
+
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
|
55 |
+
|
56 |
+
|
57 |
+
def wrap(model):
|
58 |
+
"""wrap.
|
59 |
+
|
60 |
+
Wrap a model with DDP if distributed training is enabled.
|
61 |
+
"""
|
62 |
+
if world_size == 1:
|
63 |
+
return model
|
64 |
+
else:
|
65 |
+
return DistributedDataParallel(
|
66 |
+
model,
|
67 |
+
device_ids=[torch.cuda.current_device()],
|
68 |
+
output_device=torch.cuda.current_device())
|
69 |
+
|
70 |
+
|
71 |
+
def barrier():
|
72 |
+
if world_size > 1:
|
73 |
+
torch.distributed.barrier()
|
74 |
+
|
75 |
+
|
76 |
+
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
|
77 |
+
"""loader.
|
78 |
+
|
79 |
+
Create a dataloader properly in case of distributed training.
|
80 |
+
If a gradient is going to be computed you must set `shuffle=True`.
|
81 |
+
|
82 |
+
:param dataset: the dataset to be parallelized
|
83 |
+
:param args: relevant args for the loader
|
84 |
+
:param shuffle: shuffle examples
|
85 |
+
:param klass: loader class
|
86 |
+
:param kwargs: relevant args
|
87 |
+
"""
|
88 |
+
|
89 |
+
if world_size == 1:
|
90 |
+
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
91 |
+
|
92 |
+
if shuffle:
|
93 |
+
# train means we will compute backward, we use DistributedSampler
|
94 |
+
sampler = DistributedSampler(dataset)
|
95 |
+
# We ignore shuffle, DistributedSampler already shuffles
|
96 |
+
return klass(dataset, *args, **kwargs, sampler=sampler)
|
97 |
+
else:
|
98 |
+
# We make a manual shard, as DistributedSampler otherwise replicate some examples
|
99 |
+
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
|
100 |
+
return klass(dataset, *args, shuffle=shuffle)
|
denoiser/dsp.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
def hz_to_mel(f):
|
14 |
+
return 2595 * np.log10(1 + f / 700)
|
15 |
+
|
16 |
+
|
17 |
+
def mel_to_hz(m):
|
18 |
+
return 700 * (10**(m / 2595) - 1)
|
19 |
+
|
20 |
+
|
21 |
+
def mel_frequencies(n_mels, fmin, fmax):
|
22 |
+
low = hz_to_mel(fmin)
|
23 |
+
high = hz_to_mel(fmax)
|
24 |
+
mels = np.linspace(low, high, n_mels)
|
25 |
+
return mel_to_hz(mels)
|
26 |
+
|
27 |
+
|
28 |
+
class LowPassFilters(torch.nn.Module):
|
29 |
+
"""
|
30 |
+
Bank of low pass filters.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
cutoffs (list[float]): list of cutoff frequencies, in [0, 1] expressed as `f/f_s` where
|
34 |
+
f_s is the samplerate.
|
35 |
+
width (int): width of the filters (i.e. kernel_size=2 * width + 1).
|
36 |
+
Default to `2 / min(cutoffs)`. Longer filters will have better attenuation
|
37 |
+
but more side effects.
|
38 |
+
Shape:
|
39 |
+
- Input: `(*, T)`
|
40 |
+
- Output: `(F, *, T` with `F` the len of `cutoffs`.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, cutoffs: list, width: int = None):
|
44 |
+
super().__init__()
|
45 |
+
self.cutoffs = cutoffs
|
46 |
+
if width is None:
|
47 |
+
width = int(2 / min(cutoffs))
|
48 |
+
self.width = width
|
49 |
+
window = torch.hamming_window(2 * width + 1, periodic=False)
|
50 |
+
t = np.arange(-width, width + 1, dtype=np.float32)
|
51 |
+
filters = []
|
52 |
+
for cutoff in cutoffs:
|
53 |
+
sinc = torch.from_numpy(np.sinc(2 * cutoff * t))
|
54 |
+
filters.append(2 * cutoff * sinc * window)
|
55 |
+
self.register_buffer("filters", torch.stack(filters).unsqueeze(1))
|
56 |
+
|
57 |
+
def forward(self, input):
|
58 |
+
*others, t = input.shape
|
59 |
+
input = input.view(-1, 1, t)
|
60 |
+
out = F.conv1d(input, self.filters, padding=self.width)
|
61 |
+
return out.permute(1, 0, 2).reshape(-1, *others, t)
|
62 |
+
|
63 |
+
def __repr__(self):
|
64 |
+
return "LossPassFilters(width={},cutoffs={})".format(self.width, self.cutoffs)
|
denoiser/enhance.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adiyoss
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
|
17 |
+
from .audio import Audioset, find_audio_files
|
18 |
+
from . import distrib, pretrained
|
19 |
+
from .demucs import DemucsStreamer
|
20 |
+
|
21 |
+
from .utils import LogProgress
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def add_flags(parser):
|
27 |
+
"""
|
28 |
+
Add the flags for the argument parser that are related to model loading and evaluation"
|
29 |
+
"""
|
30 |
+
pretrained.add_model_flags(parser)
|
31 |
+
parser.add_argument('--device', default="cpu")
|
32 |
+
parser.add_argument('--dry', type=float, default=0,
|
33 |
+
help='dry/wet knob coefficient. 0 is only input signal, 1 only denoised.')
|
34 |
+
parser.add_argument('--sample_rate', default=16_000, type=int, help='sample rate')
|
35 |
+
parser.add_argument('--num_workers', type=int, default=10)
|
36 |
+
parser.add_argument('--streaming', action="store_true",
|
37 |
+
help="true streaming evaluation for Demucs")
|
38 |
+
|
39 |
+
|
40 |
+
parser = argparse.ArgumentParser(
|
41 |
+
'denoiser.enhance',
|
42 |
+
description="Speech enhancement using Demucs - Generate enhanced files")
|
43 |
+
add_flags(parser)
|
44 |
+
parser.add_argument("--out_dir", type=str, default="enhanced",
|
45 |
+
help="directory putting enhanced wav files")
|
46 |
+
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
|
47 |
+
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
|
48 |
+
default=logging.INFO, help="more loggging")
|
49 |
+
|
50 |
+
group = parser.add_mutually_exclusive_group()
|
51 |
+
group.add_argument("--noisy_dir", type=str, default=None,
|
52 |
+
help="directory including noisy wav files")
|
53 |
+
group.add_argument("--noisy_json", type=str, default=None,
|
54 |
+
help="json file including noisy wav files")
|
55 |
+
|
56 |
+
|
57 |
+
def get_estimate(model, noisy, args):
|
58 |
+
torch.set_num_threads(1)
|
59 |
+
if args.streaming:
|
60 |
+
streamer = DemucsStreamer(model, dry=args.dry)
|
61 |
+
with torch.no_grad():
|
62 |
+
estimate = torch.cat([
|
63 |
+
streamer.feed(noisy[0]),
|
64 |
+
streamer.flush()], dim=1)[None]
|
65 |
+
else:
|
66 |
+
with torch.no_grad():
|
67 |
+
estimate = model(noisy)
|
68 |
+
estimate = (1 - args.dry) * estimate + args.dry * noisy
|
69 |
+
return estimate
|
70 |
+
|
71 |
+
|
72 |
+
def save_wavs(estimates, noisy_sigs, filenames, out_dir, sr=16_000):
|
73 |
+
# Write result
|
74 |
+
for estimate, noisy, filename in zip(estimates, noisy_sigs, filenames):
|
75 |
+
filename = os.path.join(out_dir, os.path.basename(filename).rsplit(".", 1)[0])
|
76 |
+
write(noisy, filename + "_noisy.wav", sr=sr)
|
77 |
+
write(estimate, filename + "_enhanced.wav", sr=sr)
|
78 |
+
|
79 |
+
|
80 |
+
def write(wav, filename, sr=16_000):
|
81 |
+
# Normalize audio if it prevents clipping
|
82 |
+
wav = wav / max(wav.abs().max().item(), 1)
|
83 |
+
torchaudio.save(filename, wav.cpu(), sr)
|
84 |
+
|
85 |
+
|
86 |
+
def get_dataset(args):
|
87 |
+
if hasattr(args, 'dset'):
|
88 |
+
paths = args.dset
|
89 |
+
else:
|
90 |
+
paths = args
|
91 |
+
if paths.noisy_json:
|
92 |
+
with open(paths.noisy_json) as f:
|
93 |
+
files = json.load(f)
|
94 |
+
elif paths.noisy_dir:
|
95 |
+
files = find_audio_files(paths.noisy_dir)
|
96 |
+
else:
|
97 |
+
logger.warning(
|
98 |
+
"Small sample set was not provided by either noisy_dir or noisy_json. "
|
99 |
+
"Skipping enhancement.")
|
100 |
+
return None
|
101 |
+
return Audioset(files, with_path=True, sample_rate=args.sample_rate)
|
102 |
+
|
103 |
+
|
104 |
+
def enhance(args, model=None, local_out_dir=None):
|
105 |
+
# Load model
|
106 |
+
if not model:
|
107 |
+
model = pretrained.get_model(args).to(args.device)
|
108 |
+
model.eval()
|
109 |
+
if local_out_dir:
|
110 |
+
out_dir = local_out_dir
|
111 |
+
else:
|
112 |
+
out_dir = args.out_dir
|
113 |
+
|
114 |
+
dset = get_dataset(args)
|
115 |
+
if dset is None:
|
116 |
+
return
|
117 |
+
loader = distrib.loader(dset, batch_size=1)
|
118 |
+
|
119 |
+
if distrib.rank == 0:
|
120 |
+
os.makedirs(out_dir, exist_ok=True)
|
121 |
+
distrib.barrier()
|
122 |
+
|
123 |
+
with torch.no_grad():
|
124 |
+
iterator = LogProgress(logger, loader, name="Generate enhanced files")
|
125 |
+
for data in iterator:
|
126 |
+
# Get batch data
|
127 |
+
noisy_signals, filenames = data
|
128 |
+
noisy_signals = noisy_signals.to(args.device)
|
129 |
+
# Forward
|
130 |
+
estimate = get_estimate(model, noisy_signals, args)
|
131 |
+
save_wavs(estimate, noisy_signals, filenames, out_dir, sr=args.sample_rate)
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
args = parser.parse_args()
|
136 |
+
logging.basicConfig(stream=sys.stderr, level=args.verbose)
|
137 |
+
logger.debug(args)
|
138 |
+
enhance(args, local_out_dir=args.out_dir)
|
denoiser/evaluate.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adiyoss
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
from concurrent.futures import ProcessPoolExecutor
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
import sys
|
13 |
+
|
14 |
+
from pesq import pesq
|
15 |
+
from pystoi import stoi
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .data import NoisyCleanSet
|
19 |
+
from .enhance import add_flags, get_estimate
|
20 |
+
from . import distrib, pretrained
|
21 |
+
from .utils import bold, LogProgress
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser(
|
26 |
+
'denoiser.evaluate',
|
27 |
+
description='Speech enhancement using Demucs - Evaluate model performance')
|
28 |
+
add_flags(parser)
|
29 |
+
parser.add_argument('--data_dir', help='directory including noisy.json and clean.json files')
|
30 |
+
parser.add_argument('--matching', default="sort", help='set this to dns for the dns dataset.')
|
31 |
+
parser.add_argument('--no_pesq', action="store_false", dest="pesq", default=True,
|
32 |
+
help="Don't compute PESQ.")
|
33 |
+
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
|
34 |
+
default=logging.INFO, help="More loggging")
|
35 |
+
|
36 |
+
|
37 |
+
def evaluate(args, model=None, data_loader=None):
|
38 |
+
total_pesq = 0
|
39 |
+
total_stoi = 0
|
40 |
+
total_cnt = 0
|
41 |
+
updates = 5
|
42 |
+
|
43 |
+
# Load model
|
44 |
+
if not model:
|
45 |
+
model = pretrained.get_model(args).to(args.device)
|
46 |
+
model.eval()
|
47 |
+
|
48 |
+
# Load data
|
49 |
+
if data_loader is None:
|
50 |
+
dataset = NoisyCleanSet(args.data_dir, matching=args.matching, sample_rate=args.sample_rate)
|
51 |
+
data_loader = distrib.loader(dataset, batch_size=1, num_workers=2)
|
52 |
+
pendings = []
|
53 |
+
with ProcessPoolExecutor(args.num_workers) as pool:
|
54 |
+
with torch.no_grad():
|
55 |
+
iterator = LogProgress(logger, data_loader, name="Eval estimates")
|
56 |
+
for i, data in enumerate(iterator):
|
57 |
+
# Get batch data
|
58 |
+
noisy, clean = [x.to(args.device) for x in data]
|
59 |
+
# If device is CPU, we do parallel evaluation in each CPU worker.
|
60 |
+
if args.device == 'cpu':
|
61 |
+
pendings.append(
|
62 |
+
pool.submit(_estimate_and_run_metrics, clean, model, noisy, args))
|
63 |
+
else:
|
64 |
+
estimate = get_estimate(model, noisy, args)
|
65 |
+
estimate = estimate.cpu()
|
66 |
+
clean = clean.cpu()
|
67 |
+
pendings.append(
|
68 |
+
pool.submit(_run_metrics, clean, estimate, args))
|
69 |
+
total_cnt += clean.shape[0]
|
70 |
+
|
71 |
+
for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
|
72 |
+
pesq_i, stoi_i = pending.result()
|
73 |
+
total_pesq += pesq_i
|
74 |
+
total_stoi += stoi_i
|
75 |
+
|
76 |
+
metrics = [total_pesq, total_stoi]
|
77 |
+
pesq, stoi = distrib.average([m/total_cnt for m in metrics], total_cnt)
|
78 |
+
logger.info(bold(f'Test set performance:PESQ={pesq}, STOI={stoi}.'))
|
79 |
+
return pesq, stoi
|
80 |
+
|
81 |
+
|
82 |
+
def _estimate_and_run_metrics(clean, model, noisy, args):
|
83 |
+
estimate = get_estimate(model, noisy, args)
|
84 |
+
return _run_metrics(clean, estimate, args)
|
85 |
+
|
86 |
+
|
87 |
+
def _run_metrics(clean, estimate, args):
|
88 |
+
estimate = estimate.numpy()[:, 0]
|
89 |
+
clean = clean.numpy()[:, 0]
|
90 |
+
if args.pesq:
|
91 |
+
pesq_i = get_pesq(clean, estimate, sr=args.sample_rate)
|
92 |
+
else:
|
93 |
+
pesq_i = 0
|
94 |
+
stoi_i = get_stoi(clean, estimate, sr=args.sample_rate)
|
95 |
+
return pesq_i, stoi_i
|
96 |
+
|
97 |
+
|
98 |
+
def get_pesq(ref_sig, out_sig, sr):
|
99 |
+
"""Calculate PESQ.
|
100 |
+
Args:
|
101 |
+
ref_sig: numpy.ndarray, [B, T]
|
102 |
+
out_sig: numpy.ndarray, [B, T]
|
103 |
+
Returns:
|
104 |
+
PESQ
|
105 |
+
"""
|
106 |
+
pesq_val = 0
|
107 |
+
for i in range(len(ref_sig)):
|
108 |
+
pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'wb')
|
109 |
+
return pesq_val
|
110 |
+
|
111 |
+
|
112 |
+
def get_stoi(ref_sig, out_sig, sr):
|
113 |
+
"""Calculate STOI.
|
114 |
+
Args:
|
115 |
+
ref_sig: numpy.ndarray, [B, T]
|
116 |
+
out_sig: numpy.ndarray, [B, T]
|
117 |
+
Returns:
|
118 |
+
STOI
|
119 |
+
"""
|
120 |
+
stoi_val = 0
|
121 |
+
for i in range(len(ref_sig)):
|
122 |
+
stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False)
|
123 |
+
return stoi_val
|
124 |
+
|
125 |
+
|
126 |
+
def main():
|
127 |
+
args = parser.parse_args()
|
128 |
+
logging.basicConfig(stream=sys.stderr, level=args.verbose)
|
129 |
+
logger.debug(args)
|
130 |
+
pesq, stoi = evaluate(args)
|
131 |
+
json.dump({'pesq': pesq, 'stoi': stoi}, sys.stdout)
|
132 |
+
sys.stdout.write('\n')
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == '__main__':
|
136 |
+
main()
|
denoiser/executor.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
"""
|
8 |
+
Start multiple process locally for DDP.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import subprocess as sp
|
13 |
+
import sys
|
14 |
+
|
15 |
+
from hydra import utils
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class ChildrenManager:
|
21 |
+
def __init__(self):
|
22 |
+
self.children = []
|
23 |
+
self.failed = False
|
24 |
+
|
25 |
+
def add(self, child):
|
26 |
+
child.rank = len(self.children)
|
27 |
+
self.children.append(child)
|
28 |
+
|
29 |
+
def __enter__(self):
|
30 |
+
return self
|
31 |
+
|
32 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
33 |
+
if exc_value is not None:
|
34 |
+
logger.error("An exception happened while starting workers %r", exc_value)
|
35 |
+
self.failed = True
|
36 |
+
try:
|
37 |
+
while self.children and not self.failed:
|
38 |
+
for child in list(self.children):
|
39 |
+
try:
|
40 |
+
exitcode = child.wait(0.1)
|
41 |
+
except sp.TimeoutExpired:
|
42 |
+
continue
|
43 |
+
else:
|
44 |
+
self.children.remove(child)
|
45 |
+
if exitcode:
|
46 |
+
logger.error(f"Worker {child.rank} died, killing all workers")
|
47 |
+
self.failed = True
|
48 |
+
except KeyboardInterrupt:
|
49 |
+
logger.error("Received keyboard interrupt, trying to kill all workers.")
|
50 |
+
self.failed = True
|
51 |
+
for child in self.children:
|
52 |
+
child.terminate()
|
53 |
+
if not self.failed:
|
54 |
+
logger.info("All workers completed successfully")
|
55 |
+
|
56 |
+
|
57 |
+
def start_ddp_workers():
|
58 |
+
import torch as th
|
59 |
+
|
60 |
+
world_size = th.cuda.device_count()
|
61 |
+
if not world_size:
|
62 |
+
logger.error(
|
63 |
+
"DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
|
64 |
+
sys.exit(1)
|
65 |
+
logger.info(f"Starting {world_size} worker processes for DDP.")
|
66 |
+
with ChildrenManager() as manager:
|
67 |
+
for rank in range(world_size):
|
68 |
+
kwargs = {}
|
69 |
+
argv = list(sys.argv)
|
70 |
+
argv += [f"world_size={world_size}", f"rank={rank}"]
|
71 |
+
if rank > 0:
|
72 |
+
kwargs['stdin'] = sp.DEVNULL
|
73 |
+
kwargs['stdout'] = sp.DEVNULL
|
74 |
+
kwargs['stderr'] = sp.DEVNULL
|
75 |
+
log = utils.HydraConfig().hydra.job_logging.handlers.file.filename
|
76 |
+
log += f".{rank}"
|
77 |
+
argv.append("hydra.job_logging.handlers.file.filename=" + log)
|
78 |
+
manager.add(sp.Popen([sys.executable] + argv, cwd=utils.get_original_cwd(), **kwargs))
|
79 |
+
sys.exit(int(manager.failed))
|
denoiser/live.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import sounddevice as sd
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from .demucs import DemucsStreamer
|
15 |
+
from .pretrained import add_model_flags, get_model
|
16 |
+
from .utils import bold
|
17 |
+
|
18 |
+
|
19 |
+
def get_parser():
|
20 |
+
parser = argparse.ArgumentParser(
|
21 |
+
"denoiser.live",
|
22 |
+
description="Performs live speech enhancement, reading audio from "
|
23 |
+
"the default mic (or interface specified by --in) and "
|
24 |
+
"writing the enhanced version to 'Soundflower (2ch)' "
|
25 |
+
"(or the interface specified by --out)."
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"-i", "--in", dest="in_",
|
29 |
+
help="name or index of input interface.")
|
30 |
+
parser.add_argument(
|
31 |
+
"-o", "--out", default="Soundflower (2ch)",
|
32 |
+
help="name or index of output interface.")
|
33 |
+
add_model_flags(parser)
|
34 |
+
parser.add_argument(
|
35 |
+
"--sample_rate", type=int, default=16_000,
|
36 |
+
help="Sample rate")
|
37 |
+
parser.add_argument(
|
38 |
+
"--no_compressor", action="store_false", dest="compressor",
|
39 |
+
help="Deactivate compressor on output, might lead to clipping.")
|
40 |
+
parser.add_argument(
|
41 |
+
"--device", default="cpu")
|
42 |
+
parser.add_argument(
|
43 |
+
"--dry", type=float, default=0.04,
|
44 |
+
help="Dry/wet knob, between 0 and 1. 0=maximum noise removal "
|
45 |
+
"but it might cause distortions. Default is 0.04")
|
46 |
+
parser.add_argument(
|
47 |
+
"-t", "--num_threads", type=int,
|
48 |
+
help="Number of threads. If you have DDR3 RAM, setting -t 1 can "
|
49 |
+
"improve performance.")
|
50 |
+
parser.add_argument(
|
51 |
+
"-f", "--num_frames", type=int, default=1,
|
52 |
+
help="Number of frames to process at once. Larger values increase "
|
53 |
+
"the overall lag, but will improve speed.")
|
54 |
+
return parser
|
55 |
+
|
56 |
+
|
57 |
+
def parse_audio_device(device):
|
58 |
+
if device is None:
|
59 |
+
return device
|
60 |
+
try:
|
61 |
+
return int(device)
|
62 |
+
except ValueError:
|
63 |
+
return device
|
64 |
+
|
65 |
+
|
66 |
+
def query_devices(device, kind):
|
67 |
+
try:
|
68 |
+
caps = sd.query_devices(device, kind=kind)
|
69 |
+
except ValueError:
|
70 |
+
message = bold(f"Invalid {kind} audio interface {device}.\n")
|
71 |
+
message += (
|
72 |
+
"If you are on Mac OS X, try installing Soundflower "
|
73 |
+
"(https://github.com/mattingalls/Soundflower).\n"
|
74 |
+
"You can list available interfaces with `python3 -m sounddevice` on Linux and OS X, "
|
75 |
+
"and `python.exe -m sounddevice` on Windows. You must have at least one loopback "
|
76 |
+
"audio interface to use this.")
|
77 |
+
print(message, file=sys.stderr)
|
78 |
+
sys.exit(1)
|
79 |
+
return caps
|
80 |
+
|
81 |
+
|
82 |
+
def main():
|
83 |
+
args = get_parser().parse_args()
|
84 |
+
if args.num_threads:
|
85 |
+
torch.set_num_threads(args.num_threads)
|
86 |
+
|
87 |
+
model = get_model(args).to(args.device)
|
88 |
+
model.eval()
|
89 |
+
print("Model loaded.")
|
90 |
+
streamer = DemucsStreamer(model, dry=args.dry, num_frames=args.num_frames)
|
91 |
+
|
92 |
+
device_in = parse_audio_device(args.in_)
|
93 |
+
caps = query_devices(device_in, "input")
|
94 |
+
channels_in = min(caps['max_input_channels'], 2)
|
95 |
+
stream_in = sd.InputStream(
|
96 |
+
device=device_in,
|
97 |
+
samplerate=args.sample_rate,
|
98 |
+
channels=channels_in)
|
99 |
+
|
100 |
+
device_out = parse_audio_device(args.out)
|
101 |
+
caps = query_devices(device_out, "output")
|
102 |
+
channels_out = min(caps['max_output_channels'], 2)
|
103 |
+
stream_out = sd.OutputStream(
|
104 |
+
device=device_out,
|
105 |
+
samplerate=args.sample_rate,
|
106 |
+
channels=channels_out)
|
107 |
+
|
108 |
+
stream_in.start()
|
109 |
+
stream_out.start()
|
110 |
+
first = True
|
111 |
+
current_time = 0
|
112 |
+
last_log_time = 0
|
113 |
+
last_error_time = 0
|
114 |
+
cooldown_time = 2
|
115 |
+
log_delta = 10
|
116 |
+
sr_ms = args.sample_rate / 1000
|
117 |
+
stride_ms = streamer.stride / sr_ms
|
118 |
+
print(f"Ready to process audio, total lag: {streamer.total_length / sr_ms:.1f}ms.")
|
119 |
+
while True:
|
120 |
+
try:
|
121 |
+
if current_time > last_log_time + log_delta:
|
122 |
+
last_log_time = current_time
|
123 |
+
tpf = streamer.time_per_frame * 1000
|
124 |
+
rtf = tpf / stride_ms
|
125 |
+
print(f"time per frame: {tpf:.1f}ms, ", end='')
|
126 |
+
print(f"RTF: {rtf:.1f}")
|
127 |
+
streamer.reset_time_per_frame()
|
128 |
+
|
129 |
+
length = streamer.total_length if first else streamer.stride
|
130 |
+
first = False
|
131 |
+
current_time += length / args.sample_rate
|
132 |
+
frame, overflow = stream_in.read(length)
|
133 |
+
frame = torch.from_numpy(frame).mean(dim=1).to(args.device)
|
134 |
+
with torch.no_grad():
|
135 |
+
out = streamer.feed(frame[None])[0]
|
136 |
+
if not out.numel():
|
137 |
+
continue
|
138 |
+
if args.compressor:
|
139 |
+
out = 0.99 * torch.tanh(out)
|
140 |
+
out = out[:, None].repeat(1, channels_out)
|
141 |
+
mx = out.abs().max().item()
|
142 |
+
if mx > 1:
|
143 |
+
print("Clipping!!")
|
144 |
+
out.clamp_(-1, 1)
|
145 |
+
out = out.cpu().numpy()
|
146 |
+
underflow = stream_out.write(out)
|
147 |
+
if overflow or underflow:
|
148 |
+
if current_time >= last_error_time + cooldown_time:
|
149 |
+
last_error_time = current_time
|
150 |
+
tpf = 1000 * streamer.time_per_frame
|
151 |
+
print(f"Not processing audio fast enough, time per frame is {tpf:.1f}ms "
|
152 |
+
f"(should be less than {stride_ms:.1f}ms).")
|
153 |
+
except KeyboardInterrupt:
|
154 |
+
print("Stopping")
|
155 |
+
break
|
156 |
+
stream_out.stop()
|
157 |
+
stream_in.stop()
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
main()
|
denoiser/pretrained.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import torch.hub
|
11 |
+
|
12 |
+
from .demucs import Demucs
|
13 |
+
from .utils import deserialize_model
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/"
|
17 |
+
DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th"
|
18 |
+
DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th"
|
19 |
+
MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th"
|
20 |
+
|
21 |
+
|
22 |
+
def _demucs(pretrained, url, **kwargs):
|
23 |
+
model = Demucs(**kwargs)
|
24 |
+
if pretrained:
|
25 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
|
26 |
+
model.load_state_dict(state_dict)
|
27 |
+
return model
|
28 |
+
|
29 |
+
|
30 |
+
def dns48(pretrained=True):
|
31 |
+
return _demucs(pretrained, DNS_48_URL, hidden=48)
|
32 |
+
|
33 |
+
|
34 |
+
def dns64(pretrained=True):
|
35 |
+
return _demucs(pretrained, DNS_64_URL, hidden=64)
|
36 |
+
|
37 |
+
|
38 |
+
def master64(pretrained=True):
|
39 |
+
return _demucs(pretrained, MASTER_64_URL, hidden=64)
|
40 |
+
|
41 |
+
|
42 |
+
def add_model_flags(parser):
|
43 |
+
group = parser.add_mutually_exclusive_group(required=False)
|
44 |
+
group.add_argument("-m", "--model_path", help="Path to local trained model.")
|
45 |
+
group.add_argument("--dns48", action="store_true",
|
46 |
+
help="Use pre-trained real time H=48 model trained on DNS.")
|
47 |
+
group.add_argument("--dns64", action="store_true",
|
48 |
+
help="Use pre-trained real time H=64 model trained on DNS.")
|
49 |
+
group.add_argument("--master64", action="store_true",
|
50 |
+
help="Use pre-trained real time H=64 model trained on DNS and Valentini.")
|
51 |
+
|
52 |
+
|
53 |
+
def get_model(args):
|
54 |
+
"""
|
55 |
+
Load local model package or torchhub pre-trained model.
|
56 |
+
"""
|
57 |
+
if args.model_path:
|
58 |
+
logger.info("Loading model from %s", args.model_path)
|
59 |
+
model = Demucs(hidden=64)
|
60 |
+
pkg = torch.load(args.model_path, map_location='cpu')
|
61 |
+
model.load_state_dict(pkg)
|
62 |
+
elif args.dns64:
|
63 |
+
logger.info("Loading pre-trained real time H=64 model trained on DNS.")
|
64 |
+
model = dns64()
|
65 |
+
elif args.master64:
|
66 |
+
logger.info("Loading pre-trained real time H=64 model trained on DNS and Valentini.")
|
67 |
+
model = master64()
|
68 |
+
else:
|
69 |
+
logger.info("Loading pre-trained real time H=48 model trained on DNS.")
|
70 |
+
model = dns48()
|
71 |
+
logger.debug(model)
|
72 |
+
return model
|
denoiser/resample.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch as th
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
|
14 |
+
def sinc(t):
|
15 |
+
"""sinc.
|
16 |
+
|
17 |
+
:param t: the input tensor
|
18 |
+
"""
|
19 |
+
return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype), th.sin(t) / t)
|
20 |
+
|
21 |
+
|
22 |
+
def kernel_upsample2(zeros=56):
|
23 |
+
"""kernel_upsample2.
|
24 |
+
|
25 |
+
"""
|
26 |
+
win = th.hann_window(4 * zeros + 1, periodic=False)
|
27 |
+
winodd = win[1::2]
|
28 |
+
t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
|
29 |
+
t *= math.pi
|
30 |
+
kernel = (sinc(t) * winodd).view(1, 1, -1)
|
31 |
+
return kernel
|
32 |
+
|
33 |
+
|
34 |
+
def upsample2(x, zeros=56):
|
35 |
+
"""
|
36 |
+
Upsampling the input by 2 using sinc interpolation.
|
37 |
+
Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
|
38 |
+
ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
|
39 |
+
Vol. 9. IEEE, 1984.
|
40 |
+
"""
|
41 |
+
*other, time = x.shape
|
42 |
+
kernel = kernel_upsample2(zeros).to(x)
|
43 |
+
out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time)
|
44 |
+
y = th.stack([x, out], dim=-1)
|
45 |
+
return y.view(*other, -1)
|
46 |
+
|
47 |
+
|
48 |
+
def kernel_downsample2(zeros=56):
|
49 |
+
"""kernel_downsample2.
|
50 |
+
|
51 |
+
"""
|
52 |
+
win = th.hann_window(4 * zeros + 1, periodic=False)
|
53 |
+
winodd = win[1::2]
|
54 |
+
t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
|
55 |
+
t.mul_(math.pi)
|
56 |
+
kernel = (sinc(t) * winodd).view(1, 1, -1)
|
57 |
+
return kernel
|
58 |
+
|
59 |
+
|
60 |
+
def downsample2(x, zeros=56):
|
61 |
+
"""
|
62 |
+
Downsampling the input by 2 using sinc interpolation.
|
63 |
+
Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
|
64 |
+
ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
|
65 |
+
Vol. 9. IEEE, 1984.
|
66 |
+
"""
|
67 |
+
if x.shape[-1] % 2 != 0:
|
68 |
+
x = F.pad(x, (0, 1))
|
69 |
+
xeven = x[..., ::2]
|
70 |
+
xodd = x[..., 1::2]
|
71 |
+
*other, time = xodd.shape
|
72 |
+
kernel = kernel_downsample2(zeros).to(x)
|
73 |
+
out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view(
|
74 |
+
*other, time)
|
75 |
+
return out.view(*other, -1).mul(0.5)
|
denoiser/solver.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adiyoss
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from . import augment, distrib, pretrained
|
18 |
+
from .enhance import enhance
|
19 |
+
from .evaluate import evaluate
|
20 |
+
from .stft_loss import MultiResolutionSTFTLoss
|
21 |
+
from .utils import bold, copy_state, pull_metric, serialize_model, swap_state, LogProgress
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class Solver(object):
|
27 |
+
def __init__(self, data, model, optimizer, args):
|
28 |
+
self.tr_loader = data['tr_loader']
|
29 |
+
self.cv_loader = data['cv_loader']
|
30 |
+
self.tt_loader = data['tt_loader']
|
31 |
+
self.model = model
|
32 |
+
self.dmodel = distrib.wrap(model)
|
33 |
+
self.optimizer = optimizer
|
34 |
+
|
35 |
+
# data augment
|
36 |
+
augments = []
|
37 |
+
if args.remix:
|
38 |
+
augments.append(augment.Remix())
|
39 |
+
if args.bandmask:
|
40 |
+
augments.append(augment.BandMask(args.bandmask, sample_rate=args.sample_rate))
|
41 |
+
if args.shift:
|
42 |
+
augments.append(augment.Shift(args.shift, args.shift_same))
|
43 |
+
if args.revecho:
|
44 |
+
augments.append(
|
45 |
+
augment.RevEcho(args.revecho))
|
46 |
+
self.augment = torch.nn.Sequential(*augments)
|
47 |
+
|
48 |
+
# Training config
|
49 |
+
self.device = args.device
|
50 |
+
self.epochs = args.epochs
|
51 |
+
|
52 |
+
# Checkpoints
|
53 |
+
self.continue_from = args.continue_from
|
54 |
+
self.eval_every = args.eval_every
|
55 |
+
self.checkpoint = args.checkpoint
|
56 |
+
if self.checkpoint:
|
57 |
+
self.checkpoint_file = Path(args.checkpoint_file)
|
58 |
+
self.best_file = Path(args.best_file)
|
59 |
+
logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve())
|
60 |
+
self.history_file = args.history_file
|
61 |
+
|
62 |
+
self.best_state = None
|
63 |
+
self.restart = args.restart
|
64 |
+
self.history = [] # Keep track of loss
|
65 |
+
self.samples_dir = args.samples_dir # Where to save samples
|
66 |
+
self.num_prints = args.num_prints # Number of times to log per epoch
|
67 |
+
self.args = args
|
68 |
+
self.mrstftloss = MultiResolutionSTFTLoss(factor_sc=args.stft_sc_factor,
|
69 |
+
factor_mag=args.stft_mag_factor)
|
70 |
+
self._reset()
|
71 |
+
|
72 |
+
def _serialize(self):
|
73 |
+
package = {}
|
74 |
+
package['model'] = serialize_model(self.model)
|
75 |
+
package['optimizer'] = self.optimizer.state_dict()
|
76 |
+
package['history'] = self.history
|
77 |
+
package['best_state'] = self.best_state
|
78 |
+
package['args'] = self.args
|
79 |
+
tmp_path = str(self.checkpoint_file) + ".tmp"
|
80 |
+
torch.save(package, tmp_path)
|
81 |
+
# renaming is sort of atomic on UNIX (not really true on NFS)
|
82 |
+
# but still less chances of leaving a half written checkpoint behind.
|
83 |
+
os.rename(tmp_path, self.checkpoint_file)
|
84 |
+
|
85 |
+
# Saving only the latest best model.
|
86 |
+
model = package['model']
|
87 |
+
model['state'] = self.best_state
|
88 |
+
tmp_path = str(self.best_file) + ".tmp"
|
89 |
+
torch.save(model, tmp_path)
|
90 |
+
os.rename(tmp_path, self.best_file)
|
91 |
+
|
92 |
+
def _reset(self):
|
93 |
+
"""_reset."""
|
94 |
+
load_from = None
|
95 |
+
load_best = False
|
96 |
+
keep_history = True
|
97 |
+
# Reset
|
98 |
+
if self.checkpoint and self.checkpoint_file.exists() and not self.restart:
|
99 |
+
load_from = self.checkpoint_file
|
100 |
+
elif self.continue_from:
|
101 |
+
load_from = self.continue_from
|
102 |
+
load_best = self.args.continue_best
|
103 |
+
keep_history = False
|
104 |
+
|
105 |
+
if load_from:
|
106 |
+
logger.info(f'Loading checkpoint model: {load_from}')
|
107 |
+
package = torch.load(load_from, 'cpu')
|
108 |
+
if load_best:
|
109 |
+
self.model.load_state_dict(package['best_state'])
|
110 |
+
else:
|
111 |
+
self.model.load_state_dict(package['model']['state'])
|
112 |
+
if 'optimizer' in package and not load_best:
|
113 |
+
self.optimizer.load_state_dict(package['optimizer'])
|
114 |
+
if keep_history:
|
115 |
+
self.history = package['history']
|
116 |
+
self.best_state = package['best_state']
|
117 |
+
continue_pretrained = self.args.continue_pretrained
|
118 |
+
if continue_pretrained:
|
119 |
+
logger.info("Fine tuning from pre-trained model %s", continue_pretrained)
|
120 |
+
model = getattr(pretrained, self.args.continue_pretrained)()
|
121 |
+
self.model.load_state_dict(model.state_dict())
|
122 |
+
|
123 |
+
def train(self):
|
124 |
+
# Optimizing the model
|
125 |
+
if self.history:
|
126 |
+
logger.info("Replaying metrics from previous run")
|
127 |
+
for epoch, metrics in enumerate(self.history):
|
128 |
+
info = " ".join(f"{k.capitalize()}={v:.5f}" for k, v in metrics.items())
|
129 |
+
logger.info(f"Epoch {epoch + 1}: {info}")
|
130 |
+
|
131 |
+
for epoch in range(len(self.history), self.epochs):
|
132 |
+
# Train one epoch
|
133 |
+
self.model.train()
|
134 |
+
start = time.time()
|
135 |
+
logger.info('-' * 70)
|
136 |
+
logger.info("Training...")
|
137 |
+
train_loss = self._run_one_epoch(epoch)
|
138 |
+
logger.info(
|
139 |
+
bold(f'Train Summary | End of Epoch {epoch + 1} | '
|
140 |
+
f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))
|
141 |
+
|
142 |
+
if self.cv_loader:
|
143 |
+
# Cross validation
|
144 |
+
logger.info('-' * 70)
|
145 |
+
logger.info('Cross validation...')
|
146 |
+
self.model.eval()
|
147 |
+
with torch.no_grad():
|
148 |
+
valid_loss = self._run_one_epoch(epoch, cross_valid=True)
|
149 |
+
logger.info(
|
150 |
+
bold(f'Valid Summary | End of Epoch {epoch + 1} | '
|
151 |
+
f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))
|
152 |
+
else:
|
153 |
+
valid_loss = 0
|
154 |
+
|
155 |
+
best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
|
156 |
+
metrics = {'train': train_loss, 'valid': valid_loss, 'best': best_loss}
|
157 |
+
# Save the best model
|
158 |
+
if valid_loss == best_loss:
|
159 |
+
logger.info(bold('New best valid loss %.4f'), valid_loss)
|
160 |
+
self.best_state = copy_state(self.model.state_dict())
|
161 |
+
|
162 |
+
# evaluate and enhance samples every 'eval_every' argument number of epochs
|
163 |
+
# also evaluate on last epoch
|
164 |
+
if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1:
|
165 |
+
# Evaluate on the testset
|
166 |
+
logger.info('-' * 70)
|
167 |
+
logger.info('Evaluating on the test set...')
|
168 |
+
# We switch to the best known model for testing
|
169 |
+
with swap_state(self.model, self.best_state):
|
170 |
+
pesq, stoi = evaluate(self.args, self.model, self.tt_loader)
|
171 |
+
|
172 |
+
metrics.update({'pesq': pesq, 'stoi': stoi})
|
173 |
+
|
174 |
+
# enhance some samples
|
175 |
+
logger.info('Enhance and save samples...')
|
176 |
+
enhance(self.args, self.model, self.samples_dir)
|
177 |
+
|
178 |
+
self.history.append(metrics)
|
179 |
+
info = " | ".join(f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
|
180 |
+
logger.info('-' * 70)
|
181 |
+
logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))
|
182 |
+
|
183 |
+
if distrib.rank == 0:
|
184 |
+
json.dump(self.history, open(self.history_file, "w"), indent=2)
|
185 |
+
# Save model each epoch
|
186 |
+
if self.checkpoint:
|
187 |
+
self._serialize()
|
188 |
+
logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())
|
189 |
+
|
190 |
+
def _run_one_epoch(self, epoch, cross_valid=False):
|
191 |
+
total_loss = 0
|
192 |
+
data_loader = self.tr_loader if not cross_valid else self.cv_loader
|
193 |
+
|
194 |
+
# get a different order for distributed training, otherwise this will get ignored
|
195 |
+
data_loader.epoch = epoch
|
196 |
+
|
197 |
+
label = ["Train", "Valid"][cross_valid]
|
198 |
+
name = label + f" | Epoch {epoch + 1}"
|
199 |
+
logprog = LogProgress(logger, data_loader, updates=self.num_prints, name=name)
|
200 |
+
for i, data in enumerate(logprog):
|
201 |
+
noisy, clean = [x.to(self.device) for x in data]
|
202 |
+
if not cross_valid:
|
203 |
+
sources = torch.stack([noisy - clean, clean])
|
204 |
+
sources = self.augment(sources)
|
205 |
+
noise, clean = sources
|
206 |
+
noisy = noise + clean
|
207 |
+
estimate = self.dmodel(noisy)
|
208 |
+
# apply a loss function after each layer
|
209 |
+
with torch.autograd.set_detect_anomaly(True):
|
210 |
+
if self.args.loss == 'l1':
|
211 |
+
loss = F.l1_loss(clean, estimate)
|
212 |
+
elif self.args.loss == 'l2':
|
213 |
+
loss = F.mse_loss(clean, estimate)
|
214 |
+
elif self.args.loss == 'huber':
|
215 |
+
loss = F.smooth_l1_loss(clean, estimate)
|
216 |
+
else:
|
217 |
+
raise ValueError(f"Invalid loss {self.args.loss}")
|
218 |
+
# MultiResolution STFT loss
|
219 |
+
if self.args.stft_loss:
|
220 |
+
sc_loss, mag_loss = self.mrstftloss(estimate.squeeze(1), clean.squeeze(1))
|
221 |
+
loss += sc_loss + mag_loss
|
222 |
+
|
223 |
+
# optimize model in training mode
|
224 |
+
if not cross_valid:
|
225 |
+
self.optimizer.zero_grad()
|
226 |
+
loss.backward()
|
227 |
+
self.optimizer.step()
|
228 |
+
|
229 |
+
total_loss += loss.item()
|
230 |
+
logprog.update(loss=format(total_loss / (i + 1), ".5f"))
|
231 |
+
# Just in case, clear some memory
|
232 |
+
del loss, estimate
|
233 |
+
return distrib.average([total_loss / (i + 1)], i + 1)[0]
|
denoiser/stft_loss.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
# Original copyright 2019 Tomoki Hayashi
|
9 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
10 |
+
|
11 |
+
"""STFT-based Loss modules."""
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
18 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
19 |
+
Args:
|
20 |
+
x (Tensor): Input signal tensor (B, T).
|
21 |
+
fft_size (int): FFT size.
|
22 |
+
hop_size (int): Hop size.
|
23 |
+
win_length (int): Window length.
|
24 |
+
window (str): Window function type.
|
25 |
+
Returns:
|
26 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
27 |
+
"""
|
28 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
29 |
+
real = x_stft[..., 0]
|
30 |
+
imag = x_stft[..., 1]
|
31 |
+
|
32 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
33 |
+
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
34 |
+
|
35 |
+
|
36 |
+
class SpectralConvergengeLoss(torch.nn.Module):
|
37 |
+
"""Spectral convergence loss module."""
|
38 |
+
|
39 |
+
def __init__(self):
|
40 |
+
"""Initilize spectral convergence loss module."""
|
41 |
+
super(SpectralConvergengeLoss, self).__init__()
|
42 |
+
|
43 |
+
def forward(self, x_mag, y_mag):
|
44 |
+
"""Calculate forward propagation.
|
45 |
+
Args:
|
46 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
47 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
48 |
+
Returns:
|
49 |
+
Tensor: Spectral convergence loss value.
|
50 |
+
"""
|
51 |
+
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
52 |
+
|
53 |
+
|
54 |
+
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
55 |
+
"""Log STFT magnitude loss module."""
|
56 |
+
|
57 |
+
def __init__(self):
|
58 |
+
"""Initilize los STFT magnitude loss module."""
|
59 |
+
super(LogSTFTMagnitudeLoss, self).__init__()
|
60 |
+
|
61 |
+
def forward(self, x_mag, y_mag):
|
62 |
+
"""Calculate forward propagation.
|
63 |
+
Args:
|
64 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
65 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
66 |
+
Returns:
|
67 |
+
Tensor: Log STFT magnitude loss value.
|
68 |
+
"""
|
69 |
+
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
70 |
+
|
71 |
+
|
72 |
+
class STFTLoss(torch.nn.Module):
|
73 |
+
"""STFT loss module."""
|
74 |
+
|
75 |
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
|
76 |
+
"""Initialize STFT loss module."""
|
77 |
+
super(STFTLoss, self).__init__()
|
78 |
+
self.fft_size = fft_size
|
79 |
+
self.shift_size = shift_size
|
80 |
+
self.win_length = win_length
|
81 |
+
self.window = getattr(torch, window)(win_length)
|
82 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
83 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
84 |
+
|
85 |
+
def forward(self, x, y):
|
86 |
+
"""Calculate forward propagation.
|
87 |
+
Args:
|
88 |
+
x (Tensor): Predicted signal (B, T).
|
89 |
+
y (Tensor): Groundtruth signal (B, T).
|
90 |
+
Returns:
|
91 |
+
Tensor: Spectral convergence loss value.
|
92 |
+
Tensor: Log STFT magnitude loss value.
|
93 |
+
"""
|
94 |
+
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
95 |
+
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
96 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
97 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
98 |
+
|
99 |
+
return sc_loss, mag_loss
|
100 |
+
|
101 |
+
|
102 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
103 |
+
"""Multi resolution STFT loss module."""
|
104 |
+
|
105 |
+
def __init__(self,
|
106 |
+
fft_sizes=[1024, 2048, 512],
|
107 |
+
hop_sizes=[120, 240, 50],
|
108 |
+
win_lengths=[600, 1200, 240],
|
109 |
+
window="hann_window", factor_sc=0.1, factor_mag=0.1):
|
110 |
+
"""Initialize Multi resolution STFT loss module.
|
111 |
+
Args:
|
112 |
+
fft_sizes (list): List of FFT sizes.
|
113 |
+
hop_sizes (list): List of hop sizes.
|
114 |
+
win_lengths (list): List of window lengths.
|
115 |
+
window (str): Window function type.
|
116 |
+
factor (float): a balancing factor across different losses.
|
117 |
+
"""
|
118 |
+
super(MultiResolutionSTFTLoss, self).__init__()
|
119 |
+
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
120 |
+
self.stft_losses = torch.nn.ModuleList()
|
121 |
+
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
122 |
+
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
123 |
+
self.factor_sc = factor_sc
|
124 |
+
self.factor_mag = factor_mag
|
125 |
+
|
126 |
+
def forward(self, x, y):
|
127 |
+
"""Calculate forward propagation.
|
128 |
+
Args:
|
129 |
+
x (Tensor): Predicted signal (B, T).
|
130 |
+
y (Tensor): Groundtruth signal (B, T).
|
131 |
+
Returns:
|
132 |
+
Tensor: Multi resolution spectral convergence loss value.
|
133 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
134 |
+
"""
|
135 |
+
sc_loss = 0.0
|
136 |
+
mag_loss = 0.0
|
137 |
+
for f in self.stft_losses:
|
138 |
+
sc_l, mag_l = f(x, y)
|
139 |
+
sc_loss += sc_l
|
140 |
+
mag_loss += mag_l
|
141 |
+
sc_loss /= len(self.stft_losses)
|
142 |
+
mag_loss /= len(self.stft_losses)
|
143 |
+
|
144 |
+
return self.factor_sc*sc_loss, self.factor_mag*mag_loss
|
denoiser/utils.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# author: adefossez
|
7 |
+
|
8 |
+
import functools
|
9 |
+
import logging
|
10 |
+
from contextlib import contextmanager
|
11 |
+
import inspect
|
12 |
+
import time
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def capture_init(init):
|
18 |
+
"""capture_init.
|
19 |
+
|
20 |
+
Decorate `__init__` with this, and you can then
|
21 |
+
recover the *args and **kwargs passed to it in `self._init_args_kwargs`
|
22 |
+
"""
|
23 |
+
@functools.wraps(init)
|
24 |
+
def __init__(self, *args, **kwargs):
|
25 |
+
self._init_args_kwargs = (args, kwargs)
|
26 |
+
init(self, *args, **kwargs)
|
27 |
+
|
28 |
+
return __init__
|
29 |
+
|
30 |
+
|
31 |
+
def deserialize_model(package, strict=False):
|
32 |
+
"""deserialize_model.
|
33 |
+
|
34 |
+
"""
|
35 |
+
klass = package['class']
|
36 |
+
if strict:
|
37 |
+
model = klass(*package['args'], **package['kwargs'])
|
38 |
+
else:
|
39 |
+
sig = inspect.signature(klass)
|
40 |
+
kw = package['kwargs']
|
41 |
+
for key in list(kw):
|
42 |
+
if key not in sig.parameters:
|
43 |
+
logger.warning("Dropping inexistant parameter %s", key)
|
44 |
+
del kw[key]
|
45 |
+
model = klass(*package['args'], **kw)
|
46 |
+
model.load_state_dict(package['state'])
|
47 |
+
return model
|
48 |
+
|
49 |
+
|
50 |
+
def copy_state(state):
|
51 |
+
return {k: v.cpu().clone() for k, v in state.items()}
|
52 |
+
|
53 |
+
|
54 |
+
def serialize_model(model):
|
55 |
+
args, kwargs = model._init_args_kwargs
|
56 |
+
state = copy_state(model.state_dict())
|
57 |
+
return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
|
58 |
+
|
59 |
+
|
60 |
+
@contextmanager
|
61 |
+
def swap_state(model, state):
|
62 |
+
"""
|
63 |
+
Context manager that swaps the state of a model, e.g:
|
64 |
+
|
65 |
+
# model is in old state
|
66 |
+
with swap_state(model, new_state):
|
67 |
+
# model in new state
|
68 |
+
# model back to old state
|
69 |
+
"""
|
70 |
+
old_state = copy_state(model.state_dict())
|
71 |
+
model.load_state_dict(state)
|
72 |
+
try:
|
73 |
+
yield
|
74 |
+
finally:
|
75 |
+
model.load_state_dict(old_state)
|
76 |
+
|
77 |
+
|
78 |
+
def pull_metric(history, name):
|
79 |
+
out = []
|
80 |
+
for metrics in history:
|
81 |
+
if name in metrics:
|
82 |
+
out.append(metrics[name])
|
83 |
+
return out
|
84 |
+
|
85 |
+
|
86 |
+
class LogProgress:
|
87 |
+
"""
|
88 |
+
Sort of like tqdm but using log lines and not as real time.
|
89 |
+
Args:
|
90 |
+
- logger: logger obtained from `logging.getLogger`,
|
91 |
+
- iterable: iterable object to wrap
|
92 |
+
- updates (int): number of lines that will be printed, e.g.
|
93 |
+
if `updates=5`, log every 1/5th of the total length.
|
94 |
+
- total (int): length of the iterable, in case it does not support
|
95 |
+
`len`.
|
96 |
+
- name (str): prefix to use in the log.
|
97 |
+
- level: logging level (like `logging.INFO`).
|
98 |
+
"""
|
99 |
+
def __init__(self,
|
100 |
+
logger,
|
101 |
+
iterable,
|
102 |
+
updates=5,
|
103 |
+
total=None,
|
104 |
+
name="LogProgress",
|
105 |
+
level=logging.INFO):
|
106 |
+
self.iterable = iterable
|
107 |
+
self.total = total or len(iterable)
|
108 |
+
self.updates = updates
|
109 |
+
self.name = name
|
110 |
+
self.logger = logger
|
111 |
+
self.level = level
|
112 |
+
|
113 |
+
def update(self, **infos):
|
114 |
+
self._infos = infos
|
115 |
+
|
116 |
+
def __iter__(self):
|
117 |
+
self._iterator = iter(self.iterable)
|
118 |
+
self._index = -1
|
119 |
+
self._infos = {}
|
120 |
+
self._begin = time.time()
|
121 |
+
return self
|
122 |
+
|
123 |
+
def __next__(self):
|
124 |
+
self._index += 1
|
125 |
+
try:
|
126 |
+
value = next(self._iterator)
|
127 |
+
except StopIteration:
|
128 |
+
raise
|
129 |
+
else:
|
130 |
+
return value
|
131 |
+
finally:
|
132 |
+
log_every = max(1, self.total // self.updates)
|
133 |
+
# logging is delayed by 1 it, in order to have the metrics from update
|
134 |
+
if self._index >= 1 and self._index % log_every == 0:
|
135 |
+
self._log()
|
136 |
+
|
137 |
+
def _log(self):
|
138 |
+
self._speed = (1 + self._index) / (time.time() - self._begin)
|
139 |
+
infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items())
|
140 |
+
if self._speed < 1e-4:
|
141 |
+
speed = "oo sec/it"
|
142 |
+
elif self._speed < 0.1:
|
143 |
+
speed = f"{1/self._speed:.1f} sec/it"
|
144 |
+
else:
|
145 |
+
speed = f"{self._speed:.1f} it/sec"
|
146 |
+
out = f"{self.name} | {self._index}/{self.total} | {speed}"
|
147 |
+
if infos:
|
148 |
+
out += " | " + infos
|
149 |
+
self.logger.log(self.level, out)
|
150 |
+
|
151 |
+
|
152 |
+
def colorize(text, color):
|
153 |
+
"""
|
154 |
+
Display text with some ANSI color in the terminal.
|
155 |
+
"""
|
156 |
+
code = f"\033[{color}m"
|
157 |
+
restore = "\033[0m"
|
158 |
+
return "".join([code, text, restore])
|
159 |
+
|
160 |
+
|
161 |
+
def bold(text):
|
162 |
+
"""
|
163 |
+
Display text in bold in the terminal.
|
164 |
+
"""
|
165 |
+
return colorize(text, "1")
|