DeepLearning101 commited on
Commit
109bb65
·
1 Parent(s): 2917403

Upload 17 files

Browse files
denoiser/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
denoiser/audio.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import json
9
+ from pathlib import Path
10
+ import math
11
+ import os
12
+ import sys
13
+
14
+ import torchaudio
15
+ from torch.nn import functional as F
16
+
17
+
18
+ def find_audio_files(path, exts=[".wav"], progress=True):
19
+ audio_files = []
20
+ for root, folders, files in os.walk(path, followlinks=True):
21
+ for file in files:
22
+ file = Path(root) / file
23
+ if file.suffix.lower() in exts:
24
+ audio_files.append(str(file.resolve()))
25
+ meta = []
26
+ for idx, file in enumerate(audio_files):
27
+ siginfo, _ = torchaudio.info(file)
28
+ length = siginfo.length // siginfo.channels
29
+ meta.append((file, length))
30
+ if progress:
31
+ print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
32
+ meta.sort()
33
+ return meta
34
+
35
+
36
+ class Audioset:
37
+ def __init__(self, files=None, length=None, stride=None,
38
+ pad=True, with_path=False, sample_rate=None):
39
+ """
40
+ files should be a list [(file, length)]
41
+ """
42
+ self.files = files
43
+ self.num_examples = []
44
+ self.length = length
45
+ self.stride = stride or length
46
+ self.with_path = with_path
47
+ self.sample_rate = sample_rate
48
+ for file, file_length in self.files:
49
+ if length is None:
50
+ examples = 1
51
+ elif file_length < length:
52
+ examples = 1 if pad else 0
53
+ elif pad:
54
+ examples = int(math.ceil((file_length - self.length) / self.stride) + 1)
55
+ else:
56
+ examples = (file_length - self.length) // self.stride + 1
57
+ self.num_examples.append(examples)
58
+
59
+ def __len__(self):
60
+ return sum(self.num_examples)
61
+
62
+ def __getitem__(self, index):
63
+ for (file, _), examples in zip(self.files, self.num_examples):
64
+ if index >= examples:
65
+ index -= examples
66
+ continue
67
+ num_frames = 0
68
+ offset = 0
69
+ if self.length is not None:
70
+ offset = self.stride * index
71
+ num_frames = self.length
72
+ out, sr = torchaudio.load(str(file), offset=offset, num_frames=num_frames)
73
+ if self.sample_rate is not None:
74
+ if sr != self.sample_rate:
75
+ raise RuntimeError(f"Expected {file} to have sample rate of "
76
+ f"{self.sample_rate}, but got {sr}")
77
+ if num_frames:
78
+ out = F.pad(out, (0, num_frames - out.shape[-1]))
79
+ if self.with_path:
80
+ return out, file
81
+ else:
82
+ return out
83
+
84
+
85
+ if __name__ == "__main__":
86
+ meta = []
87
+ for path in sys.argv[1:]:
88
+ meta += find_audio_files(path)
89
+ json.dump(meta, sys.stdout, indent=4)
denoiser/augment.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import random
9
+ import torch as th
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ from . import dsp
14
+
15
+
16
+ class Remix(nn.Module):
17
+ """Remix.
18
+ Mixes different noises with clean speech within a given batch
19
+ """
20
+
21
+ def forward(self, sources):
22
+ noise, clean = sources
23
+ bs, *other = noise.shape
24
+ device = noise.device
25
+ perm = th.argsort(th.rand(bs, device=device), dim=0)
26
+ return th.stack([noise[perm], clean])
27
+
28
+
29
+ class RevEcho(nn.Module):
30
+ """
31
+ Hacky Reverb but runs on GPU without slowing down training.
32
+ This reverb adds a succession of attenuated echos of the input
33
+ signal to itself. Intuitively, the delay of the first echo will happen
34
+ after roughly 2x the radius of the room and is controlled by `first_delay`.
35
+ Then RevEcho keeps adding echos with the same delay and further attenuation
36
+ until the amplitude ratio between the last and first echo is 1e-3.
37
+ The attenuation factor and the number of echos to adds is controlled
38
+ by RT60 (measured in seconds). RT60 is the average time to get to -60dB
39
+ (remember volume is measured over the squared amplitude so this matches
40
+ the 1e-3 ratio).
41
+
42
+ At each call to RevEcho, `first_delay`, `initial` and `RT60` are
43
+ sampled from their range. Then, to prevent this reverb from being too regular,
44
+ the delay time is resampled uniformly within `first_delay +- 10%`,
45
+ as controlled by the `jitter` parameter. Finally, for a denser reverb,
46
+ multiple trains of echos are added with different jitter noises.
47
+
48
+ Args:
49
+ - initial: amplitude of the first echo as a fraction
50
+ of the input signal. For each sample, actually sampled from
51
+ `[0, initial]`. Larger values means louder reverb. Physically,
52
+ this would depend on the absorption of the room walls.
53
+ - rt60: range of values to sample the RT60 in seconds, i.e.
54
+ after RT60 seconds, the echo amplitude is 1e-3 of the first echo.
55
+ The default values follow the recommendations of
56
+ https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf, Section 2.4.
57
+ Physically this would also be related to the absorption of the
58
+ room walls and there is likely a relation between `RT60` and
59
+ `initial`, which we ignore here.
60
+ - first_delay: range of values to sample the first echo delay in seconds.
61
+ The default values are equivalent to sampling a room of 3 to 10 meters.
62
+ - repeat: how many train of echos with differents jitters to add.
63
+ Higher values means a denser reverb.
64
+ - jitter: jitter used to make each repetition of the reverb echo train
65
+ slightly different. For instance a jitter of 0.1 means
66
+ the delay between two echos will be in the range `first_delay +- 10%`,
67
+ with the jittering noise being resampled after each single echo.
68
+ - keep_clean: fraction of the reverb of the clean speech to add back
69
+ to the ground truth. 0 = dereverberation, 1 = no dereverberation.
70
+ - sample_rate: sample rate of the input signals.
71
+ """
72
+
73
+ def __init__(self, proba=0.5, initial=0.3, rt60=(0.3, 1.3), first_delay=(0.01, 0.03),
74
+ repeat=3, jitter=0.1, keep_clean=0.1, sample_rate=16000):
75
+ super().__init__()
76
+ self.proba = proba
77
+ self.initial = initial
78
+ self.rt60 = rt60
79
+ self.first_delay = first_delay
80
+ self.repeat = repeat
81
+ self.jitter = jitter
82
+ self.keep_clean = keep_clean
83
+ self.sample_rate = sample_rate
84
+
85
+ def _reverb(self, source, initial, first_delay, rt60):
86
+ """
87
+ Return the reverb for a single source.
88
+ """
89
+ length = source.shape[-1]
90
+ reverb = th.zeros_like(source)
91
+ for _ in range(self.repeat):
92
+ frac = 1 # what fraction of the first echo amplitude is still here
93
+ echo = initial * source
94
+ while frac > 1e-3:
95
+ # First jitter noise for the delay
96
+ jitter = 1 + self.jitter * random.uniform(-1, 1)
97
+ delay = min(
98
+ 1 + int(jitter * first_delay * self.sample_rate),
99
+ length)
100
+ # Delay the echo in time by padding with zero on the left
101
+ echo = F.pad(echo[:, :, :-delay], (delay, 0))
102
+ reverb += echo
103
+
104
+ # Second jitter noise for the attenuation
105
+ jitter = 1 + self.jitter * random.uniform(-1, 1)
106
+ # we want, with `d` the attenuation, d**(rt60 / first_ms) = 1e-3
107
+ # i.e. log10(d) = -3 * first_ms / rt60, so that
108
+ attenuation = 10**(-3 * jitter * first_delay / rt60)
109
+ echo *= attenuation
110
+ frac *= attenuation
111
+ return reverb
112
+
113
+ def forward(self, wav):
114
+ if random.random() >= self.proba:
115
+ return wav
116
+ noise, clean = wav
117
+ # Sample characteristics for the reverb
118
+ initial = random.random() * self.initial
119
+ first_delay = random.uniform(*self.first_delay)
120
+ rt60 = random.uniform(*self.rt60)
121
+
122
+ reverb_noise = self._reverb(noise, initial, first_delay, rt60)
123
+ # Reverb for the noise is always added back to the noise
124
+ noise += reverb_noise
125
+ reverb_clean = self._reverb(clean, initial, first_delay, rt60)
126
+ # Split clean reverb among the clean speech and noise
127
+ clean += self.keep_clean * reverb_clean
128
+ noise += (1 - self.keep_clean) * reverb_clean
129
+
130
+ return th.stack([noise, clean])
131
+
132
+
133
+ class BandMask(nn.Module):
134
+ """BandMask.
135
+ Maskes bands of frequencies. Similar to Park, Daniel S., et al.
136
+ "Specaugment: A simple data augmentation method for automatic speech recognition."
137
+ (https://arxiv.org/pdf/1904.08779.pdf) but over the waveform.
138
+ """
139
+
140
+ def __init__(self, maxwidth=0.2, bands=120, sample_rate=16_000):
141
+ """__init__.
142
+
143
+ :param maxwidth: the maximum width to remove
144
+ :param bands: number of bands
145
+ :param sample_rate: signal sample rate
146
+ """
147
+ super().__init__()
148
+ self.maxwidth = maxwidth
149
+ self.bands = bands
150
+ self.sample_rate = sample_rate
151
+
152
+ def forward(self, wav):
153
+ bands = self.bands
154
+ bandwidth = int(abs(self.maxwidth) * bands)
155
+ mels = dsp.mel_frequencies(bands, 40, self.sample_rate/2) / self.sample_rate
156
+ low = random.randrange(bands)
157
+ high = random.randrange(low, min(bands, low + bandwidth))
158
+ filters = dsp.LowPassFilters([mels[low], mels[high]]).to(wav.device)
159
+ low, midlow = filters(wav)
160
+ # band pass filtering
161
+ out = wav - midlow + low
162
+ return out
163
+
164
+
165
+ class Shift(nn.Module):
166
+ """Shift."""
167
+
168
+ def __init__(self, shift=8192, same=False):
169
+ """__init__.
170
+
171
+ :param shift: randomly shifts the signals up to a given factor
172
+ :param same: shifts both clean and noisy files by the same factor
173
+ """
174
+ super().__init__()
175
+ self.shift = shift
176
+ self.same = same
177
+
178
+ def forward(self, wav):
179
+ sources, batch, channels, length = wav.shape
180
+ length = length - self.shift
181
+ if self.shift > 0:
182
+ if not self.training:
183
+ wav = wav[..., :length]
184
+ else:
185
+ offsets = th.randint(
186
+ self.shift,
187
+ [1 if self.same else sources, batch, 1, 1], device=wav.device)
188
+ offsets = offsets.expand(sources, -1, channels, -1)
189
+ indexes = th.arange(length, device=wav.device)
190
+ wav = wav.gather(3, indexes + offsets)
191
+ return wav
denoiser/conv_demucs.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import math
9
+ import time
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .resample import downsample2, upsample2
16
+ from .utils import capture_init
17
+
18
+
19
+ # class BLSTM(nn.Module):
20
+ # def __init__(self, dim, layers=2, bi=True):
21
+ # super().__init__()
22
+ # klass = nn.LSTM
23
+ # self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
24
+ # self.linear = None
25
+ # if bi:
26
+ # self.linear = nn.Linear(2 * dim, dim)
27
+
28
+ # def forward(self, x, hidden=None):
29
+ # x, hidden = self.lstm(x, hidden)
30
+ # if self.linear:
31
+ # x = self.linear(x)
32
+ # return x, hidden
33
+
34
+ EPS = 1e-8
35
+ class Chomp1d(nn.Module):
36
+ """To ensure the output length is the same as the input.
37
+ """
38
+ def __init__(self, chomp_size):
39
+ super(Chomp1d, self).__init__()
40
+ self.chomp_size = chomp_size
41
+
42
+ def forward(self, x):
43
+ """
44
+ Args:
45
+ x: [M, H, Kpad]
46
+ Returns:
47
+ [M, H, K]
48
+ """
49
+ return x[:, :, :-self.chomp_size].contiguous()
50
+
51
+ def chose_norm(norm_type, channel_size):
52
+ """The input of normlization will be (M, C, K), where M is batch size,
53
+ C is channel size and K is sequence length.
54
+ """
55
+ if norm_type == "gLN":
56
+ return GlobalLayerNorm(channel_size)
57
+ elif norm_type == "cLN":
58
+ return ChannelwiseLayerNorm(channel_size)
59
+ else: # norm_type == "BN":
60
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
61
+ # along M and K, so this BN usage is right.
62
+ return nn.BatchNorm1d(channel_size)
63
+
64
+ class ChannelwiseLayerNorm(nn.Module):
65
+ """Channel-wise Layer Normalization (cLN)"""
66
+ def __init__(self, channel_size):
67
+ super(ChannelwiseLayerNorm, self).__init__()
68
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
69
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
70
+ self.reset_parameters()
71
+
72
+ def reset_parameters(self):
73
+ self.gamma.data.fill_(1)
74
+ self.beta.data.zero_()
75
+
76
+ def forward(self, y):
77
+ """
78
+ Args:
79
+ y: [M, N, K], M is batch size, N is channel size, K is length
80
+ Returns:
81
+ cLN_y: [M, N, K]
82
+ """
83
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
84
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
85
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
86
+ return cLN_y
87
+
88
+ class DepthwiseSeparableConv(nn.Module):
89
+ def __init__(self, in_channels, out_channels, kernel_size,
90
+ stride, padding, dilation, norm_type="gLN", causal=False):
91
+ super(DepthwiseSeparableConv, self).__init__()
92
+ # Use `groups` option to implement depthwise convolution
93
+ # [M, H, K] -> [M, H, K]
94
+ depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size,
95
+ stride=stride, padding=padding,
96
+ dilation=dilation, groups=in_channels,
97
+ bias=False)
98
+ if causal:
99
+ chomp = Chomp1d(padding)
100
+ prelu = nn.PReLU()
101
+ norm = chose_norm(norm_type, in_channels)
102
+ # [M, H, K] -> [M, B, K]
103
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
104
+ # Put together
105
+ if causal:
106
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm,
107
+ pointwise_conv)
108
+ else:
109
+ self.net = nn.Sequential(depthwise_conv, prelu, norm,
110
+ pointwise_conv)
111
+
112
+ def forward(self, x):
113
+ """
114
+ Args:
115
+ x: [M, H, K]
116
+ Returns:
117
+ result: [M, B, K]
118
+ """
119
+ return self.net(x)
120
+
121
+ class TemporalBlock(nn.Module):
122
+ def __init__(self, in_channels, out_channels, kernel_size,
123
+ stride, padding, dilation, norm_type="gLN", causal=False):
124
+ super(TemporalBlock, self).__init__()
125
+ # [M, B, K] -> [M, H, K]
126
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
127
+ prelu = nn.PReLU()
128
+ norm = chose_norm(norm_type, out_channels)
129
+ # [M, H, K] -> [M, B, K]
130
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size,
131
+ stride, padding, dilation, norm_type,
132
+ causal)
133
+ # Put together
134
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
135
+
136
+ def forward(self, x):
137
+ """
138
+ Args:
139
+ x: [M, B, K]
140
+ Returns:
141
+ [M, B, K]
142
+ """
143
+ residual = x
144
+ out = self.net(x)
145
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
146
+ return out + residual # look like w/o F.relu is better than w/ F.relu
147
+ # return F.relu(out + residual)
148
+
149
+ class GlobalLayerNorm(nn.Module):
150
+ """Global Layer Normalization (gLN)"""
151
+ def __init__(self, channel_size):
152
+ super(GlobalLayerNorm, self).__init__()
153
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
154
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
155
+ self.reset_parameters()
156
+
157
+ def reset_parameters(self):
158
+ self.gamma.data.fill_(1)
159
+ self.beta.data.zero_()
160
+
161
+ def forward(self, y):
162
+ """
163
+ Args:
164
+ y: [M, N, K], M is batch size, N is channel size, K is length
165
+ Returns:
166
+ gLN_y: [M, N, K]
167
+ """
168
+ # TODO: in torch 1.0, torch.mean() support dim list
169
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
170
+ var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
171
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
172
+ return gLN_y
173
+
174
+ class TemporalConvNet(nn.Module):
175
+ def __init__(self, N=768, B=256, H=512, P=3, X=8, R=4, C=1, norm_type="gLN", causal=1,
176
+ mask_nonlinear='relu'):
177
+ """
178
+ Args:
179
+ N: Number of filters in autoencoder
180
+ B: Number of channels in bottleneck 1 × 1-conv block
181
+ H: Number of channels in convolutional blocks
182
+ P: Kernel size in convolutional blocks
183
+ X: Number of convolutional blocks in each repeat
184
+ R: Number of repeats
185
+ C: Number of speakers
186
+ norm_type: BN, gLN, cLN
187
+ causal: causal or non-causal
188
+ mask_nonlinear: use which non-linear function to generate mask
189
+ """
190
+ super(TemporalConvNet, self).__init__()
191
+ # Hyper-parameter
192
+ self.C = C
193
+ self.mask_nonlinear = mask_nonlinear
194
+ # Components
195
+ # [M, N, K] -> [M, N, K]
196
+ layer_norm = ChannelwiseLayerNorm(N)
197
+ # [M, N, K] -> [M, B, K]
198
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
199
+ # [M, B, K] -> [M, B, K]
200
+ repeats = []
201
+ for r in range(R):
202
+ blocks = []
203
+ for x in range(X):
204
+ dilation = 2**x
205
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
206
+ blocks += [TemporalBlock(B, H, P, stride=1,
207
+ padding=padding,
208
+ dilation=dilation,
209
+ norm_type=norm_type,
210
+ causal=causal)]
211
+ repeats += [nn.Sequential(*blocks)]
212
+ temporal_conv_net = nn.Sequential(*repeats)
213
+ # [M, B, K] -> [M, C*N, K]
214
+ mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias=False)
215
+ # Put together
216
+ self.network = nn.Sequential(layer_norm,
217
+ bottleneck_conv1x1,
218
+ temporal_conv_net,
219
+ mask_conv1x1)
220
+
221
+ def forward(self, mixture_w):
222
+ """
223
+ Keep this API same with TasNet
224
+ Args:
225
+ mixture_w: [M, N, K], M is batch size
226
+ returns:
227
+ est_mask: [M, C, N, K]
228
+ """
229
+ M, N, K = mixture_w.size()
230
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
231
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
232
+ if self.mask_nonlinear == 'softmax':
233
+ est_mask = F.softmax(score, dim=1)
234
+ est_mask = est_mask.squeeze(1)
235
+ elif self.mask_nonlinear == 'relu':
236
+ est_mask = F.relu(score)
237
+ est_mask = est_mask.squeeze(1)
238
+ else:
239
+ raise ValueError("Unsupported mask non-linear function")
240
+ return est_mask
241
+
242
+
243
+
244
+ def rescale_conv(conv, reference):
245
+ std = conv.weight.std().detach()
246
+ scale = (std / reference)**0.5
247
+ conv.weight.data /= scale
248
+ if conv.bias is not None:
249
+ conv.bias.data /= scale
250
+
251
+
252
+ def rescale_module(module, reference):
253
+ for sub in module.modules():
254
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
255
+ rescale_conv(sub, reference)
256
+
257
+
258
+ class Demucs(nn.Module):
259
+ """
260
+ Demucs speech enhancement model.
261
+ Args:
262
+ - chin (int): number of input channels.
263
+ - chout (int): number of output channels.
264
+ - hidden (int): number of initial hidden channels.
265
+ - depth (int): number of layers.
266
+ - kernel_size (int): kernel size for each layer.
267
+ - stride (int): stride for each layer.
268
+ - causal (bool): if false, uses BiLSTM instead of LSTM.
269
+ - resample (int): amount of resampling to apply to the input/output.
270
+ Can be one of 1, 2 or 4.
271
+ - growth (float): number of channels is multiplied by this for every layer.
272
+ - max_hidden (int): maximum number of channels. Can be useful to
273
+ control the size/speed of the model.
274
+ - normalize (bool): if true, normalize the input.
275
+ - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions.
276
+ - rescale (float): controls custom weight initialization.
277
+ See https://arxiv.org/abs/1911.13254.
278
+ - floor (float): stability flooring when normalizing.
279
+
280
+ """
281
+ @capture_init
282
+ def __init__(self,
283
+ chin=1,
284
+ chout=1,
285
+ hidden=48,
286
+ depth=5,
287
+ kernel_size=8,
288
+ stride=4,
289
+ causal=True,
290
+ resample=4,
291
+ growth=2,
292
+ max_hidden=10_000,
293
+ normalize=True,
294
+ glu=True,
295
+ rescale=0.1,
296
+ floor=1e-3):
297
+
298
+ super().__init__()
299
+ if resample not in [1, 2, 4]:
300
+ raise ValueError("Resample should be 1, 2 or 4.")
301
+
302
+ self.chin = chin
303
+ self.chout = chout
304
+ self.hidden = hidden
305
+ self.depth = depth
306
+ self.kernel_size = kernel_size
307
+ self.stride = stride
308
+ self.causal = causal
309
+ self.floor = floor
310
+ self.resample = resample
311
+ self.normalize = normalize
312
+
313
+ self.encoder = nn.ModuleList()
314
+ self.decoder = nn.ModuleList()
315
+ activation = nn.GLU(1) if glu else nn.ReLU()
316
+ ch_scale = 2 if glu else 1
317
+
318
+ for index in range(depth):
319
+ encode = []
320
+ encode += [
321
+ nn.Conv1d(chin, hidden, kernel_size, stride),
322
+ nn.ReLU(),
323
+ nn.Conv1d(hidden, hidden * ch_scale, 1), activation,
324
+ ]
325
+ self.encoder.append(nn.Sequential(*encode))
326
+
327
+ decode = []
328
+ decode += [
329
+ nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
330
+ nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
331
+ ]
332
+ if index > 0:
333
+ decode.append(nn.ReLU())
334
+ self.decoder.insert(0, nn.Sequential(*decode))
335
+ chout = hidden
336
+ chin = hidden
337
+ hidden = min(int(growth * hidden), max_hidden)
338
+ # import pdb; pdb.set_trace()
339
+ self.separator = TemporalConvNet(N=chout)
340
+ # self.lstm = BLSTM(chin, bi=not causal)
341
+ if rescale:
342
+ rescale_module(self, reference=rescale)
343
+
344
+ def valid_length(self, length):
345
+ """
346
+ Return the nearest valid length to use with the model so that
347
+ there is no time steps left over in a convolutions, e.g. for all
348
+ layers, size of the input - kernel_size % stride = 0.
349
+
350
+ If the mixture has a valid length, the estimated sources
351
+ will have exactly the same length.
352
+ """
353
+ length = math.ceil(length * self.resample)
354
+ for idx in range(self.depth):
355
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
356
+ length = max(length, 1)
357
+ for idx in range(self.depth):
358
+ length = (length - 1) * self.stride + self.kernel_size
359
+ length = int(math.ceil(length / self.resample))
360
+ return int(length)
361
+
362
+ @property
363
+ def total_stride(self):
364
+ return self.stride ** self.depth // self.resample
365
+
366
+ def forward(self, mix):
367
+ if mix.dim() == 2:
368
+ mix = mix.unsqueeze(1)
369
+
370
+ if self.normalize:
371
+ mono = mix.mean(dim=1, keepdim=True)
372
+ std = mono.std(dim=-1, keepdim=True)
373
+ mix = mix / (self.floor + std)
374
+ else:
375
+ std = 1
376
+ length = mix.shape[-1]
377
+ x = mix
378
+ x = F.pad(x, (0, self.valid_length(length) - length))
379
+ if self.resample == 2:
380
+ x = upsample2(x)
381
+ elif self.resample == 4:
382
+ x = upsample2(x)
383
+ x = upsample2(x)
384
+ skips = []
385
+ for encode in self.encoder:
386
+ x = encode(x)
387
+ skips.append(x)
388
+ x = self.separator(x)
389
+ # x = x.permute(2, 0, 1)
390
+ # x, _ = self.lstm(x)
391
+ # x = x.permute(1, 2, 0)
392
+ # import pdb; pdb.set_trace()
393
+ for decode in self.decoder:
394
+ skip = skips.pop(-1)
395
+ x = x + skip[..., :x.shape[-1]]
396
+ x = decode(x)
397
+ if self.resample == 2:
398
+ x = downsample2(x)
399
+ elif self.resample == 4:
400
+ x = downsample2(x)
401
+ x = downsample2(x)
402
+
403
+ x = x[..., :length]
404
+ return std * x
405
+
406
+
407
+ def fast_conv(conv, x):
408
+ """
409
+ Faster convolution evaluation if either kernel size is 1
410
+ or length of sequence is 1.
411
+ """
412
+ batch, chin, length = x.shape
413
+ chout, chin, kernel = conv.weight.shape
414
+ assert batch == 1
415
+ if kernel == 1:
416
+ x = x.view(chin, length)
417
+ out = th.addmm(conv.bias.view(-1, 1),
418
+ conv.weight.view(chout, chin), x)
419
+ elif length == kernel:
420
+ x = x.view(chin * kernel, 1)
421
+ out = th.addmm(conv.bias.view(-1, 1),
422
+ conv.weight.view(chout, chin * kernel), x)
423
+ else:
424
+ out = conv(x)
425
+ return out.view(batch, chout, -1)
426
+
427
+
428
+ class DemucsStreamer:
429
+ """
430
+ Streaming implementation for Demucs. It supports being fed with any amount
431
+ of audio at a time. You will get back as much audio as possible at that
432
+ point.
433
+
434
+ Args:
435
+ - demucs (Demucs): Demucs model.
436
+ - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum
437
+ noise removal, 1 just returns the input signal. Small values > 0
438
+ allows to limit distortions.
439
+ - num_frames (int): number of frames to process at once. Higher values
440
+ will increase overall latency but improve the real time factor.
441
+ - resample_lookahead (int): extra lookahead used for the resampling.
442
+ - resample_buffer (int): size of the buffer of previous inputs/outputs
443
+ kept for resampling.
444
+ """
445
+ def __init__(self, demucs,
446
+ dry=0,
447
+ num_frames=1,
448
+ resample_lookahead=64,
449
+ resample_buffer=256):
450
+ device = next(iter(demucs.parameters())).device
451
+ self.demucs = demucs
452
+ self.lstm_state = None
453
+ self.conv_state = None
454
+ self.dry = dry
455
+ self.resample_lookahead = resample_lookahead
456
+ self.resample_buffer = resample_buffer
457
+ self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1)
458
+ self.total_length = self.frame_length + self.resample_lookahead
459
+ self.stride = demucs.total_stride * num_frames
460
+ self.resample_in = torch.zeros(demucs.chin, resample_buffer, device=device)
461
+ self.resample_out = torch.zeros(demucs.chin, resample_buffer, device=device)
462
+
463
+ self.frames = 0
464
+ self.total_time = 0
465
+ self.variance = 0
466
+ self.pending = torch.zeros(demucs.chin, 0, device=device)
467
+
468
+ bias = demucs.decoder[0][2].bias
469
+ weight = demucs.decoder[0][2].weight
470
+ chin, chout, kernel = weight.shape
471
+ self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1)
472
+ self._weight = weight.permute(1, 2, 0).contiguous()
473
+
474
+ def reset_time_per_frame(self):
475
+ self.total_time = 0
476
+ self.frames = 0
477
+
478
+ @property
479
+ def time_per_frame(self):
480
+ return self.total_time / self.frames
481
+
482
+ def flush(self):
483
+ """
484
+ Flush remaining audio by padding it with zero. Call this
485
+ when you have no more input and want to get back the last chunk of audio.
486
+ """
487
+ pending_length = self.pending.shape[1]
488
+ padding = torch.zeros(self.demucs.chin, self.total_length, device=self.pending.device)
489
+ out = self.feed(padding)
490
+ return out[:, :pending_length]
491
+
492
+ def feed(self, wav):
493
+ """
494
+ Apply the model to mix using true real time evaluation.
495
+ Normalization is done online as is the resampling.
496
+ """
497
+ begin = time.time()
498
+ demucs = self.demucs
499
+ resample_buffer = self.resample_buffer
500
+ stride = self.stride
501
+ resample = demucs.resample
502
+
503
+ if wav.dim() != 2:
504
+ raise ValueError("input wav should be two dimensional.")
505
+ chin, _ = wav.shape
506
+ if chin != demucs.chin:
507
+ raise ValueError(f"Expected {demucs.chin} channels, got {chin}")
508
+
509
+ self.pending = torch.cat([self.pending, wav], dim=1)
510
+ outs = []
511
+ while self.pending.shape[1] >= self.total_length:
512
+ self.frames += 1
513
+ frame = self.pending[:, :self.total_length]
514
+ dry_signal = frame[:, :stride]
515
+ if demucs.normalize:
516
+ mono = frame.mean(0)
517
+ variance = (mono**2).mean()
518
+ self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance
519
+ frame = frame / (demucs.floor + math.sqrt(self.variance))
520
+ frame = torch.cat([self.resample_in, frame], dim=-1)
521
+ self.resample_in[:] = frame[:, stride - resample_buffer:stride]
522
+
523
+ if resample == 4:
524
+ frame = upsample2(upsample2(frame))
525
+ elif resample == 2:
526
+ frame = upsample2(frame)
527
+ frame = frame[:, resample * resample_buffer:] # remove pre sampling buffer
528
+ frame = frame[:, :resample * self.frame_length] # remove extra samples after window
529
+
530
+ out, extra = self._separate_frame(frame)
531
+ padded_out = torch.cat([self.resample_out, out, extra], 1)
532
+ self.resample_out[:] = out[:, -resample_buffer:]
533
+ if resample == 4:
534
+ out = downsample2(downsample2(padded_out))
535
+ elif resample == 2:
536
+ out = downsample2(padded_out)
537
+ else:
538
+ out = padded_out
539
+
540
+ out = out[:, resample_buffer // resample:]
541
+ out = out[:, :stride]
542
+
543
+ if demucs.normalize:
544
+ out *= math.sqrt(self.variance)
545
+ out = self.dry * dry_signal + (1 - self.dry) * out
546
+ outs.append(out)
547
+ self.pending = self.pending[:, stride:]
548
+
549
+ self.total_time += time.time() - begin
550
+ if outs:
551
+ out = torch.cat(outs, 1)
552
+ else:
553
+ out = torch.zeros(chin, 0, device=wav.device)
554
+ return out
555
+
556
+ def _separate_frame(self, frame):
557
+ demucs = self.demucs
558
+ skips = []
559
+ next_state = []
560
+ first = self.conv_state is None
561
+ stride = self.stride * demucs.resample
562
+ x = frame[None]
563
+ for idx, encode in enumerate(demucs.encoder):
564
+ stride //= demucs.stride
565
+ length = x.shape[2]
566
+ if idx == demucs.depth - 1:
567
+ # This is sligthly faster for the last conv
568
+ x = fast_conv(encode[0], x)
569
+ x = encode[1](x)
570
+ x = fast_conv(encode[2], x)
571
+ x = encode[3](x)
572
+ else:
573
+ if not first:
574
+ prev = self.conv_state.pop(0)
575
+ prev = prev[..., stride:]
576
+ tgt = (length - demucs.kernel_size) // demucs.stride + 1
577
+ missing = tgt - prev.shape[-1]
578
+ offset = length - demucs.kernel_size - demucs.stride * (missing - 1)
579
+ x = x[..., offset:]
580
+ x = encode[1](encode[0](x))
581
+ x = fast_conv(encode[2], x)
582
+ x = encode[3](x)
583
+ if not first:
584
+ x = torch.cat([prev, x], -1)
585
+ next_state.append(x)
586
+ skips.append(x)
587
+
588
+ x = x.permute(2, 0, 1)
589
+ x, self.lstm_state = demucs.lstm(x, self.lstm_state)
590
+ x = x.permute(1, 2, 0)
591
+ # In the following, x contains only correct samples, i.e. the one
592
+ # for which each time position is covered by two window of the upper layer.
593
+ # extra contains extra samples to the right, and is used only as a
594
+ # better padding for the online resampling.
595
+ extra = None
596
+ for idx, decode in enumerate(demucs.decoder):
597
+ skip = skips.pop(-1)
598
+ x += skip[..., :x.shape[-1]]
599
+ x = fast_conv(decode[0], x)
600
+ x = decode[1](x)
601
+
602
+ if extra is not None:
603
+ skip = skip[..., x.shape[-1]:]
604
+ extra += skip[..., :extra.shape[-1]]
605
+ extra = decode[2](decode[1](decode[0](extra)))
606
+ x = decode[2](x)
607
+ next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1))
608
+ if extra is None:
609
+ extra = x[..., -demucs.stride:]
610
+ else:
611
+ extra[..., :demucs.stride] += next_state[-1]
612
+ x = x[..., :-demucs.stride]
613
+
614
+ if not first:
615
+ prev = self.conv_state.pop(0)
616
+ x[..., :demucs.stride] += prev
617
+ if idx != demucs.depth - 1:
618
+ x = decode[3](x)
619
+ extra = decode[3](extra)
620
+ self.conv_state = next_state
621
+ return x[0], extra[0]
622
+
623
+
624
+ def test():
625
+ import argparse
626
+ parser = argparse.ArgumentParser(
627
+ "denoiser.demucs",
628
+ description="Benchmark the streaming Demucs implementation, "
629
+ "as well as checking the delta with the offline implementation.")
630
+ parser.add_argument("--resample", default=4, type=int)
631
+ parser.add_argument("--hidden", default=48, type=int)
632
+ parser.add_argument("--device", default="cpu")
633
+ parser.add_argument("-t", "--num_threads", type=int)
634
+ parser.add_argument("-f", "--num_frames", type=int, default=1)
635
+ args = parser.parse_args()
636
+ if args.num_threads:
637
+ torch.set_num_threads(args.num_threads)
638
+ sr = 16_000
639
+ sr_ms = sr / 1000
640
+ demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device)
641
+ x = torch.randn(1, sr * 4).to(args.device)
642
+ out = demucs(x[None])[0]
643
+ streamer = DemucsStreamer(demucs, num_frames=args.num_frames)
644
+ out_rt = []
645
+ frame_size = streamer.total_length
646
+ with torch.no_grad():
647
+ while x.shape[1] > 0:
648
+ out_rt.append(streamer.feed(x[:, :frame_size]))
649
+ x = x[:, frame_size:]
650
+ frame_size = streamer.demucs.total_stride
651
+ out_rt.append(streamer.flush())
652
+ out_rt = torch.cat(out_rt, 1)
653
+ print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='')
654
+ print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='')
655
+ print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='')
656
+ print(f"delta: {torch.norm(out - out_rt) / torch.norm(out):.2%}, ", end='')
657
+ print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}")
658
+
659
+
660
+ if __name__ == "__main__":
661
+ test()
denoiser/data.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez and adiyoss
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import re
12
+
13
+ from .audio import Audioset
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def match_dns(noisy, clean):
19
+ """match_dns.
20
+ Match noisy and clean DNS dataset filenames.
21
+
22
+ :param noisy: list of the noisy filenames
23
+ :param clean: list of the clean filenames
24
+ """
25
+ logger.debug("Matching noisy and clean for dns dataset")
26
+ noisydict = {}
27
+ extra_noisy = []
28
+ for path, size in noisy:
29
+ match = re.search(r'fileid_(\d+)\.wav$', path)
30
+ if match is None:
31
+ # maybe we are mixing some other dataset in
32
+ extra_noisy.append((path, size))
33
+ else:
34
+ noisydict[match.group(1)] = (path, size)
35
+ noisy[:] = []
36
+ extra_clean = []
37
+ copied = list(clean)
38
+ clean[:] = []
39
+ for path, size in copied:
40
+ match = re.search(r'fileid_(\d+)\.wav$', path)
41
+ if match is None:
42
+ extra_clean.append((path, size))
43
+ else:
44
+ noisy.append(noisydict[match.group(1)])
45
+ clean.append((path, size))
46
+ extra_noisy.sort()
47
+ extra_clean.sort()
48
+ clean += extra_clean
49
+ noisy += extra_noisy
50
+
51
+
52
+ def match_files(noisy, clean, matching="sort"):
53
+ """match_files.
54
+ Sort files to match noisy and clean filenames.
55
+ :param noisy: list of the noisy filenames
56
+ :param clean: list of the clean filenames
57
+ :param matching: the matching function, at this point only sort is supported
58
+ """
59
+ if matching == "dns":
60
+ # dns dataset filenames don't match when sorted, we have to manually match them
61
+ match_dns(noisy, clean)
62
+ elif matching == "sort":
63
+ noisy.sort()
64
+ clean.sort()
65
+ else:
66
+ raise ValueError(f"Invalid value for matching {matching}")
67
+
68
+
69
+ class NoisyCleanSet:
70
+ def __init__(self, json_dir, matching="sort", length=None, stride=None,
71
+ pad=True, sample_rate=None):
72
+ """__init__.
73
+
74
+ :param json_dir: directory containing both clean.json and noisy.json
75
+ :param matching: matching function for the files
76
+ :param length: maximum sequence length
77
+ :param stride: the stride used for splitting audio sequences
78
+ :param pad: pad the end of the sequence with zeros
79
+ :param sample_rate: the signals sampling rate
80
+ """
81
+ noisy_json = os.path.join(json_dir, 'noisy.json')
82
+ clean_json = os.path.join(json_dir, 'clean.json')
83
+ with open(noisy_json, 'r') as f:
84
+ noisy = json.load(f)
85
+ with open(clean_json, 'r') as f:
86
+ clean = json.load(f)
87
+
88
+ match_files(noisy, clean, matching)
89
+ kw = {'length': length, 'stride': stride, 'pad': pad, 'sample_rate': sample_rate}
90
+ self.clean_set = Audioset(clean, **kw)
91
+ self.noisy_set = Audioset(noisy, **kw)
92
+
93
+ assert len(self.clean_set) == len(self.noisy_set)
94
+
95
+ def __getitem__(self, index):
96
+ return self.noisy_set[index], self.clean_set[index]
97
+
98
+ def __len__(self):
99
+ return len(self.noisy_set)
denoiser/demucs.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import math
9
+ import time
10
+
11
+ import torch as th
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .resample import downsample2, upsample2
16
+ from .utils import capture_init
17
+
18
+
19
+ class BLSTM(nn.Module):
20
+ def __init__(self, dim, layers=2, bi=True):
21
+ super().__init__()
22
+ klass = nn.LSTM
23
+ self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim)
24
+ self.linear = None
25
+ if bi:
26
+ self.linear = nn.Linear(2 * dim, dim)
27
+
28
+ def forward(self, x, hidden=None):
29
+ x, hidden = self.lstm(x, hidden)
30
+ if self.linear:
31
+ x = self.linear(x)
32
+ return x, hidden
33
+
34
+
35
+ def rescale_conv(conv, reference):
36
+ std = conv.weight.std().detach()
37
+ scale = (std / reference)**0.5
38
+ conv.weight.data /= scale
39
+ if conv.bias is not None:
40
+ conv.bias.data /= scale
41
+
42
+
43
+ def rescale_module(module, reference):
44
+ for sub in module.modules():
45
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
46
+ rescale_conv(sub, reference)
47
+
48
+
49
+ class Demucs(nn.Module):
50
+ """
51
+ Demucs speech enhancement model.
52
+ Args:
53
+ - chin (int): number of input channels.
54
+ - chout (int): number of output channels.
55
+ - hidden (int): number of initial hidden channels.
56
+ - depth (int): number of layers.
57
+ - kernel_size (int): kernel size for each layer.
58
+ - stride (int): stride for each layer.
59
+ - causal (bool): if false, uses BiLSTM instead of LSTM.
60
+ - resample (int): amount of resampling to apply to the input/output.
61
+ Can be one of 1, 2 or 4.
62
+ - growth (float): number of channels is multiplied by this for every layer.
63
+ - max_hidden (int): maximum number of channels. Can be useful to
64
+ control the size/speed of the model.
65
+ - normalize (bool): if true, normalize the input.
66
+ - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions.
67
+ - rescale (float): controls custom weight initialization.
68
+ See https://arxiv.org/abs/1911.13254.
69
+ - floor (float): stability flooring when normalizing.
70
+
71
+ """
72
+ @capture_init
73
+ def __init__(self,
74
+ chin=1,
75
+ chout=1,
76
+ hidden=48,
77
+ depth=5,
78
+ kernel_size=8,
79
+ stride=4,
80
+ causal=True,
81
+ resample=4,
82
+ growth=2,
83
+ max_hidden=10_000,
84
+ normalize=True,
85
+ glu=True,
86
+ rescale=0.1,
87
+ floor=1e-3):
88
+
89
+ super().__init__()
90
+ if resample not in [1, 2, 4]:
91
+ raise ValueError("Resample should be 1, 2 or 4.")
92
+
93
+ self.chin = chin
94
+ self.chout = chout
95
+ self.hidden = hidden
96
+ self.depth = depth
97
+ self.kernel_size = kernel_size
98
+ self.stride = stride
99
+ self.causal = causal
100
+ self.floor = floor
101
+ self.resample = resample
102
+ self.normalize = normalize
103
+
104
+ self.encoder = nn.ModuleList()
105
+ self.decoder = nn.ModuleList()
106
+ activation = nn.GLU(1) if glu else nn.ReLU()
107
+ ch_scale = 2 if glu else 1
108
+
109
+ for index in range(depth):
110
+ encode = []
111
+ encode += [
112
+ nn.Conv1d(chin, hidden, kernel_size, stride),
113
+ nn.ReLU(),
114
+ nn.Conv1d(hidden, hidden * ch_scale, 1), activation,
115
+ ]
116
+ self.encoder.append(nn.Sequential(*encode))
117
+
118
+ decode = []
119
+ decode += [
120
+ nn.Conv1d(hidden, ch_scale * hidden, 1), activation,
121
+ nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
122
+ ]
123
+ if index > 0:
124
+ decode.append(nn.ReLU())
125
+ self.decoder.insert(0, nn.Sequential(*decode))
126
+ chout = hidden
127
+ chin = hidden
128
+ hidden = min(int(growth * hidden), max_hidden)
129
+
130
+ self.lstm = BLSTM(chin, bi=not causal)
131
+ if rescale:
132
+ rescale_module(self, reference=rescale)
133
+
134
+ def valid_length(self, length):
135
+ """
136
+ Return the nearest valid length to use with the model so that
137
+ there is no time steps left over in a convolutions, e.g. for all
138
+ layers, size of the input - kernel_size % stride = 0.
139
+
140
+ If the mixture has a valid length, the estimated sources
141
+ will have exactly the same length.
142
+ """
143
+ length = math.ceil(length * self.resample)
144
+ for idx in range(self.depth):
145
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
146
+ length = max(length, 1)
147
+ for idx in range(self.depth):
148
+ length = (length - 1) * self.stride + self.kernel_size
149
+ length = int(math.ceil(length / self.resample))
150
+ return int(length)
151
+
152
+ @property
153
+ def total_stride(self):
154
+ return self.stride ** self.depth // self.resample
155
+
156
+ def forward(self, mix):
157
+ if mix.dim() == 2:
158
+ mix = mix.unsqueeze(1)
159
+
160
+ if self.normalize:
161
+ mono = mix.mean(dim=1, keepdim=True)
162
+ std = mono.std(dim=-1, keepdim=True)
163
+ mix = mix / (self.floor + std)
164
+ else:
165
+ std = 1
166
+ length = mix.shape[-1]
167
+ x = mix
168
+ x = F.pad(x, (0, self.valid_length(length) - length))
169
+ if self.resample == 2:
170
+ x = upsample2(x)
171
+ elif self.resample == 4:
172
+ x = upsample2(x)
173
+ x = upsample2(x)
174
+ skips = []
175
+ for encode in self.encoder:
176
+ x = encode(x)
177
+ skips.append(x)
178
+ x = x.permute(2, 0, 1)
179
+ x, _ = self.lstm(x)
180
+ x = x.permute(1, 2, 0)
181
+ for decode in self.decoder:
182
+ skip = skips.pop(-1)
183
+ x = x + skip[..., :x.shape[-1]]
184
+ x = decode(x)
185
+ if self.resample == 2:
186
+ x = downsample2(x)
187
+ elif self.resample == 4:
188
+ x = downsample2(x)
189
+ x = downsample2(x)
190
+
191
+ x = x[..., :length]
192
+ return std * x
193
+
194
+
195
+ def fast_conv(conv, x):
196
+ """
197
+ Faster convolution evaluation if either kernel size is 1
198
+ or length of sequence is 1.
199
+ """
200
+ batch, chin, length = x.shape
201
+ chout, chin, kernel = conv.weight.shape
202
+ assert batch == 1
203
+ if kernel == 1:
204
+ x = x.view(chin, length)
205
+ out = th.addmm(conv.bias.view(-1, 1),
206
+ conv.weight.view(chout, chin), x)
207
+ elif length == kernel:
208
+ x = x.view(chin * kernel, 1)
209
+ out = th.addmm(conv.bias.view(-1, 1),
210
+ conv.weight.view(chout, chin * kernel), x)
211
+ else:
212
+ out = conv(x)
213
+ return out.view(batch, chout, -1)
214
+
215
+
216
+ class DemucsStreamer:
217
+ """
218
+ Streaming implementation for Demucs. It supports being fed with any amount
219
+ of audio at a time. You will get back as much audio as possible at that
220
+ point.
221
+
222
+ Args:
223
+ - demucs (Demucs): Demucs model.
224
+ - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum
225
+ noise removal, 1 just returns the input signal. Small values > 0
226
+ allows to limit distortions.
227
+ - num_frames (int): number of frames to process at once. Higher values
228
+ will increase overall latency but improve the real time factor.
229
+ - resample_lookahead (int): extra lookahead used for the resampling.
230
+ - resample_buffer (int): size of the buffer of previous inputs/outputs
231
+ kept for resampling.
232
+ """
233
+ def __init__(self, demucs,
234
+ dry=0,
235
+ num_frames=1,
236
+ resample_lookahead=64,
237
+ resample_buffer=256):
238
+ device = next(iter(demucs.parameters())).device
239
+ self.demucs = demucs
240
+ self.lstm_state = None
241
+ self.conv_state = None
242
+ self.dry = dry
243
+ self.resample_lookahead = resample_lookahead
244
+ self.resample_buffer = resample_buffer
245
+ self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1)
246
+ self.total_length = self.frame_length + self.resample_lookahead
247
+ self.stride = demucs.total_stride * num_frames
248
+ self.resample_in = th.zeros(demucs.chin, resample_buffer, device=device)
249
+ self.resample_out = th.zeros(demucs.chin, resample_buffer, device=device)
250
+
251
+ self.frames = 0
252
+ self.total_time = 0
253
+ self.variance = 0
254
+ self.pending = th.zeros(demucs.chin, 0, device=device)
255
+
256
+ bias = demucs.decoder[0][2].bias
257
+ weight = demucs.decoder[0][2].weight
258
+ chin, chout, kernel = weight.shape
259
+ self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1)
260
+ self._weight = weight.permute(1, 2, 0).contiguous()
261
+
262
+ def reset_time_per_frame(self):
263
+ self.total_time = 0
264
+ self.frames = 0
265
+
266
+ @property
267
+ def time_per_frame(self):
268
+ return self.total_time / self.frames
269
+
270
+ def flush(self):
271
+ """
272
+ Flush remaining audio by padding it with zero. Call this
273
+ when you have no more input and want to get back the last chunk of audio.
274
+ """
275
+ pending_length = self.pending.shape[1]
276
+ padding = th.zeros(self.demucs.chin, self.total_length, device=self.pending.device)
277
+ out = self.feed(padding)
278
+ return out[:, :pending_length]
279
+
280
+ def feed(self, wav):
281
+ """
282
+ Apply the model to mix using true real time evaluation.
283
+ Normalization is done online as is the resampling.
284
+ """
285
+ begin = time.time()
286
+ demucs = self.demucs
287
+ resample_buffer = self.resample_buffer
288
+ stride = self.stride
289
+ resample = demucs.resample
290
+
291
+ if wav.dim() != 2:
292
+ raise ValueError("input wav should be two dimensional.")
293
+ chin, _ = wav.shape
294
+ if chin != demucs.chin:
295
+ raise ValueError(f"Expected {demucs.chin} channels, got {chin}")
296
+
297
+ self.pending = th.cat([self.pending, wav], dim=1)
298
+ outs = []
299
+ while self.pending.shape[1] >= self.total_length:
300
+ self.frames += 1
301
+ frame = self.pending[:, :self.total_length]
302
+ dry_signal = frame[:, :stride]
303
+ if demucs.normalize:
304
+ mono = frame.mean(0)
305
+ variance = (mono**2).mean()
306
+ self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance
307
+ frame = frame / (demucs.floor + math.sqrt(self.variance))
308
+ frame = th.cat([self.resample_in, frame], dim=-1)
309
+ self.resample_in[:] = frame[:, stride - resample_buffer:stride]
310
+
311
+ if resample == 4:
312
+ frame = upsample2(upsample2(frame))
313
+ elif resample == 2:
314
+ frame = upsample2(frame)
315
+ frame = frame[:, resample * resample_buffer:] # remove pre sampling buffer
316
+ frame = frame[:, :resample * self.frame_length] # remove extra samples after window
317
+
318
+ out, extra = self._separate_frame(frame)
319
+ padded_out = th.cat([self.resample_out, out, extra], 1)
320
+ self.resample_out[:] = out[:, -resample_buffer:]
321
+ if resample == 4:
322
+ out = downsample2(downsample2(padded_out))
323
+ elif resample == 2:
324
+ out = downsample2(padded_out)
325
+ else:
326
+ out = padded_out
327
+
328
+ out = out[:, resample_buffer // resample:]
329
+ out = out[:, :stride]
330
+
331
+ if demucs.normalize:
332
+ out *= math.sqrt(self.variance)
333
+ out = self.dry * dry_signal + (1 - self.dry) * out
334
+ outs.append(out)
335
+ self.pending = self.pending[:, stride:]
336
+
337
+ self.total_time += time.time() - begin
338
+ if outs:
339
+ out = th.cat(outs, 1)
340
+ else:
341
+ out = th.zeros(chin, 0, device=wav.device)
342
+ return out
343
+
344
+ def _separate_frame(self, frame):
345
+ demucs = self.demucs
346
+ skips = []
347
+ next_state = []
348
+ first = self.conv_state is None
349
+ stride = self.stride * demucs.resample
350
+ x = frame[None]
351
+ for idx, encode in enumerate(demucs.encoder):
352
+ stride //= demucs.stride
353
+ length = x.shape[2]
354
+ if idx == demucs.depth - 1:
355
+ # This is sligthly faster for the last conv
356
+ x = fast_conv(encode[0], x)
357
+ x = encode[1](x)
358
+ x = fast_conv(encode[2], x)
359
+ x = encode[3](x)
360
+ else:
361
+ if not first:
362
+ prev = self.conv_state.pop(0)
363
+ prev = prev[..., stride:]
364
+ tgt = (length - demucs.kernel_size) // demucs.stride + 1
365
+ missing = tgt - prev.shape[-1]
366
+ offset = length - demucs.kernel_size - demucs.stride * (missing - 1)
367
+ x = x[..., offset:]
368
+ x = encode[1](encode[0](x))
369
+ x = fast_conv(encode[2], x)
370
+ x = encode[3](x)
371
+ if not first:
372
+ x = th.cat([prev, x], -1)
373
+ next_state.append(x)
374
+ skips.append(x)
375
+
376
+ x = x.permute(2, 0, 1)
377
+ x, self.lstm_state = demucs.lstm(x, self.lstm_state)
378
+ x = x.permute(1, 2, 0)
379
+ # In the following, x contains only correct samples, i.e. the one
380
+ # for which each time position is covered by two window of the upper layer.
381
+ # extra contains extra samples to the right, and is used only as a
382
+ # better padding for the online resampling.
383
+ extra = None
384
+ for idx, decode in enumerate(demucs.decoder):
385
+ skip = skips.pop(-1)
386
+ x += skip[..., :x.shape[-1]]
387
+ x = fast_conv(decode[0], x)
388
+ x = decode[1](x)
389
+
390
+ if extra is not None:
391
+ skip = skip[..., x.shape[-1]:]
392
+ extra += skip[..., :extra.shape[-1]]
393
+ extra = decode[2](decode[1](decode[0](extra)))
394
+ x = decode[2](x)
395
+ next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1))
396
+ if extra is None:
397
+ extra = x[..., -demucs.stride:]
398
+ else:
399
+ extra[..., :demucs.stride] += next_state[-1]
400
+ x = x[..., :-demucs.stride]
401
+
402
+ if not first:
403
+ prev = self.conv_state.pop(0)
404
+ x[..., :demucs.stride] += prev
405
+ if idx != demucs.depth - 1:
406
+ x = decode[3](x)
407
+ extra = decode[3](extra)
408
+ self.conv_state = next_state
409
+ return x[0], extra[0]
410
+
411
+
412
+ def test():
413
+ import argparse
414
+ parser = argparse.ArgumentParser(
415
+ "denoiser.demucs",
416
+ description="Benchmark the streaming Demucs implementation, "
417
+ "as well as checking the delta with the offline implementation.")
418
+ parser.add_argument("--resample", default=4, type=int)
419
+ parser.add_argument("--hidden", default=48, type=int)
420
+ parser.add_argument("--device", default="cpu")
421
+ parser.add_argument("-t", "--num_threads", type=int)
422
+ parser.add_argument("-f", "--num_frames", type=int, default=1)
423
+ args = parser.parse_args()
424
+ if args.num_threads:
425
+ th.set_num_threads(args.num_threads)
426
+ sr = 16_000
427
+ sr_ms = sr / 1000
428
+ demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device)
429
+ x = th.randn(1, sr * 4).to(args.device)
430
+ out = demucs(x[None])[0]
431
+ streamer = DemucsStreamer(demucs, num_frames=args.num_frames)
432
+ out_rt = []
433
+ frame_size = streamer.total_length
434
+ with th.no_grad():
435
+ while x.shape[1] > 0:
436
+ out_rt.append(streamer.feed(x[:, :frame_size]))
437
+ x = x[:, frame_size:]
438
+ frame_size = streamer.demucs.total_stride
439
+ out_rt.append(streamer.flush())
440
+ out_rt = th.cat(out_rt, 1)
441
+ print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='')
442
+ print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='')
443
+ print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='')
444
+ print(f"delta: {th.norm(out - out_rt) / th.norm(out):.2%}, ", end='')
445
+ print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}")
446
+
447
+
448
+ if __name__ == "__main__":
449
+ test()
denoiser/distrib.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import logging
9
+ import os
10
+
11
+ import torch
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from torch.utils.data import DataLoader, Subset
14
+ from torch.nn.parallel.distributed import DistributedDataParallel
15
+
16
+ logger = logging.getLogger(__name__)
17
+ rank = 0
18
+ world_size = 1
19
+
20
+
21
+ def init(args):
22
+ """init.
23
+
24
+ Initialize DDP using the given rendezvous file.
25
+ """
26
+ global rank, world_size
27
+ if args.ddp:
28
+ assert args.rank is not None and args.world_size is not None
29
+ rank = args.rank
30
+ world_size = args.world_size
31
+ if world_size == 1:
32
+ return
33
+ torch.cuda.set_device(rank)
34
+ torch.distributed.init_process_group(
35
+ backend=args.ddp_backend,
36
+ init_method='file://' + os.path.abspath(args.rendezvous_file),
37
+ world_size=world_size,
38
+ rank=rank)
39
+ logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size)
40
+
41
+
42
+ def average(metrics, count=1.):
43
+ """average.
44
+
45
+ Average all the relevant metrices across processes
46
+ `metrics`should be a 1D float32 fector. Returns the average of `metrics`
47
+ over all hosts. You can use `count` to control the weight of each worker.
48
+ """
49
+ if world_size == 1:
50
+ return metrics
51
+ tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
52
+ tensor *= count
53
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
54
+ return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
55
+
56
+
57
+ def wrap(model):
58
+ """wrap.
59
+
60
+ Wrap a model with DDP if distributed training is enabled.
61
+ """
62
+ if world_size == 1:
63
+ return model
64
+ else:
65
+ return DistributedDataParallel(
66
+ model,
67
+ device_ids=[torch.cuda.current_device()],
68
+ output_device=torch.cuda.current_device())
69
+
70
+
71
+ def barrier():
72
+ if world_size > 1:
73
+ torch.distributed.barrier()
74
+
75
+
76
+ def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
77
+ """loader.
78
+
79
+ Create a dataloader properly in case of distributed training.
80
+ If a gradient is going to be computed you must set `shuffle=True`.
81
+
82
+ :param dataset: the dataset to be parallelized
83
+ :param args: relevant args for the loader
84
+ :param shuffle: shuffle examples
85
+ :param klass: loader class
86
+ :param kwargs: relevant args
87
+ """
88
+
89
+ if world_size == 1:
90
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
91
+
92
+ if shuffle:
93
+ # train means we will compute backward, we use DistributedSampler
94
+ sampler = DistributedSampler(dataset)
95
+ # We ignore shuffle, DistributedSampler already shuffles
96
+ return klass(dataset, *args, **kwargs, sampler=sampler)
97
+ else:
98
+ # We make a manual shard, as DistributedSampler otherwise replicate some examples
99
+ dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
100
+ return klass(dataset, *args, shuffle=shuffle)
denoiser/dsp.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.nn import functional as F
11
+
12
+
13
+ def hz_to_mel(f):
14
+ return 2595 * np.log10(1 + f / 700)
15
+
16
+
17
+ def mel_to_hz(m):
18
+ return 700 * (10**(m / 2595) - 1)
19
+
20
+
21
+ def mel_frequencies(n_mels, fmin, fmax):
22
+ low = hz_to_mel(fmin)
23
+ high = hz_to_mel(fmax)
24
+ mels = np.linspace(low, high, n_mels)
25
+ return mel_to_hz(mels)
26
+
27
+
28
+ class LowPassFilters(torch.nn.Module):
29
+ """
30
+ Bank of low pass filters.
31
+
32
+ Args:
33
+ cutoffs (list[float]): list of cutoff frequencies, in [0, 1] expressed as `f/f_s` where
34
+ f_s is the samplerate.
35
+ width (int): width of the filters (i.e. kernel_size=2 * width + 1).
36
+ Default to `2 / min(cutoffs)`. Longer filters will have better attenuation
37
+ but more side effects.
38
+ Shape:
39
+ - Input: `(*, T)`
40
+ - Output: `(F, *, T` with `F` the len of `cutoffs`.
41
+ """
42
+
43
+ def __init__(self, cutoffs: list, width: int = None):
44
+ super().__init__()
45
+ self.cutoffs = cutoffs
46
+ if width is None:
47
+ width = int(2 / min(cutoffs))
48
+ self.width = width
49
+ window = torch.hamming_window(2 * width + 1, periodic=False)
50
+ t = np.arange(-width, width + 1, dtype=np.float32)
51
+ filters = []
52
+ for cutoff in cutoffs:
53
+ sinc = torch.from_numpy(np.sinc(2 * cutoff * t))
54
+ filters.append(2 * cutoff * sinc * window)
55
+ self.register_buffer("filters", torch.stack(filters).unsqueeze(1))
56
+
57
+ def forward(self, input):
58
+ *others, t = input.shape
59
+ input = input.view(-1, 1, t)
60
+ out = F.conv1d(input, self.filters, padding=self.width)
61
+ return out.permute(1, 0, 2).reshape(-1, *others, t)
62
+
63
+ def __repr__(self):
64
+ return "LossPassFilters(width={},cutoffs={})".format(self.width, self.cutoffs)
denoiser/enhance.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adiyoss
7
+
8
+ import argparse
9
+ import json
10
+ import logging
11
+ import os
12
+ import sys
13
+
14
+ import torch
15
+ import torchaudio
16
+
17
+ from .audio import Audioset, find_audio_files
18
+ from . import distrib, pretrained
19
+ from .demucs import DemucsStreamer
20
+
21
+ from .utils import LogProgress
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def add_flags(parser):
27
+ """
28
+ Add the flags for the argument parser that are related to model loading and evaluation"
29
+ """
30
+ pretrained.add_model_flags(parser)
31
+ parser.add_argument('--device', default="cpu")
32
+ parser.add_argument('--dry', type=float, default=0,
33
+ help='dry/wet knob coefficient. 0 is only input signal, 1 only denoised.')
34
+ parser.add_argument('--sample_rate', default=16_000, type=int, help='sample rate')
35
+ parser.add_argument('--num_workers', type=int, default=10)
36
+ parser.add_argument('--streaming', action="store_true",
37
+ help="true streaming evaluation for Demucs")
38
+
39
+
40
+ parser = argparse.ArgumentParser(
41
+ 'denoiser.enhance',
42
+ description="Speech enhancement using Demucs - Generate enhanced files")
43
+ add_flags(parser)
44
+ parser.add_argument("--out_dir", type=str, default="enhanced",
45
+ help="directory putting enhanced wav files")
46
+ parser.add_argument("--batch_size", default=1, type=int, help="batch size")
47
+ parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
48
+ default=logging.INFO, help="more loggging")
49
+
50
+ group = parser.add_mutually_exclusive_group()
51
+ group.add_argument("--noisy_dir", type=str, default=None,
52
+ help="directory including noisy wav files")
53
+ group.add_argument("--noisy_json", type=str, default=None,
54
+ help="json file including noisy wav files")
55
+
56
+
57
+ def get_estimate(model, noisy, args):
58
+ torch.set_num_threads(1)
59
+ if args.streaming:
60
+ streamer = DemucsStreamer(model, dry=args.dry)
61
+ with torch.no_grad():
62
+ estimate = torch.cat([
63
+ streamer.feed(noisy[0]),
64
+ streamer.flush()], dim=1)[None]
65
+ else:
66
+ with torch.no_grad():
67
+ estimate = model(noisy)
68
+ estimate = (1 - args.dry) * estimate + args.dry * noisy
69
+ return estimate
70
+
71
+
72
+ def save_wavs(estimates, noisy_sigs, filenames, out_dir, sr=16_000):
73
+ # Write result
74
+ for estimate, noisy, filename in zip(estimates, noisy_sigs, filenames):
75
+ filename = os.path.join(out_dir, os.path.basename(filename).rsplit(".", 1)[0])
76
+ write(noisy, filename + "_noisy.wav", sr=sr)
77
+ write(estimate, filename + "_enhanced.wav", sr=sr)
78
+
79
+
80
+ def write(wav, filename, sr=16_000):
81
+ # Normalize audio if it prevents clipping
82
+ wav = wav / max(wav.abs().max().item(), 1)
83
+ torchaudio.save(filename, wav.cpu(), sr)
84
+
85
+
86
+ def get_dataset(args):
87
+ if hasattr(args, 'dset'):
88
+ paths = args.dset
89
+ else:
90
+ paths = args
91
+ if paths.noisy_json:
92
+ with open(paths.noisy_json) as f:
93
+ files = json.load(f)
94
+ elif paths.noisy_dir:
95
+ files = find_audio_files(paths.noisy_dir)
96
+ else:
97
+ logger.warning(
98
+ "Small sample set was not provided by either noisy_dir or noisy_json. "
99
+ "Skipping enhancement.")
100
+ return None
101
+ return Audioset(files, with_path=True, sample_rate=args.sample_rate)
102
+
103
+
104
+ def enhance(args, model=None, local_out_dir=None):
105
+ # Load model
106
+ if not model:
107
+ model = pretrained.get_model(args).to(args.device)
108
+ model.eval()
109
+ if local_out_dir:
110
+ out_dir = local_out_dir
111
+ else:
112
+ out_dir = args.out_dir
113
+
114
+ dset = get_dataset(args)
115
+ if dset is None:
116
+ return
117
+ loader = distrib.loader(dset, batch_size=1)
118
+
119
+ if distrib.rank == 0:
120
+ os.makedirs(out_dir, exist_ok=True)
121
+ distrib.barrier()
122
+
123
+ with torch.no_grad():
124
+ iterator = LogProgress(logger, loader, name="Generate enhanced files")
125
+ for data in iterator:
126
+ # Get batch data
127
+ noisy_signals, filenames = data
128
+ noisy_signals = noisy_signals.to(args.device)
129
+ # Forward
130
+ estimate = get_estimate(model, noisy_signals, args)
131
+ save_wavs(estimate, noisy_signals, filenames, out_dir, sr=args.sample_rate)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ args = parser.parse_args()
136
+ logging.basicConfig(stream=sys.stderr, level=args.verbose)
137
+ logger.debug(args)
138
+ enhance(args, local_out_dir=args.out_dir)
denoiser/evaluate.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adiyoss
7
+
8
+ import argparse
9
+ from concurrent.futures import ProcessPoolExecutor
10
+ import json
11
+ import logging
12
+ import sys
13
+
14
+ from pesq import pesq
15
+ from pystoi import stoi
16
+ import torch
17
+
18
+ from .data import NoisyCleanSet
19
+ from .enhance import add_flags, get_estimate
20
+ from . import distrib, pretrained
21
+ from .utils import bold, LogProgress
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ parser = argparse.ArgumentParser(
26
+ 'denoiser.evaluate',
27
+ description='Speech enhancement using Demucs - Evaluate model performance')
28
+ add_flags(parser)
29
+ parser.add_argument('--data_dir', help='directory including noisy.json and clean.json files')
30
+ parser.add_argument('--matching', default="sort", help='set this to dns for the dns dataset.')
31
+ parser.add_argument('--no_pesq', action="store_false", dest="pesq", default=True,
32
+ help="Don't compute PESQ.")
33
+ parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
34
+ default=logging.INFO, help="More loggging")
35
+
36
+
37
+ def evaluate(args, model=None, data_loader=None):
38
+ total_pesq = 0
39
+ total_stoi = 0
40
+ total_cnt = 0
41
+ updates = 5
42
+
43
+ # Load model
44
+ if not model:
45
+ model = pretrained.get_model(args).to(args.device)
46
+ model.eval()
47
+
48
+ # Load data
49
+ if data_loader is None:
50
+ dataset = NoisyCleanSet(args.data_dir, matching=args.matching, sample_rate=args.sample_rate)
51
+ data_loader = distrib.loader(dataset, batch_size=1, num_workers=2)
52
+ pendings = []
53
+ with ProcessPoolExecutor(args.num_workers) as pool:
54
+ with torch.no_grad():
55
+ iterator = LogProgress(logger, data_loader, name="Eval estimates")
56
+ for i, data in enumerate(iterator):
57
+ # Get batch data
58
+ noisy, clean = [x.to(args.device) for x in data]
59
+ # If device is CPU, we do parallel evaluation in each CPU worker.
60
+ if args.device == 'cpu':
61
+ pendings.append(
62
+ pool.submit(_estimate_and_run_metrics, clean, model, noisy, args))
63
+ else:
64
+ estimate = get_estimate(model, noisy, args)
65
+ estimate = estimate.cpu()
66
+ clean = clean.cpu()
67
+ pendings.append(
68
+ pool.submit(_run_metrics, clean, estimate, args))
69
+ total_cnt += clean.shape[0]
70
+
71
+ for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
72
+ pesq_i, stoi_i = pending.result()
73
+ total_pesq += pesq_i
74
+ total_stoi += stoi_i
75
+
76
+ metrics = [total_pesq, total_stoi]
77
+ pesq, stoi = distrib.average([m/total_cnt for m in metrics], total_cnt)
78
+ logger.info(bold(f'Test set performance:PESQ={pesq}, STOI={stoi}.'))
79
+ return pesq, stoi
80
+
81
+
82
+ def _estimate_and_run_metrics(clean, model, noisy, args):
83
+ estimate = get_estimate(model, noisy, args)
84
+ return _run_metrics(clean, estimate, args)
85
+
86
+
87
+ def _run_metrics(clean, estimate, args):
88
+ estimate = estimate.numpy()[:, 0]
89
+ clean = clean.numpy()[:, 0]
90
+ if args.pesq:
91
+ pesq_i = get_pesq(clean, estimate, sr=args.sample_rate)
92
+ else:
93
+ pesq_i = 0
94
+ stoi_i = get_stoi(clean, estimate, sr=args.sample_rate)
95
+ return pesq_i, stoi_i
96
+
97
+
98
+ def get_pesq(ref_sig, out_sig, sr):
99
+ """Calculate PESQ.
100
+ Args:
101
+ ref_sig: numpy.ndarray, [B, T]
102
+ out_sig: numpy.ndarray, [B, T]
103
+ Returns:
104
+ PESQ
105
+ """
106
+ pesq_val = 0
107
+ for i in range(len(ref_sig)):
108
+ pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'wb')
109
+ return pesq_val
110
+
111
+
112
+ def get_stoi(ref_sig, out_sig, sr):
113
+ """Calculate STOI.
114
+ Args:
115
+ ref_sig: numpy.ndarray, [B, T]
116
+ out_sig: numpy.ndarray, [B, T]
117
+ Returns:
118
+ STOI
119
+ """
120
+ stoi_val = 0
121
+ for i in range(len(ref_sig)):
122
+ stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False)
123
+ return stoi_val
124
+
125
+
126
+ def main():
127
+ args = parser.parse_args()
128
+ logging.basicConfig(stream=sys.stderr, level=args.verbose)
129
+ logger.debug(args)
130
+ pesq, stoi = evaluate(args)
131
+ json.dump({'pesq': pesq, 'stoi': stoi}, sys.stdout)
132
+ sys.stdout.write('\n')
133
+
134
+
135
+ if __name__ == '__main__':
136
+ main()
denoiser/executor.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+ """
8
+ Start multiple process locally for DDP.
9
+ """
10
+
11
+ import logging
12
+ import subprocess as sp
13
+ import sys
14
+
15
+ from hydra import utils
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ChildrenManager:
21
+ def __init__(self):
22
+ self.children = []
23
+ self.failed = False
24
+
25
+ def add(self, child):
26
+ child.rank = len(self.children)
27
+ self.children.append(child)
28
+
29
+ def __enter__(self):
30
+ return self
31
+
32
+ def __exit__(self, exc_type, exc_value, traceback):
33
+ if exc_value is not None:
34
+ logger.error("An exception happened while starting workers %r", exc_value)
35
+ self.failed = True
36
+ try:
37
+ while self.children and not self.failed:
38
+ for child in list(self.children):
39
+ try:
40
+ exitcode = child.wait(0.1)
41
+ except sp.TimeoutExpired:
42
+ continue
43
+ else:
44
+ self.children.remove(child)
45
+ if exitcode:
46
+ logger.error(f"Worker {child.rank} died, killing all workers")
47
+ self.failed = True
48
+ except KeyboardInterrupt:
49
+ logger.error("Received keyboard interrupt, trying to kill all workers.")
50
+ self.failed = True
51
+ for child in self.children:
52
+ child.terminate()
53
+ if not self.failed:
54
+ logger.info("All workers completed successfully")
55
+
56
+
57
+ def start_ddp_workers():
58
+ import torch as th
59
+
60
+ world_size = th.cuda.device_count()
61
+ if not world_size:
62
+ logger.error(
63
+ "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.")
64
+ sys.exit(1)
65
+ logger.info(f"Starting {world_size} worker processes for DDP.")
66
+ with ChildrenManager() as manager:
67
+ for rank in range(world_size):
68
+ kwargs = {}
69
+ argv = list(sys.argv)
70
+ argv += [f"world_size={world_size}", f"rank={rank}"]
71
+ if rank > 0:
72
+ kwargs['stdin'] = sp.DEVNULL
73
+ kwargs['stdout'] = sp.DEVNULL
74
+ kwargs['stderr'] = sp.DEVNULL
75
+ log = utils.HydraConfig().hydra.job_logging.handlers.file.filename
76
+ log += f".{rank}"
77
+ argv.append("hydra.job_logging.handlers.file.filename=" + log)
78
+ manager.add(sp.Popen([sys.executable] + argv, cwd=utils.get_original_cwd(), **kwargs))
79
+ sys.exit(int(manager.failed))
denoiser/live.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import argparse
9
+ import sys
10
+
11
+ import sounddevice as sd
12
+ import torch
13
+
14
+ from .demucs import DemucsStreamer
15
+ from .pretrained import add_model_flags, get_model
16
+ from .utils import bold
17
+
18
+
19
+ def get_parser():
20
+ parser = argparse.ArgumentParser(
21
+ "denoiser.live",
22
+ description="Performs live speech enhancement, reading audio from "
23
+ "the default mic (or interface specified by --in) and "
24
+ "writing the enhanced version to 'Soundflower (2ch)' "
25
+ "(or the interface specified by --out)."
26
+ )
27
+ parser.add_argument(
28
+ "-i", "--in", dest="in_",
29
+ help="name or index of input interface.")
30
+ parser.add_argument(
31
+ "-o", "--out", default="Soundflower (2ch)",
32
+ help="name or index of output interface.")
33
+ add_model_flags(parser)
34
+ parser.add_argument(
35
+ "--sample_rate", type=int, default=16_000,
36
+ help="Sample rate")
37
+ parser.add_argument(
38
+ "--no_compressor", action="store_false", dest="compressor",
39
+ help="Deactivate compressor on output, might lead to clipping.")
40
+ parser.add_argument(
41
+ "--device", default="cpu")
42
+ parser.add_argument(
43
+ "--dry", type=float, default=0.04,
44
+ help="Dry/wet knob, between 0 and 1. 0=maximum noise removal "
45
+ "but it might cause distortions. Default is 0.04")
46
+ parser.add_argument(
47
+ "-t", "--num_threads", type=int,
48
+ help="Number of threads. If you have DDR3 RAM, setting -t 1 can "
49
+ "improve performance.")
50
+ parser.add_argument(
51
+ "-f", "--num_frames", type=int, default=1,
52
+ help="Number of frames to process at once. Larger values increase "
53
+ "the overall lag, but will improve speed.")
54
+ return parser
55
+
56
+
57
+ def parse_audio_device(device):
58
+ if device is None:
59
+ return device
60
+ try:
61
+ return int(device)
62
+ except ValueError:
63
+ return device
64
+
65
+
66
+ def query_devices(device, kind):
67
+ try:
68
+ caps = sd.query_devices(device, kind=kind)
69
+ except ValueError:
70
+ message = bold(f"Invalid {kind} audio interface {device}.\n")
71
+ message += (
72
+ "If you are on Mac OS X, try installing Soundflower "
73
+ "(https://github.com/mattingalls/Soundflower).\n"
74
+ "You can list available interfaces with `python3 -m sounddevice` on Linux and OS X, "
75
+ "and `python.exe -m sounddevice` on Windows. You must have at least one loopback "
76
+ "audio interface to use this.")
77
+ print(message, file=sys.stderr)
78
+ sys.exit(1)
79
+ return caps
80
+
81
+
82
+ def main():
83
+ args = get_parser().parse_args()
84
+ if args.num_threads:
85
+ torch.set_num_threads(args.num_threads)
86
+
87
+ model = get_model(args).to(args.device)
88
+ model.eval()
89
+ print("Model loaded.")
90
+ streamer = DemucsStreamer(model, dry=args.dry, num_frames=args.num_frames)
91
+
92
+ device_in = parse_audio_device(args.in_)
93
+ caps = query_devices(device_in, "input")
94
+ channels_in = min(caps['max_input_channels'], 2)
95
+ stream_in = sd.InputStream(
96
+ device=device_in,
97
+ samplerate=args.sample_rate,
98
+ channels=channels_in)
99
+
100
+ device_out = parse_audio_device(args.out)
101
+ caps = query_devices(device_out, "output")
102
+ channels_out = min(caps['max_output_channels'], 2)
103
+ stream_out = sd.OutputStream(
104
+ device=device_out,
105
+ samplerate=args.sample_rate,
106
+ channels=channels_out)
107
+
108
+ stream_in.start()
109
+ stream_out.start()
110
+ first = True
111
+ current_time = 0
112
+ last_log_time = 0
113
+ last_error_time = 0
114
+ cooldown_time = 2
115
+ log_delta = 10
116
+ sr_ms = args.sample_rate / 1000
117
+ stride_ms = streamer.stride / sr_ms
118
+ print(f"Ready to process audio, total lag: {streamer.total_length / sr_ms:.1f}ms.")
119
+ while True:
120
+ try:
121
+ if current_time > last_log_time + log_delta:
122
+ last_log_time = current_time
123
+ tpf = streamer.time_per_frame * 1000
124
+ rtf = tpf / stride_ms
125
+ print(f"time per frame: {tpf:.1f}ms, ", end='')
126
+ print(f"RTF: {rtf:.1f}")
127
+ streamer.reset_time_per_frame()
128
+
129
+ length = streamer.total_length if first else streamer.stride
130
+ first = False
131
+ current_time += length / args.sample_rate
132
+ frame, overflow = stream_in.read(length)
133
+ frame = torch.from_numpy(frame).mean(dim=1).to(args.device)
134
+ with torch.no_grad():
135
+ out = streamer.feed(frame[None])[0]
136
+ if not out.numel():
137
+ continue
138
+ if args.compressor:
139
+ out = 0.99 * torch.tanh(out)
140
+ out = out[:, None].repeat(1, channels_out)
141
+ mx = out.abs().max().item()
142
+ if mx > 1:
143
+ print("Clipping!!")
144
+ out.clamp_(-1, 1)
145
+ out = out.cpu().numpy()
146
+ underflow = stream_out.write(out)
147
+ if overflow or underflow:
148
+ if current_time >= last_error_time + cooldown_time:
149
+ last_error_time = current_time
150
+ tpf = 1000 * streamer.time_per_frame
151
+ print(f"Not processing audio fast enough, time per frame is {tpf:.1f}ms "
152
+ f"(should be less than {stride_ms:.1f}ms).")
153
+ except KeyboardInterrupt:
154
+ print("Stopping")
155
+ break
156
+ stream_out.stop()
157
+ stream_in.stop()
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
denoiser/pretrained.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import logging
9
+
10
+ import torch.hub
11
+
12
+ from .demucs import Demucs
13
+ from .utils import deserialize_model
14
+
15
+ logger = logging.getLogger(__name__)
16
+ ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/"
17
+ DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th"
18
+ DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th"
19
+ MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th"
20
+
21
+
22
+ def _demucs(pretrained, url, **kwargs):
23
+ model = Demucs(**kwargs)
24
+ if pretrained:
25
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
26
+ model.load_state_dict(state_dict)
27
+ return model
28
+
29
+
30
+ def dns48(pretrained=True):
31
+ return _demucs(pretrained, DNS_48_URL, hidden=48)
32
+
33
+
34
+ def dns64(pretrained=True):
35
+ return _demucs(pretrained, DNS_64_URL, hidden=64)
36
+
37
+
38
+ def master64(pretrained=True):
39
+ return _demucs(pretrained, MASTER_64_URL, hidden=64)
40
+
41
+
42
+ def add_model_flags(parser):
43
+ group = parser.add_mutually_exclusive_group(required=False)
44
+ group.add_argument("-m", "--model_path", help="Path to local trained model.")
45
+ group.add_argument("--dns48", action="store_true",
46
+ help="Use pre-trained real time H=48 model trained on DNS.")
47
+ group.add_argument("--dns64", action="store_true",
48
+ help="Use pre-trained real time H=64 model trained on DNS.")
49
+ group.add_argument("--master64", action="store_true",
50
+ help="Use pre-trained real time H=64 model trained on DNS and Valentini.")
51
+
52
+
53
+ def get_model(args):
54
+ """
55
+ Load local model package or torchhub pre-trained model.
56
+ """
57
+ if args.model_path:
58
+ logger.info("Loading model from %s", args.model_path)
59
+ model = Demucs(hidden=64)
60
+ pkg = torch.load(args.model_path, map_location='cpu')
61
+ model.load_state_dict(pkg)
62
+ elif args.dns64:
63
+ logger.info("Loading pre-trained real time H=64 model trained on DNS.")
64
+ model = dns64()
65
+ elif args.master64:
66
+ logger.info("Loading pre-trained real time H=64 model trained on DNS and Valentini.")
67
+ model = master64()
68
+ else:
69
+ logger.info("Loading pre-trained real time H=48 model trained on DNS.")
70
+ model = dns48()
71
+ logger.debug(model)
72
+ return model
denoiser/resample.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import math
9
+
10
+ import torch as th
11
+ from torch.nn import functional as F
12
+
13
+
14
+ def sinc(t):
15
+ """sinc.
16
+
17
+ :param t: the input tensor
18
+ """
19
+ return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype), th.sin(t) / t)
20
+
21
+
22
+ def kernel_upsample2(zeros=56):
23
+ """kernel_upsample2.
24
+
25
+ """
26
+ win = th.hann_window(4 * zeros + 1, periodic=False)
27
+ winodd = win[1::2]
28
+ t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
29
+ t *= math.pi
30
+ kernel = (sinc(t) * winodd).view(1, 1, -1)
31
+ return kernel
32
+
33
+
34
+ def upsample2(x, zeros=56):
35
+ """
36
+ Upsampling the input by 2 using sinc interpolation.
37
+ Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
38
+ ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
39
+ Vol. 9. IEEE, 1984.
40
+ """
41
+ *other, time = x.shape
42
+ kernel = kernel_upsample2(zeros).to(x)
43
+ out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time)
44
+ y = th.stack([x, out], dim=-1)
45
+ return y.view(*other, -1)
46
+
47
+
48
+ def kernel_downsample2(zeros=56):
49
+ """kernel_downsample2.
50
+
51
+ """
52
+ win = th.hann_window(4 * zeros + 1, periodic=False)
53
+ winodd = win[1::2]
54
+ t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
55
+ t.mul_(math.pi)
56
+ kernel = (sinc(t) * winodd).view(1, 1, -1)
57
+ return kernel
58
+
59
+
60
+ def downsample2(x, zeros=56):
61
+ """
62
+ Downsampling the input by 2 using sinc interpolation.
63
+ Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method."
64
+ ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing.
65
+ Vol. 9. IEEE, 1984.
66
+ """
67
+ if x.shape[-1] % 2 != 0:
68
+ x = F.pad(x, (0, 1))
69
+ xeven = x[..., ::2]
70
+ xodd = x[..., 1::2]
71
+ *other, time = xodd.shape
72
+ kernel = kernel_downsample2(zeros).to(x)
73
+ out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view(
74
+ *other, time)
75
+ return out.view(*other, -1).mul(0.5)
denoiser/solver.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adiyoss
7
+
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ import os
12
+ import time
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from . import augment, distrib, pretrained
18
+ from .enhance import enhance
19
+ from .evaluate import evaluate
20
+ from .stft_loss import MultiResolutionSTFTLoss
21
+ from .utils import bold, copy_state, pull_metric, serialize_model, swap_state, LogProgress
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class Solver(object):
27
+ def __init__(self, data, model, optimizer, args):
28
+ self.tr_loader = data['tr_loader']
29
+ self.cv_loader = data['cv_loader']
30
+ self.tt_loader = data['tt_loader']
31
+ self.model = model
32
+ self.dmodel = distrib.wrap(model)
33
+ self.optimizer = optimizer
34
+
35
+ # data augment
36
+ augments = []
37
+ if args.remix:
38
+ augments.append(augment.Remix())
39
+ if args.bandmask:
40
+ augments.append(augment.BandMask(args.bandmask, sample_rate=args.sample_rate))
41
+ if args.shift:
42
+ augments.append(augment.Shift(args.shift, args.shift_same))
43
+ if args.revecho:
44
+ augments.append(
45
+ augment.RevEcho(args.revecho))
46
+ self.augment = torch.nn.Sequential(*augments)
47
+
48
+ # Training config
49
+ self.device = args.device
50
+ self.epochs = args.epochs
51
+
52
+ # Checkpoints
53
+ self.continue_from = args.continue_from
54
+ self.eval_every = args.eval_every
55
+ self.checkpoint = args.checkpoint
56
+ if self.checkpoint:
57
+ self.checkpoint_file = Path(args.checkpoint_file)
58
+ self.best_file = Path(args.best_file)
59
+ logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve())
60
+ self.history_file = args.history_file
61
+
62
+ self.best_state = None
63
+ self.restart = args.restart
64
+ self.history = [] # Keep track of loss
65
+ self.samples_dir = args.samples_dir # Where to save samples
66
+ self.num_prints = args.num_prints # Number of times to log per epoch
67
+ self.args = args
68
+ self.mrstftloss = MultiResolutionSTFTLoss(factor_sc=args.stft_sc_factor,
69
+ factor_mag=args.stft_mag_factor)
70
+ self._reset()
71
+
72
+ def _serialize(self):
73
+ package = {}
74
+ package['model'] = serialize_model(self.model)
75
+ package['optimizer'] = self.optimizer.state_dict()
76
+ package['history'] = self.history
77
+ package['best_state'] = self.best_state
78
+ package['args'] = self.args
79
+ tmp_path = str(self.checkpoint_file) + ".tmp"
80
+ torch.save(package, tmp_path)
81
+ # renaming is sort of atomic on UNIX (not really true on NFS)
82
+ # but still less chances of leaving a half written checkpoint behind.
83
+ os.rename(tmp_path, self.checkpoint_file)
84
+
85
+ # Saving only the latest best model.
86
+ model = package['model']
87
+ model['state'] = self.best_state
88
+ tmp_path = str(self.best_file) + ".tmp"
89
+ torch.save(model, tmp_path)
90
+ os.rename(tmp_path, self.best_file)
91
+
92
+ def _reset(self):
93
+ """_reset."""
94
+ load_from = None
95
+ load_best = False
96
+ keep_history = True
97
+ # Reset
98
+ if self.checkpoint and self.checkpoint_file.exists() and not self.restart:
99
+ load_from = self.checkpoint_file
100
+ elif self.continue_from:
101
+ load_from = self.continue_from
102
+ load_best = self.args.continue_best
103
+ keep_history = False
104
+
105
+ if load_from:
106
+ logger.info(f'Loading checkpoint model: {load_from}')
107
+ package = torch.load(load_from, 'cpu')
108
+ if load_best:
109
+ self.model.load_state_dict(package['best_state'])
110
+ else:
111
+ self.model.load_state_dict(package['model']['state'])
112
+ if 'optimizer' in package and not load_best:
113
+ self.optimizer.load_state_dict(package['optimizer'])
114
+ if keep_history:
115
+ self.history = package['history']
116
+ self.best_state = package['best_state']
117
+ continue_pretrained = self.args.continue_pretrained
118
+ if continue_pretrained:
119
+ logger.info("Fine tuning from pre-trained model %s", continue_pretrained)
120
+ model = getattr(pretrained, self.args.continue_pretrained)()
121
+ self.model.load_state_dict(model.state_dict())
122
+
123
+ def train(self):
124
+ # Optimizing the model
125
+ if self.history:
126
+ logger.info("Replaying metrics from previous run")
127
+ for epoch, metrics in enumerate(self.history):
128
+ info = " ".join(f"{k.capitalize()}={v:.5f}" for k, v in metrics.items())
129
+ logger.info(f"Epoch {epoch + 1}: {info}")
130
+
131
+ for epoch in range(len(self.history), self.epochs):
132
+ # Train one epoch
133
+ self.model.train()
134
+ start = time.time()
135
+ logger.info('-' * 70)
136
+ logger.info("Training...")
137
+ train_loss = self._run_one_epoch(epoch)
138
+ logger.info(
139
+ bold(f'Train Summary | End of Epoch {epoch + 1} | '
140
+ f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))
141
+
142
+ if self.cv_loader:
143
+ # Cross validation
144
+ logger.info('-' * 70)
145
+ logger.info('Cross validation...')
146
+ self.model.eval()
147
+ with torch.no_grad():
148
+ valid_loss = self._run_one_epoch(epoch, cross_valid=True)
149
+ logger.info(
150
+ bold(f'Valid Summary | End of Epoch {epoch + 1} | '
151
+ f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))
152
+ else:
153
+ valid_loss = 0
154
+
155
+ best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
156
+ metrics = {'train': train_loss, 'valid': valid_loss, 'best': best_loss}
157
+ # Save the best model
158
+ if valid_loss == best_loss:
159
+ logger.info(bold('New best valid loss %.4f'), valid_loss)
160
+ self.best_state = copy_state(self.model.state_dict())
161
+
162
+ # evaluate and enhance samples every 'eval_every' argument number of epochs
163
+ # also evaluate on last epoch
164
+ if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1:
165
+ # Evaluate on the testset
166
+ logger.info('-' * 70)
167
+ logger.info('Evaluating on the test set...')
168
+ # We switch to the best known model for testing
169
+ with swap_state(self.model, self.best_state):
170
+ pesq, stoi = evaluate(self.args, self.model, self.tt_loader)
171
+
172
+ metrics.update({'pesq': pesq, 'stoi': stoi})
173
+
174
+ # enhance some samples
175
+ logger.info('Enhance and save samples...')
176
+ enhance(self.args, self.model, self.samples_dir)
177
+
178
+ self.history.append(metrics)
179
+ info = " | ".join(f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
180
+ logger.info('-' * 70)
181
+ logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))
182
+
183
+ if distrib.rank == 0:
184
+ json.dump(self.history, open(self.history_file, "w"), indent=2)
185
+ # Save model each epoch
186
+ if self.checkpoint:
187
+ self._serialize()
188
+ logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())
189
+
190
+ def _run_one_epoch(self, epoch, cross_valid=False):
191
+ total_loss = 0
192
+ data_loader = self.tr_loader if not cross_valid else self.cv_loader
193
+
194
+ # get a different order for distributed training, otherwise this will get ignored
195
+ data_loader.epoch = epoch
196
+
197
+ label = ["Train", "Valid"][cross_valid]
198
+ name = label + f" | Epoch {epoch + 1}"
199
+ logprog = LogProgress(logger, data_loader, updates=self.num_prints, name=name)
200
+ for i, data in enumerate(logprog):
201
+ noisy, clean = [x.to(self.device) for x in data]
202
+ if not cross_valid:
203
+ sources = torch.stack([noisy - clean, clean])
204
+ sources = self.augment(sources)
205
+ noise, clean = sources
206
+ noisy = noise + clean
207
+ estimate = self.dmodel(noisy)
208
+ # apply a loss function after each layer
209
+ with torch.autograd.set_detect_anomaly(True):
210
+ if self.args.loss == 'l1':
211
+ loss = F.l1_loss(clean, estimate)
212
+ elif self.args.loss == 'l2':
213
+ loss = F.mse_loss(clean, estimate)
214
+ elif self.args.loss == 'huber':
215
+ loss = F.smooth_l1_loss(clean, estimate)
216
+ else:
217
+ raise ValueError(f"Invalid loss {self.args.loss}")
218
+ # MultiResolution STFT loss
219
+ if self.args.stft_loss:
220
+ sc_loss, mag_loss = self.mrstftloss(estimate.squeeze(1), clean.squeeze(1))
221
+ loss += sc_loss + mag_loss
222
+
223
+ # optimize model in training mode
224
+ if not cross_valid:
225
+ self.optimizer.zero_grad()
226
+ loss.backward()
227
+ self.optimizer.step()
228
+
229
+ total_loss += loss.item()
230
+ logprog.update(loss=format(total_loss / (i + 1), ".5f"))
231
+ # Just in case, clear some memory
232
+ del loss, estimate
233
+ return distrib.average([total_loss / (i + 1)], i + 1)[0]
denoiser/stft_loss.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # Original copyright 2019 Tomoki Hayashi
9
+ # MIT License (https://opensource.org/licenses/MIT)
10
+
11
+ """STFT-based Loss modules."""
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def stft(x, fft_size, hop_size, win_length, window):
18
+ """Perform STFT and convert to magnitude spectrogram.
19
+ Args:
20
+ x (Tensor): Input signal tensor (B, T).
21
+ fft_size (int): FFT size.
22
+ hop_size (int): Hop size.
23
+ win_length (int): Window length.
24
+ window (str): Window function type.
25
+ Returns:
26
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
27
+ """
28
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
29
+ real = x_stft[..., 0]
30
+ imag = x_stft[..., 1]
31
+
32
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
33
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
34
+
35
+
36
+ class SpectralConvergengeLoss(torch.nn.Module):
37
+ """Spectral convergence loss module."""
38
+
39
+ def __init__(self):
40
+ """Initilize spectral convergence loss module."""
41
+ super(SpectralConvergengeLoss, self).__init__()
42
+
43
+ def forward(self, x_mag, y_mag):
44
+ """Calculate forward propagation.
45
+ Args:
46
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
47
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
48
+ Returns:
49
+ Tensor: Spectral convergence loss value.
50
+ """
51
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
52
+
53
+
54
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
55
+ """Log STFT magnitude loss module."""
56
+
57
+ def __init__(self):
58
+ """Initilize los STFT magnitude loss module."""
59
+ super(LogSTFTMagnitudeLoss, self).__init__()
60
+
61
+ def forward(self, x_mag, y_mag):
62
+ """Calculate forward propagation.
63
+ Args:
64
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
65
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
66
+ Returns:
67
+ Tensor: Log STFT magnitude loss value.
68
+ """
69
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
70
+
71
+
72
+ class STFTLoss(torch.nn.Module):
73
+ """STFT loss module."""
74
+
75
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
76
+ """Initialize STFT loss module."""
77
+ super(STFTLoss, self).__init__()
78
+ self.fft_size = fft_size
79
+ self.shift_size = shift_size
80
+ self.win_length = win_length
81
+ self.window = getattr(torch, window)(win_length)
82
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
83
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
84
+
85
+ def forward(self, x, y):
86
+ """Calculate forward propagation.
87
+ Args:
88
+ x (Tensor): Predicted signal (B, T).
89
+ y (Tensor): Groundtruth signal (B, T).
90
+ Returns:
91
+ Tensor: Spectral convergence loss value.
92
+ Tensor: Log STFT magnitude loss value.
93
+ """
94
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
95
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
96
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
97
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
98
+
99
+ return sc_loss, mag_loss
100
+
101
+
102
+ class MultiResolutionSTFTLoss(torch.nn.Module):
103
+ """Multi resolution STFT loss module."""
104
+
105
+ def __init__(self,
106
+ fft_sizes=[1024, 2048, 512],
107
+ hop_sizes=[120, 240, 50],
108
+ win_lengths=[600, 1200, 240],
109
+ window="hann_window", factor_sc=0.1, factor_mag=0.1):
110
+ """Initialize Multi resolution STFT loss module.
111
+ Args:
112
+ fft_sizes (list): List of FFT sizes.
113
+ hop_sizes (list): List of hop sizes.
114
+ win_lengths (list): List of window lengths.
115
+ window (str): Window function type.
116
+ factor (float): a balancing factor across different losses.
117
+ """
118
+ super(MultiResolutionSTFTLoss, self).__init__()
119
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
120
+ self.stft_losses = torch.nn.ModuleList()
121
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
122
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
123
+ self.factor_sc = factor_sc
124
+ self.factor_mag = factor_mag
125
+
126
+ def forward(self, x, y):
127
+ """Calculate forward propagation.
128
+ Args:
129
+ x (Tensor): Predicted signal (B, T).
130
+ y (Tensor): Groundtruth signal (B, T).
131
+ Returns:
132
+ Tensor: Multi resolution spectral convergence loss value.
133
+ Tensor: Multi resolution log STFT magnitude loss value.
134
+ """
135
+ sc_loss = 0.0
136
+ mag_loss = 0.0
137
+ for f in self.stft_losses:
138
+ sc_l, mag_l = f(x, y)
139
+ sc_loss += sc_l
140
+ mag_loss += mag_l
141
+ sc_loss /= len(self.stft_losses)
142
+ mag_loss /= len(self.stft_losses)
143
+
144
+ return self.factor_sc*sc_loss, self.factor_mag*mag_loss
denoiser/utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # author: adefossez
7
+
8
+ import functools
9
+ import logging
10
+ from contextlib import contextmanager
11
+ import inspect
12
+ import time
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def capture_init(init):
18
+ """capture_init.
19
+
20
+ Decorate `__init__` with this, and you can then
21
+ recover the *args and **kwargs passed to it in `self._init_args_kwargs`
22
+ """
23
+ @functools.wraps(init)
24
+ def __init__(self, *args, **kwargs):
25
+ self._init_args_kwargs = (args, kwargs)
26
+ init(self, *args, **kwargs)
27
+
28
+ return __init__
29
+
30
+
31
+ def deserialize_model(package, strict=False):
32
+ """deserialize_model.
33
+
34
+ """
35
+ klass = package['class']
36
+ if strict:
37
+ model = klass(*package['args'], **package['kwargs'])
38
+ else:
39
+ sig = inspect.signature(klass)
40
+ kw = package['kwargs']
41
+ for key in list(kw):
42
+ if key not in sig.parameters:
43
+ logger.warning("Dropping inexistant parameter %s", key)
44
+ del kw[key]
45
+ model = klass(*package['args'], **kw)
46
+ model.load_state_dict(package['state'])
47
+ return model
48
+
49
+
50
+ def copy_state(state):
51
+ return {k: v.cpu().clone() for k, v in state.items()}
52
+
53
+
54
+ def serialize_model(model):
55
+ args, kwargs = model._init_args_kwargs
56
+ state = copy_state(model.state_dict())
57
+ return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
58
+
59
+
60
+ @contextmanager
61
+ def swap_state(model, state):
62
+ """
63
+ Context manager that swaps the state of a model, e.g:
64
+
65
+ # model is in old state
66
+ with swap_state(model, new_state):
67
+ # model in new state
68
+ # model back to old state
69
+ """
70
+ old_state = copy_state(model.state_dict())
71
+ model.load_state_dict(state)
72
+ try:
73
+ yield
74
+ finally:
75
+ model.load_state_dict(old_state)
76
+
77
+
78
+ def pull_metric(history, name):
79
+ out = []
80
+ for metrics in history:
81
+ if name in metrics:
82
+ out.append(metrics[name])
83
+ return out
84
+
85
+
86
+ class LogProgress:
87
+ """
88
+ Sort of like tqdm but using log lines and not as real time.
89
+ Args:
90
+ - logger: logger obtained from `logging.getLogger`,
91
+ - iterable: iterable object to wrap
92
+ - updates (int): number of lines that will be printed, e.g.
93
+ if `updates=5`, log every 1/5th of the total length.
94
+ - total (int): length of the iterable, in case it does not support
95
+ `len`.
96
+ - name (str): prefix to use in the log.
97
+ - level: logging level (like `logging.INFO`).
98
+ """
99
+ def __init__(self,
100
+ logger,
101
+ iterable,
102
+ updates=5,
103
+ total=None,
104
+ name="LogProgress",
105
+ level=logging.INFO):
106
+ self.iterable = iterable
107
+ self.total = total or len(iterable)
108
+ self.updates = updates
109
+ self.name = name
110
+ self.logger = logger
111
+ self.level = level
112
+
113
+ def update(self, **infos):
114
+ self._infos = infos
115
+
116
+ def __iter__(self):
117
+ self._iterator = iter(self.iterable)
118
+ self._index = -1
119
+ self._infos = {}
120
+ self._begin = time.time()
121
+ return self
122
+
123
+ def __next__(self):
124
+ self._index += 1
125
+ try:
126
+ value = next(self._iterator)
127
+ except StopIteration:
128
+ raise
129
+ else:
130
+ return value
131
+ finally:
132
+ log_every = max(1, self.total // self.updates)
133
+ # logging is delayed by 1 it, in order to have the metrics from update
134
+ if self._index >= 1 and self._index % log_every == 0:
135
+ self._log()
136
+
137
+ def _log(self):
138
+ self._speed = (1 + self._index) / (time.time() - self._begin)
139
+ infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items())
140
+ if self._speed < 1e-4:
141
+ speed = "oo sec/it"
142
+ elif self._speed < 0.1:
143
+ speed = f"{1/self._speed:.1f} sec/it"
144
+ else:
145
+ speed = f"{self._speed:.1f} it/sec"
146
+ out = f"{self.name} | {self._index}/{self.total} | {speed}"
147
+ if infos:
148
+ out += " | " + infos
149
+ self.logger.log(self.level, out)
150
+
151
+
152
+ def colorize(text, color):
153
+ """
154
+ Display text with some ANSI color in the terminal.
155
+ """
156
+ code = f"\033[{color}m"
157
+ restore = "\033[0m"
158
+ return "".join([code, text, restore])
159
+
160
+
161
+ def bold(text):
162
+ """
163
+ Display text in bold in the terminal.
164
+ """
165
+ return colorize(text, "1")