yangwang825 commited on
Commit
db8f1e0
1 Parent(s): 47c2e51

Create helpers_xvector.py

Browse files
Files changed (1) hide show
  1. helpers_xvector.py +744 -0
helpers_xvector.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Deltas(torch.nn.Module):
7
+ """Computes delta coefficients (time derivatives).
8
+
9
+ Arguments
10
+ ---------
11
+ win_length : int
12
+ Length of the window used to compute the time derivatives.
13
+
14
+ Example
15
+ -------
16
+ >>> inputs = torch.randn([10, 101, 20])
17
+ >>> compute_deltas = Deltas(input_size=inputs.size(-1))
18
+ >>> features = compute_deltas(inputs)
19
+ >>> features.shape
20
+ torch.Size([10, 101, 20])
21
+ """
22
+
23
+ def __init__(
24
+ self, input_size, window_length=5,
25
+ ):
26
+ super().__init__()
27
+ self.n = (window_length - 1) // 2
28
+ self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3
29
+
30
+ self.register_buffer(
31
+ "kernel",
32
+ torch.arange(-self.n, self.n + 1, dtype=torch.float32,).repeat(
33
+ input_size, 1, 1
34
+ ),
35
+ )
36
+
37
+ def forward(self, x):
38
+ """Returns the delta coefficients.
39
+
40
+ Arguments
41
+ ---------
42
+ x : tensor
43
+ A batch of tensors.
44
+ """
45
+ # Managing multi-channel deltas reshape tensor (batch*channel,time)
46
+ x = x.transpose(1, 2).transpose(2, -1)
47
+ or_shape = x.shape
48
+ if len(or_shape) == 4:
49
+ x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
50
+
51
+ # Padding for time borders
52
+ x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate")
53
+
54
+ # Derivative estimation (with a fixed convolutional kernel)
55
+ delta_coeff = (
56
+ torch.nn.functional.conv1d(
57
+ x, self.kernel.to(x.device), groups=x.shape[1]
58
+ )
59
+ / self.denom
60
+ )
61
+
62
+ # Retrieving the original dimensionality (for multi-channel case)
63
+ if len(or_shape) == 4:
64
+ delta_coeff = delta_coeff.reshape(
65
+ or_shape[0], or_shape[1], or_shape[2], or_shape[3],
66
+ )
67
+ delta_coeff = delta_coeff.transpose(1, -1).transpose(2, -1)
68
+
69
+ return delta_coeff
70
+
71
+
72
+ class Filterbank(torch.nn.Module):
73
+ """computes filter bank (FBANK) features given spectral magnitudes.
74
+
75
+ Arguments
76
+ ---------
77
+ n_mels : float
78
+ Number of Mel filters used to average the spectrogram.
79
+ log_mel : bool
80
+ If True, it computes the log of the FBANKs.
81
+ filter_shape : str
82
+ Shape of the filters ('triangular', 'rectangular', 'gaussian').
83
+ f_min : int
84
+ Lowest frequency for the Mel filters.
85
+ f_max : int
86
+ Highest frequency for the Mel filters.
87
+ n_fft : int
88
+ Number of fft points of the STFT. It defines the frequency resolution
89
+ (n_fft should be<= than win_len).
90
+ sample_rate : int
91
+ Sample rate of the input audio signal (e.g, 16000)
92
+ power_spectrogram : float
93
+ Exponent used for spectrogram computation.
94
+ amin : float
95
+ Minimum amplitude (used for numerical stability).
96
+ ref_value : float
97
+ Reference value used for the dB scale.
98
+ top_db : float
99
+ Minimum negative cut-off in decibels.
100
+ freeze : bool
101
+ If False, it the central frequency and the band of each filter are
102
+ added into nn.parameters. If True, the standard frozen features
103
+ are computed.
104
+ param_change_factor: bool
105
+ If freeze=False, this parameter affects the speed at which the filter
106
+ parameters (i.e., central_freqs and bands) can be changed. When high
107
+ (e.g., param_change_factor=1) the filters change a lot during training.
108
+ When low (e.g. param_change_factor=0.1) the filter parameters are more
109
+ stable during training
110
+ param_rand_factor: float
111
+ This parameter can be used to randomly change the filter parameters
112
+ (i.e, central frequencies and bands) during training. It is thus a
113
+ sort of regularization. param_rand_factor=0 does not affect, while
114
+ param_rand_factor=0.15 allows random variations within +-15% of the
115
+ standard values of the filter parameters (e.g., if the central freq
116
+ is 100 Hz, we can randomly change it from 85 Hz to 115 Hz).
117
+
118
+ Example
119
+ -------
120
+ >>> import torch
121
+ >>> compute_fbanks = Filterbank()
122
+ >>> inputs = torch.randn([10, 101, 201])
123
+ >>> features = compute_fbanks(inputs)
124
+ >>> features.shape
125
+ torch.Size([10, 101, 40])
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ n_mels=40,
131
+ log_mel=True,
132
+ filter_shape="triangular",
133
+ f_min=0,
134
+ f_max=8000,
135
+ n_fft=400,
136
+ sample_rate=16000,
137
+ power_spectrogram=2,
138
+ amin=1e-10,
139
+ ref_value=1.0,
140
+ top_db=80.0,
141
+ param_change_factor=1.0,
142
+ param_rand_factor=0.0,
143
+ freeze=True,
144
+ ):
145
+ super().__init__()
146
+ self.n_mels = n_mels
147
+ self.log_mel = log_mel
148
+ self.filter_shape = filter_shape
149
+ self.f_min = f_min
150
+ self.f_max = f_max
151
+ self.n_fft = n_fft
152
+ self.sample_rate = sample_rate
153
+ self.power_spectrogram = power_spectrogram
154
+ self.amin = amin
155
+ self.ref_value = ref_value
156
+ self.top_db = top_db
157
+ self.freeze = freeze
158
+ self.n_stft = self.n_fft // 2 + 1
159
+ self.db_multiplier = math.log10(max(self.amin, self.ref_value))
160
+ self.device_inp = torch.device("cpu")
161
+ self.param_change_factor = param_change_factor
162
+ self.param_rand_factor = param_rand_factor
163
+
164
+ if self.power_spectrogram == 2:
165
+ self.multiplier = 10
166
+ else:
167
+ self.multiplier = 20
168
+
169
+ # Make sure f_min < f_max
170
+ if self.f_min >= self.f_max:
171
+ err_msg = "Require f_min: %f < f_max: %f" % (
172
+ self.f_min,
173
+ self.f_max,
174
+ )
175
+ print(err_msg)
176
+
177
+ # Filter definition
178
+ mel = torch.linspace(
179
+ self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2
180
+ )
181
+ hz = self._to_hz(mel)
182
+
183
+ # Computation of the filter bands
184
+ band = hz[1:] - hz[:-1]
185
+ self.band = band[:-1]
186
+ self.f_central = hz[1:-1]
187
+
188
+ # Adding the central frequency and the band to the list of nn param
189
+ if not self.freeze:
190
+ self.f_central = torch.nn.Parameter(
191
+ self.f_central / (self.sample_rate * self.param_change_factor)
192
+ )
193
+ self.band = torch.nn.Parameter(
194
+ self.band / (self.sample_rate * self.param_change_factor)
195
+ )
196
+
197
+ # Frequency axis
198
+ all_freqs = torch.linspace(0, self.sample_rate // 2, self.n_stft)
199
+
200
+ # Replicating for all the filters
201
+ self.all_freqs_mat = all_freqs.repeat(self.f_central.shape[0], 1)
202
+
203
+ def forward(self, spectrogram):
204
+ """Returns the FBANks.
205
+
206
+ Arguments
207
+ ---------
208
+ x : tensor
209
+ A batch of spectrogram tensors.
210
+ """
211
+ # Computing central frequency and bandwidth of each filter
212
+ f_central_mat = self.f_central.repeat(
213
+ self.all_freqs_mat.shape[1], 1
214
+ ).transpose(0, 1)
215
+ band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose(
216
+ 0, 1
217
+ )
218
+
219
+ # Uncomment to print filter parameters
220
+ # print(self.f_central*self.sample_rate * self.param_change_factor)
221
+ # print(self.band*self.sample_rate* self.param_change_factor)
222
+
223
+ # Creation of the multiplication matrix. It is used to create
224
+ # the filters that average the computed spectrogram.
225
+ if not self.freeze:
226
+ f_central_mat = f_central_mat * (
227
+ self.sample_rate
228
+ * self.param_change_factor
229
+ * self.param_change_factor
230
+ )
231
+ band_mat = band_mat * (
232
+ self.sample_rate
233
+ * self.param_change_factor
234
+ * self.param_change_factor
235
+ )
236
+
237
+ # Regularization with random changes of filter central frequency and band
238
+ elif self.param_rand_factor != 0 and self.training:
239
+ rand_change = (
240
+ 1.0
241
+ + torch.rand(2) * 2 * self.param_rand_factor
242
+ - self.param_rand_factor
243
+ )
244
+ f_central_mat = f_central_mat * rand_change[0]
245
+ band_mat = band_mat * rand_change[1]
246
+
247
+ fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to(
248
+ spectrogram.device
249
+ )
250
+
251
+ sp_shape = spectrogram.shape
252
+
253
+ # Managing multi-channels case (batch, time, channels)
254
+ if len(sp_shape) == 4:
255
+ spectrogram = spectrogram.permute(0, 3, 1, 2)
256
+ spectrogram = spectrogram.reshape(
257
+ sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2]
258
+ )
259
+
260
+ # FBANK computation
261
+ fbanks = torch.matmul(spectrogram, fbank_matrix)
262
+ if self.log_mel:
263
+ fbanks = self._amplitude_to_DB(fbanks)
264
+
265
+ # Reshaping in the case of multi-channel inputs
266
+ if len(sp_shape) == 4:
267
+ fb_shape = fbanks.shape
268
+ fbanks = fbanks.reshape(
269
+ sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2]
270
+ )
271
+ fbanks = fbanks.permute(0, 2, 3, 1)
272
+
273
+ return fbanks
274
+
275
+ @staticmethod
276
+ def _to_mel(hz):
277
+ """Returns mel-frequency value corresponding to the input
278
+ frequency value in Hz.
279
+
280
+ Arguments
281
+ ---------
282
+ x : float
283
+ The frequency point in Hz.
284
+ """
285
+ return 2595 * math.log10(1 + hz / 700)
286
+
287
+ @staticmethod
288
+ def _to_hz(mel):
289
+ """Returns hz-frequency value corresponding to the input
290
+ mel-frequency value.
291
+
292
+ Arguments
293
+ ---------
294
+ x : float
295
+ The frequency point in the mel-scale.
296
+ """
297
+ return 700 * (10 ** (mel / 2595) - 1)
298
+
299
+ def _triangular_filters(self, all_freqs, f_central, band):
300
+ """Returns fbank matrix using triangular filters.
301
+
302
+ Arguments
303
+ ---------
304
+ all_freqs : Tensor
305
+ Tensor gathering all the frequency points.
306
+ f_central : Tensor
307
+ Tensor gathering central frequencies of each filter.
308
+ band : Tensor
309
+ Tensor gathering the bands of each filter.
310
+ """
311
+
312
+ # Computing the slops of the filters
313
+ slope = (all_freqs - f_central) / band
314
+ left_side = slope + 1.0
315
+ right_side = -slope + 1.0
316
+
317
+ # Adding zeros for negative values
318
+ zero = torch.zeros(1, device=self.device_inp)
319
+ fbank_matrix = torch.max(
320
+ zero, torch.min(left_side, right_side)
321
+ ).transpose(0, 1)
322
+
323
+ return fbank_matrix
324
+
325
+ def _rectangular_filters(self, all_freqs, f_central, band):
326
+ """Returns fbank matrix using rectangular filters.
327
+
328
+ Arguments
329
+ ---------
330
+ all_freqs : Tensor
331
+ Tensor gathering all the frequency points.
332
+ f_central : Tensor
333
+ Tensor gathering central frequencies of each filter.
334
+ band : Tensor
335
+ Tensor gathering the bands of each filter.
336
+ """
337
+
338
+ # cut-off frequencies of the filters
339
+ low_hz = f_central - band
340
+ high_hz = f_central + band
341
+
342
+ # Left/right parts of the filter
343
+ left_side = right_size = all_freqs.ge(low_hz)
344
+ right_size = all_freqs.le(high_hz)
345
+
346
+ fbank_matrix = (left_side * right_size).float().transpose(0, 1)
347
+
348
+ return fbank_matrix
349
+
350
+ def _gaussian_filters(
351
+ self, all_freqs, f_central, band, smooth_factor=torch.tensor(2)
352
+ ):
353
+ """Returns fbank matrix using gaussian filters.
354
+
355
+ Arguments
356
+ ---------
357
+ all_freqs : Tensor
358
+ Tensor gathering all the frequency points.
359
+ f_central : Tensor
360
+ Tensor gathering central frequencies of each filter.
361
+ band : Tensor
362
+ Tensor gathering the bands of each filter.
363
+ smooth_factor: Tensor
364
+ Smoothing factor of the gaussian filter. It can be used to employ
365
+ sharper or flatter filters.
366
+ """
367
+ fbank_matrix = torch.exp(
368
+ -0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2
369
+ ).transpose(0, 1)
370
+
371
+ return fbank_matrix
372
+
373
+ def _create_fbank_matrix(self, f_central_mat, band_mat):
374
+ """Returns fbank matrix to use for averaging the spectrum with
375
+ the set of filter-banks.
376
+
377
+ Arguments
378
+ ---------
379
+ f_central : Tensor
380
+ Tensor gathering central frequencies of each filter.
381
+ band : Tensor
382
+ Tensor gathering the bands of each filter.
383
+ smooth_factor: Tensor
384
+ Smoothing factor of the gaussian filter. It can be used to employ
385
+ sharper or flatter filters.
386
+ """
387
+ if self.filter_shape == "triangular":
388
+ fbank_matrix = self._triangular_filters(
389
+ self.all_freqs_mat, f_central_mat, band_mat
390
+ )
391
+
392
+ elif self.filter_shape == "rectangular":
393
+ fbank_matrix = self._rectangular_filters(
394
+ self.all_freqs_mat, f_central_mat, band_mat
395
+ )
396
+
397
+ else:
398
+ fbank_matrix = self._gaussian_filters(
399
+ self.all_freqs_mat, f_central_mat, band_mat
400
+ )
401
+
402
+ return fbank_matrix
403
+
404
+ def _amplitude_to_DB(self, x):
405
+ """Converts linear-FBANKs to log-FBANKs.
406
+
407
+ Arguments
408
+ ---------
409
+ x : Tensor
410
+ A batch of linear FBANK tensors.
411
+
412
+ """
413
+
414
+ x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
415
+ x_db -= self.multiplier * self.db_multiplier
416
+
417
+ # Setting up dB max. It is the max over time and frequency,
418
+ # Hence, of a whole sequence (sequence-dependent)
419
+ new_x_db_max = x_db.amax(dim=(-2, -1)) - self.top_db
420
+
421
+ # Clipping to dB max. The view is necessary as only a scalar is obtained
422
+ # per sequence.
423
+ x_db = torch.max(x_db, new_x_db_max.view(x_db.shape[0], 1, 1))
424
+
425
+ return x_db
426
+
427
+
428
+ class STFT(torch.nn.Module):
429
+ """computes the Short-Term Fourier Transform (STFT).
430
+
431
+ This class computes the Short-Term Fourier Transform of an audio signal.
432
+ It supports multi-channel audio inputs (batch, time, channels).
433
+
434
+ Arguments
435
+ ---------
436
+ sample_rate : int
437
+ Sample rate of the input audio signal (e.g 16000).
438
+ win_length : float
439
+ Length (in ms) of the sliding window used to compute the STFT.
440
+ hop_length : float
441
+ Length (in ms) of the hope of the sliding window used to compute
442
+ the STFT.
443
+ n_fft : int
444
+ Number of fft point of the STFT. It defines the frequency resolution
445
+ (n_fft should be <= than win_len).
446
+ window_fn : function
447
+ A function that takes an integer (number of samples) and outputs a
448
+ tensor to be multiplied with each window before fft.
449
+ normalized_stft : bool
450
+ If True, the function returns the normalized STFT results,
451
+ i.e., multiplied by win_length^-0.5 (default is False).
452
+ center : bool
453
+ If True (default), the input will be padded on both sides so that the
454
+ t-th frame is centered at time t×hop_length. Otherwise, the t-th frame
455
+ begins at time t×hop_length.
456
+ pad_mode : str
457
+ It can be 'constant','reflect','replicate', 'circular', 'reflect'
458
+ (default). 'constant' pads the input tensor boundaries with a
459
+ constant value. 'reflect' pads the input tensor using the reflection
460
+ of the input boundary. 'replicate' pads the input tensor using
461
+ replication of the input boundary. 'circular' pads using circular
462
+ replication.
463
+ onesided : True
464
+ If True (default) only returns nfft/2 values. Note that the other
465
+ samples are redundant due to the Fourier transform conjugate symmetry.
466
+
467
+ Example
468
+ -------
469
+ >>> import torch
470
+ >>> compute_STFT = STFT(
471
+ ... sample_rate=16000, win_length=25, hop_length=10, n_fft=400
472
+ ... )
473
+ >>> inputs = torch.randn([10, 16000])
474
+ >>> features = compute_STFT(inputs)
475
+ >>> features.shape
476
+ torch.Size([10, 101, 201, 2])
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ sample_rate,
482
+ win_length=25,
483
+ hop_length=10,
484
+ n_fft=400,
485
+ window_fn=torch.hamming_window,
486
+ normalized_stft=False,
487
+ center=True,
488
+ pad_mode="constant",
489
+ onesided=True,
490
+ ):
491
+ super().__init__()
492
+ self.sample_rate = sample_rate
493
+ self.win_length = win_length
494
+ self.hop_length = hop_length
495
+ self.n_fft = n_fft
496
+ self.normalized_stft = normalized_stft
497
+ self.center = center
498
+ self.pad_mode = pad_mode
499
+ self.onesided = onesided
500
+
501
+ # Convert win_length and hop_length from ms to samples
502
+ self.win_length = int(
503
+ round((self.sample_rate / 1000.0) * self.win_length)
504
+ )
505
+ self.hop_length = int(
506
+ round((self.sample_rate / 1000.0) * self.hop_length)
507
+ )
508
+
509
+ self.window = window_fn(self.win_length)
510
+
511
+ def forward(self, x):
512
+ """Returns the STFT generated from the input waveforms.
513
+
514
+ Arguments
515
+ ---------
516
+ x : tensor
517
+ A batch of audio signals to transform.
518
+ """
519
+
520
+ # Managing multi-channel stft
521
+ or_shape = x.shape
522
+ if len(or_shape) == 3:
523
+ x = x.transpose(1, 2)
524
+ x = x.reshape(or_shape[0] * or_shape[2], or_shape[1])
525
+
526
+ stft = torch.stft(
527
+ x,
528
+ self.n_fft,
529
+ self.hop_length,
530
+ self.win_length,
531
+ self.window.to(x.device),
532
+ self.center,
533
+ self.pad_mode,
534
+ self.normalized_stft,
535
+ self.onesided,
536
+ return_complex=True,
537
+ )
538
+
539
+ stft = torch.view_as_real(stft)
540
+
541
+ # Retrieving the original dimensionality (batch,time, channels)
542
+ if len(or_shape) == 3:
543
+ stft = stft.reshape(
544
+ or_shape[0],
545
+ or_shape[2],
546
+ stft.shape[1],
547
+ stft.shape[2],
548
+ stft.shape[3],
549
+ )
550
+ stft = stft.permute(0, 3, 2, 4, 1)
551
+ else:
552
+ # (batch, time, channels)
553
+ stft = stft.transpose(2, 1)
554
+
555
+ return stft
556
+
557
+
558
+ def spectral_magnitude(
559
+ stft, power: int = 1, log: bool = False, eps: float = 1e-14
560
+ ):
561
+ """Returns the magnitude of a complex spectrogram.
562
+
563
+ Arguments
564
+ ---------
565
+ stft : torch.Tensor
566
+ A tensor, output from the stft function.
567
+ power : int
568
+ What power to use in computing the magnitude.
569
+ Use power=1 for the power spectrogram.
570
+ Use power=0.5 for the magnitude spectrogram.
571
+ log : bool
572
+ Whether to apply log to the spectral features.
573
+
574
+ Example
575
+ -------
576
+ >>> a = torch.Tensor([[3, 4]])
577
+ >>> spectral_magnitude(a, power=0.5)
578
+ tensor([5.])
579
+ """
580
+ spectr = stft.pow(2).sum(-1)
581
+
582
+ # Add eps avoids NaN when spectr is zero
583
+ if power < 1:
584
+ spectr = spectr + eps
585
+ spectr = spectr.pow(power)
586
+
587
+ if log:
588
+ return torch.log(spectr + eps)
589
+ return spectr
590
+
591
+
592
+ class ContextWindow(torch.nn.Module):
593
+ """Computes the context window.
594
+
595
+ This class applies a context window by gathering multiple time steps
596
+ in a single feature vector. The operation is performed with a
597
+ convolutional layer based on a fixed kernel designed for that.
598
+
599
+ Arguments
600
+ ---------
601
+ left_frames : int
602
+ Number of left frames (i.e, past frames) to collect.
603
+ right_frames : int
604
+ Number of right frames (i.e, future frames) to collect.
605
+
606
+ Example
607
+ -------
608
+ >>> import torch
609
+ >>> compute_cw = ContextWindow(left_frames=5, right_frames=5)
610
+ >>> inputs = torch.randn([10, 101, 20])
611
+ >>> features = compute_cw(inputs)
612
+ >>> features.shape
613
+ torch.Size([10, 101, 220])
614
+ """
615
+
616
+ def __init__(
617
+ self, left_frames=0, right_frames=0,
618
+ ):
619
+ super().__init__()
620
+ self.left_frames = left_frames
621
+ self.right_frames = right_frames
622
+ self.context_len = self.left_frames + self.right_frames + 1
623
+ self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1
624
+
625
+ # Kernel definition
626
+ self.kernel = torch.eye(self.context_len, self.kernel_len)
627
+
628
+ if self.right_frames > self.left_frames:
629
+ lag = self.right_frames - self.left_frames
630
+ self.kernel = torch.roll(self.kernel, lag, 1)
631
+
632
+ self.first_call = True
633
+
634
+ def forward(self, x):
635
+ """Returns the tensor with the surrounding context.
636
+
637
+ Arguments
638
+ ---------
639
+ x : tensor
640
+ A batch of tensors.
641
+ """
642
+
643
+ x = x.transpose(1, 2)
644
+
645
+ if self.first_call is True:
646
+ self.first_call = False
647
+ self.kernel = (
648
+ self.kernel.repeat(x.shape[1], 1, 1)
649
+ .view(x.shape[1] * self.context_len, self.kernel_len,)
650
+ .unsqueeze(1)
651
+ )
652
+
653
+ # Managing multi-channel case
654
+ or_shape = x.shape
655
+ if len(or_shape) == 4:
656
+ x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3])
657
+
658
+ # Compute context (using the estimated convolutional kernel)
659
+ cw_x = torch.nn.functional.conv1d(
660
+ x,
661
+ self.kernel.to(x.device),
662
+ groups=x.shape[1],
663
+ padding=max(self.left_frames, self.right_frames),
664
+ )
665
+
666
+ # Retrieving the original dimensionality (for multi-channel case)
667
+ if len(or_shape) == 4:
668
+ cw_x = cw_x.reshape(
669
+ or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1]
670
+ )
671
+
672
+ cw_x = cw_x.transpose(1, 2)
673
+
674
+ return cw_x
675
+
676
+
677
+ class Fbank(torch.nn.Module):
678
+
679
+ def __init__(
680
+ self,
681
+ deltas=False,
682
+ context=False,
683
+ requires_grad=False,
684
+ sample_rate=16000,
685
+ f_min=0,
686
+ f_max=None,
687
+ n_fft=400,
688
+ n_mels=40,
689
+ filter_shape="triangular",
690
+ param_change_factor=1.0,
691
+ param_rand_factor=0.0,
692
+ left_frames=5,
693
+ right_frames=5,
694
+ win_length=25,
695
+ hop_length=10,
696
+ ):
697
+ super().__init__()
698
+ self.deltas = deltas
699
+ self.context = context
700
+ self.requires_grad = requires_grad
701
+
702
+ if f_max is None:
703
+ f_max = sample_rate / 2
704
+
705
+ self.compute_STFT = STFT(
706
+ sample_rate=sample_rate,
707
+ n_fft=n_fft,
708
+ win_length=win_length,
709
+ hop_length=hop_length,
710
+ )
711
+ self.compute_fbanks = Filterbank(
712
+ sample_rate=sample_rate,
713
+ n_fft=n_fft,
714
+ n_mels=n_mels,
715
+ f_min=f_min,
716
+ f_max=f_max,
717
+ freeze=not requires_grad,
718
+ filter_shape=filter_shape,
719
+ param_change_factor=param_change_factor,
720
+ param_rand_factor=param_rand_factor,
721
+ )
722
+ self.compute_deltas = Deltas(input_size=n_mels)
723
+ self.context_window = ContextWindow(
724
+ left_frames=left_frames, right_frames=right_frames,
725
+ )
726
+
727
+ def forward(self, wav):
728
+ """Returns a set of features generated from the input waveforms.
729
+
730
+ Arguments
731
+ ---------
732
+ wav : tensor
733
+ A batch of audio signals to transform to features.
734
+ """
735
+ STFT = self.compute_STFT(wav)
736
+ mag = spectral_magnitude(STFT)
737
+ fbanks = self.compute_fbanks(mag)
738
+ if self.deltas:
739
+ delta1 = self.compute_deltas(fbanks)
740
+ delta2 = self.compute_deltas(delta1)
741
+ fbanks = torch.cat([fbanks, delta1, delta2], dim=2)
742
+ if self.context:
743
+ fbanks = self.context_window(fbanks)
744
+ return fbanks