jhtonyKoo commited on
Commit
6fc042a
1 Parent(s): 694161d
app.py CHANGED
@@ -25,7 +25,14 @@ def process_audio(input_audio, reference_audio, perform_ito):
25
  if ito_output_audio is not None:
26
  sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)
27
 
28
- return "output_mastered.wav", "ito_output_mastered.wav" if ito_output_audio is not None else None
 
 
 
 
 
 
 
29
 
30
  def process_youtube(input_url, reference_url, perform_ito):
31
  input_audio = download_youtube_audio(input_url)
@@ -41,7 +48,12 @@ with gr.Blocks() as demo:
41
  submit_button = gr.Button("Process")
42
  output_audio = gr.Audio(label="Output Audio")
43
  ito_output_audio = gr.Audio(label="ITO Output Audio")
44
- submit_button.click(process_audio, inputs=[input_audio, reference_audio, perform_ito], outputs=[output_audio, ito_output_audio])
 
 
 
 
 
45
 
46
  with gr.Tab("YouTube URLs"):
47
  input_url = gr.Textbox(label="Input YouTube URL")
@@ -50,6 +62,11 @@ with gr.Blocks() as demo:
50
  submit_button_yt = gr.Button("Process")
51
  output_audio_yt = gr.Audio(label="Output Audio")
52
  ito_output_audio_yt = gr.Audio(label="ITO Output Audio")
53
- submit_button_yt.click(process_youtube, inputs=[input_url, reference_url, perform_ito_yt], outputs=[output_audio_yt, ito_output_audio_yt])
 
 
 
 
 
54
 
55
  demo.launch()
 
25
  if ito_output_audio is not None:
26
  sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)
27
 
28
+ # Generate parameter output strings
29
+ param_output = mastering_transfer.get_param_output_string(predicted_params)
30
+ ito_param_output = mastering_transfer.get_param_output_string(ito_predicted_params) if ito_predicted_params is not None else "ITO not performed"
31
+
32
+ # Generate top 10 differences if ITO was performed
33
+ top_10_diff = mastering_transfer.get_top_10_diff_string(predicted_params, ito_predicted_params) if ito_predicted_params is not None else "ITO not performed"
34
+
35
+ return "output_mastered.wav", "ito_output_mastered.wav" if ito_output_audio is not None else None, param_output, ito_param_output, top_10_diff
36
 
37
  def process_youtube(input_url, reference_url, perform_ito):
38
  input_audio = download_youtube_audio(input_url)
 
48
  submit_button = gr.Button("Process")
49
  output_audio = gr.Audio(label="Output Audio")
50
  ito_output_audio = gr.Audio(label="ITO Output Audio")
51
+ param_output = gr.Textbox(label="Predicted Parameters", lines=10)
52
+ ito_param_output = gr.Textbox(label="ITO Predicted Parameters", lines=10)
53
+ top_10_diff = gr.Textbox(label="Top 10 Parameter Differences", lines=10)
54
+ submit_button.click(process_audio,
55
+ inputs=[input_audio, reference_audio, perform_ito],
56
+ outputs=[output_audio, ito_output_audio, param_output, ito_param_output, top_10_diff])
57
 
58
  with gr.Tab("YouTube URLs"):
59
  input_url = gr.Textbox(label="Input YouTube URL")
 
62
  submit_button_yt = gr.Button("Process")
63
  output_audio_yt = gr.Audio(label="Output Audio")
64
  ito_output_audio_yt = gr.Audio(label="ITO Output Audio")
65
+ param_output_yt = gr.Textbox(label="Predicted Parameters", lines=10)
66
+ ito_param_output_yt = gr.Textbox(label="ITO Predicted Parameters", lines=10)
67
+ top_10_diff_yt = gr.Textbox(label="Top 10 Parameter Differences", lines=10)
68
+ submit_button_yt.click(process_youtube,
69
+ inputs=[input_url, reference_url, perform_ito_yt],
70
+ outputs=[output_audio_yt, ito_output_audio_yt, param_output_yt, ito_param_output_yt, top_10_diff_yt])
71
 
72
  demo.launch()
inference.py CHANGED
@@ -9,8 +9,7 @@ import sys
9
  currentdir = os.path.dirname(os.path.realpath(__file__))
10
  sys.path.append(os.path.dirname(currentdir))
11
  from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder
12
- from modules import FrontEnd, BackEnd
13
- from modules.loss import AudioFeatureLoss
14
 
15
  class MasteringStyleTransfer:
16
  def __init__(self, args):
@@ -205,6 +204,59 @@ class MasteringStyleTransfer:
205
  else:
206
  print(f" {fx_params}")
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def reload_weights(model, ckpt_path, device):
209
  checkpoint = torch.load(ckpt_path, map_location=device)
210
 
@@ -215,6 +267,7 @@ def reload_weights(model, ckpt_path, device):
215
  new_state_dict[name] = v
216
  model.load_state_dict(new_state_dict, strict=False)
217
 
 
218
  if __name__ == "__main__":
219
  basis_path = '/data2/tony/Mastering_Style_Transfer/results/dasp_tcn_tuneenc_daspman_loudnessnorm/ckpt/1000/'
220
 
@@ -258,5 +311,3 @@ if __name__ == "__main__":
258
  if ito_output_audio is not None:
259
  sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)
260
 
261
-
262
-
 
9
  currentdir = os.path.dirname(os.path.realpath(__file__))
10
  sys.path.append(os.path.dirname(currentdir))
11
  from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder
12
+ from modules.loss import AudioFeatureLoss, Loss
 
13
 
14
  class MasteringStyleTransfer:
15
  def __init__(self, args):
 
204
  else:
205
  print(f" {fx_params}")
206
 
207
+ def get_param_output_string(self, params):
208
+ if params is None:
209
+ return "No parameters available"
210
+
211
+ output = []
212
+ for fx_name, fx_params in params.items():
213
+ output.append(f"{fx_name.upper()}:")
214
+ if isinstance(fx_params, dict):
215
+ for param_name, param_value in fx_params.items():
216
+ if isinstance(param_value, torch.Tensor):
217
+ param_value = param_value.item()
218
+ output.append(f" {param_name}: {param_value:.4f}")
219
+ elif isinstance(fx_params, torch.Tensor):
220
+ output.append(f" {fx_params.item():.4f}")
221
+ else:
222
+ output.append(f" {fx_params:.4f}")
223
+
224
+ return "\n".join(output)
225
+
226
+ def get_top_10_diff_string(self, initial_params, ito_params):
227
+ if initial_params is None or ito_params is None:
228
+ return "Cannot compare parameters"
229
+
230
+ all_diffs = []
231
+ for fx_name in initial_params.keys():
232
+ if isinstance(initial_params[fx_name], dict):
233
+ for param_name in initial_params[fx_name].keys():
234
+ initial_value = initial_params[fx_name][param_name]
235
+ ito_value = ito_params[fx_name][param_name]
236
+
237
+ param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name]
238
+ normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0]))
239
+
240
+ all_diffs.append((fx_name, param_name, initial_value.item(), ito_value.item(), normalized_diff.item()))
241
+ else:
242
+ initial_value = initial_params[fx_name]
243
+ ito_value = ito_params[fx_name]
244
+ normalized_diff = abs(ito_value - initial_value)
245
+ all_diffs.append((fx_name, 'width', initial_value.item(), ito_value.item(), normalized_diff.item()))
246
+
247
+ top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:10]
248
+
249
+ output = ["Top 10 parameter differences (sorted by normalized difference):"]
250
+ for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs:
251
+ output.append(f"{fx_name.upper()} - {param_name}:")
252
+ output.append(f" Initial: {initial_value:.4f}")
253
+ output.append(f" ITO: {ito_value:.4f}")
254
+ output.append(f" Normalized Diff: {normalized_diff:.4f}")
255
+ output.append("")
256
+
257
+ return "\n".join(output)
258
+
259
+
260
  def reload_weights(model, ckpt_path, device):
261
  checkpoint = torch.load(ckpt_path, map_location=device)
262
 
 
267
  new_state_dict[name] = v
268
  model.load_state_dict(new_state_dict, strict=False)
269
 
270
+
271
  if __name__ == "__main__":
272
  basis_path = '/data2/tony/Mastering_Style_Transfer/results/dasp_tcn_tuneenc_daspman_loudnessnorm/ckpt/1000/'
273
 
 
311
  if ito_output_audio is not None:
312
  sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)
313
 
 
 
modules/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .front_back_end import *
2
+ from .loss import *
modules/front_back_end.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Front-end: processing raw data input """
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio.functional as ta_F
5
+ import torchaudio
6
+
7
+
8
+
9
+ class FrontEnd(nn.Module):
10
+ def __init__(self, channel='stereo', \
11
+ n_fft=2048, \
12
+ n_mels=128, \
13
+ sample_rate=44100, \
14
+ hop_length=None, \
15
+ win_length=None, \
16
+ window="hann", \
17
+ eps=1e-7, \
18
+ device=torch.device("cpu")):
19
+ super(FrontEnd, self).__init__()
20
+ self.channel = channel
21
+ self.n_fft = n_fft
22
+ self.n_mels = n_mels
23
+ self.sample_rate = sample_rate
24
+ self.hop_length = n_fft//4 if hop_length==None else hop_length
25
+ self.win_length = n_fft if win_length==None else win_length
26
+ self.eps = eps
27
+ if window=="hann":
28
+ self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(device)
29
+ elif window=="hamming":
30
+ self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(device)
31
+ self.melscale_transform = torchaudio.transforms.MelScale(n_mels=self.n_mels, \
32
+ sample_rate=self.sample_rate, \
33
+ n_stft=self.n_fft//2+1).to(device)
34
+
35
+
36
+ def forward(self, input, mode):
37
+ # front-end function which channel-wise combines all demanded features
38
+ # input shape : batch x channel x raw waveform
39
+ # output shape : batch x channel x frequency x time
40
+ phase_output = None
41
+
42
+ front_output_list = []
43
+ for cur_mode in mode:
44
+ # Real & Imaginary
45
+ if cur_mode=="cplx":
46
+ if self.channel=="mono":
47
+ output = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
48
+ elif self.channel=="stereo":
49
+ output_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
50
+ output_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
51
+ output = torch.cat((output_l, output_r), axis=-1)
52
+ if input.shape[-1] % round(self.n_fft/4) == 0:
53
+ output = output[:, :, :-1]
54
+ if self.n_fft % 2 == 0:
55
+ output = output[:, :-1]
56
+ front_output_list.append(output.permute(0, 3, 1, 2))
57
+ # Magnitude & Phase or Mel
58
+ elif "mag" in cur_mode or "mel" in cur_mode:
59
+ if self.channel=="mono":
60
+ cur_cplx = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, return_complex=True)
61
+ output = self.mag(cur_cplx).unsqueeze(-1)[..., 0:1]
62
+ if "mag_phase" in cur_mode:
63
+ phase = self.phase(cur_cplx)
64
+ if "mel" in cur_mode:
65
+ output = self.melscale_transform(output.squeeze(-1)).unsqueeze(-1)
66
+ elif self.channel=="stereo":
67
+ cplx_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, return_complex=True)
68
+ cplx_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, return_complex=True)
69
+ mag_l = self.mag(cplx_l).unsqueeze(-1)
70
+ mag_r = self.mag(cplx_r).unsqueeze(-1)
71
+ output = torch.cat((mag_l, mag_r), axis=-1)
72
+ if "mag_phase" in cur_mode:
73
+ phase_l = self.phase(cplx_l).unsqueeze(-1)
74
+ phase_r = self.phase(cplx_r).unsqueeze(-1)
75
+ output = torch.cat((mag_l, phase_l, mag_r, phase_r), axis=-1)
76
+ if "mel" in cur_mode:
77
+ output = torch.cat((self.melscale_transform(mag_l.squeeze(-1)).unsqueeze(-1), self.melscale_transform(mag_r.squeeze(-1)).unsqueeze(-1)), axis=-1)
78
+
79
+ if "log" in cur_mode:
80
+ output = torch.log(output+self.eps)
81
+
82
+ if input.shape[-1] % round(self.n_fft/4) == 0:
83
+ output = output[:, :, :-1]
84
+ if cur_mode!="mel" and self.n_fft % 2 == 0: # discard highest frequency
85
+ output = output[:, 1:]
86
+ front_output_list.append(output.permute(0, 3, 1, 2))
87
+
88
+ # combine all demanded features
89
+ if not front_output_list:
90
+ raise NameError("NameError at FrontEnd: check using features for front-end")
91
+ elif len(mode)!=1:
92
+ for i, cur_output in enumerate(front_output_list):
93
+ if i==0:
94
+ front_output = cur_output
95
+ else:
96
+ front_output = torch.cat((front_output, cur_output), axis=1)
97
+ else:
98
+ front_output = front_output_list[0]
99
+
100
+ return front_output
101
+
102
+
103
+ def mag(self, cplx_input, eps=1e-07):
104
+ # mag_summed = cplx_input.pow(2.).sum(-1) + eps
105
+ mag_summed = cplx_input.real.pow(2.) + cplx_input.imag.pow(2.) + eps
106
+ return mag_summed.pow(0.5)
107
+
108
+
109
+ def phase(self, cplx_input, ):
110
+ return torch.atan2(cplx_input.imag, cplx_input.real)
111
+ # return torch.angle(cplx_input)
112
+
113
+
114
+
115
+ class BackEnd(nn.Module):
116
+ def __init__(self, channel='stereo', \
117
+ n_fft=2048, \
118
+ hop_length=None, \
119
+ win_length=None, \
120
+ window="hann", \
121
+ eps=1e-07, \
122
+ orig_freq=44100, \
123
+ new_freq=16000, \
124
+ device=torch.device("cpu")):
125
+ super(BackEnd, self).__init__()
126
+ self.device = device
127
+ self.channel = channel
128
+ self.n_fft = n_fft
129
+ self.hop_length = n_fft//4 if hop_length==None else hop_length
130
+ self.win_length = n_fft if win_length==None else win_length
131
+ self.eps = eps
132
+ if window=="hann":
133
+ self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(self.device)
134
+ elif window=="hamming":
135
+ self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(self.device)
136
+ self.resample_func_8k = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=8000).to(self.device)
137
+ self.resample_func = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq).to(self.device)
138
+
139
+ def magphase_to_cplx(self, magphase_spec):
140
+ real = magphase_spec[..., 0] * torch.cos(magphase_spec[..., 1])
141
+ imaginary = magphase_spec[..., 0] * torch.sin(magphase_spec[..., 1])
142
+ return torch.cat((real.unsqueeze(-1), imaginary.unsqueeze(-1)), dim=-1)
143
+
144
+
145
+ def forward(self, input, phase, mode):
146
+ # back-end function which convert output spectrograms into waveform
147
+ # input shape : batch x channel x frequency x time
148
+ # output shape : batch x channel x raw waveform
149
+
150
+ # convert to shape : batch x frequency x time x channel
151
+ input = input.permute(0, 2, 3, 1)
152
+ # pad highest frequency
153
+ pad = torch.zeros((input.shape[0], 1, input.shape[2], input.shape[3])).to(self.device)
154
+ input = torch.cat((pad, input), dim=1)
155
+
156
+ back_output_list = []
157
+ channel_count = 0
158
+ for i, cur_mode in enumerate(mode):
159
+ # Real & Imaginary
160
+ if cur_mode=="cplx":
161
+ if self.channel=="mono":
162
+ output = ta_F.istft(input[...,channel_count:channel_count+2], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
163
+ channel_count += 2
164
+ elif self.channel=="stereo":
165
+ cplx_spec = torch.cat([input[...,channel_count:channel_count+2], input[...,channel_count+2:channel_count+4]], dim=0)
166
+ output_wav = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
167
+ output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
168
+ channel_count += 4
169
+ back_output_list.append(output)
170
+ # Magnitude & Phase
171
+ elif cur_mode=="mag_phase" or cur_mode=="mag":
172
+ if self.channel=="mono":
173
+ if cur_mode=="mag":
174
+ input_spec = torch.cat((input[...,channel_count:channel_count+1], phase), axis=-1)
175
+ channel_count += 1
176
+ else:
177
+ input_spec = input[...,channel_count:channel_count+2]
178
+ channel_count += 2
179
+ cplx_spec = self.magphase_to_cplx(input_spec)
180
+ output = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
181
+ elif self.channel=="stereo":
182
+ if cur_mode=="mag":
183
+ input_spec_l = torch.cat((input[...,channel_count:channel_count+1], phase[...,0:1]), axis=-1)
184
+ input_spec_r = torch.cat((input[...,channel_count+1:channel_count+2], phase[...,1:2]), axis=-1)
185
+ channel_count += 2
186
+ else:
187
+ input_spec_l = input[...,channel_count:channel_count+2]
188
+ input_spec_r = input[...,channel_count+2:channel_count+4]
189
+ channel_count += 4
190
+ cplx_spec_l = self.magphase_to_cplx(input_spec_l)
191
+ cplx_spec_r = self.magphase_to_cplx(input_spec_r)
192
+ cplx_spec = torch.cat([cplx_spec_l, cplx_spec_r], dim=0)
193
+ output_wav = torch.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
194
+ output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
195
+ channel_count += 4
196
+ back_output_list.append(output)
197
+ elif cur_mode=="griff":
198
+ if self.channel=="mono":
199
+ output = self.griffin_lim(input.squeeze(-1), input.device).unsqueeze(1)
200
+ # output = self.griff(input.permute(0, 3, 1, 2))
201
+ else:
202
+ output_l = self.griffin_lim(input[..., 0], input.device).unsqueeze(1)
203
+ output_r = self.griffin_lim(input[..., 1], input.device).unsqueeze(1)
204
+ output = torch.cat((output_l, output_r), axis=1)
205
+
206
+ back_output_list.append(output)
207
+
208
+ # combine all demanded feature outputs
209
+ if not back_output_list:
210
+ raise NameError("NameError at BackEnd: check using features for back-end")
211
+ elif len(mode)!=1:
212
+ for i, cur_output in enumerate(back_output_list):
213
+ if i==0:
214
+ back_output = cur_output
215
+ else:
216
+ back_output = torch.cat((back_output, cur_output), axis=1)
217
+ else:
218
+ back_output = back_output_list[0]
219
+
220
+ return back_output
221
+
222
+
223
+ def griffin_lim(self, l_est, gpu, n_iter=100):
224
+ l_est = l_est.cpu().detach()
225
+
226
+ l_est = torch.pow(l_est, 1/0.80)
227
+ # l_est [batch, channel, time]
228
+ l_mag = l_est.unsqueeze(-1)
229
+ l_phase = 2 * np.pi * torch.rand_like(l_mag) - np.pi
230
+ real = l_mag * torch.cos(l_phase)
231
+ imag = l_mag * torch.sin(l_phase)
232
+ S = torch.cat((real, imag), axis=-1)
233
+ S_mag = (real**2 + imag**2 + self.eps) ** 1/2
234
+ for i in range(n_iter):
235
+ x = ta_F.istft(S, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
236
+ S_new = torch.stft(x, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
237
+ S_new_phase = S_new/mag(S_new)
238
+ S = S_mag * S_new_phase
239
+ return x / torch.max(torch.abs(x))
240
+
modules/loss.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of objective functions used in the task 'ITO-Master'
3
+ """
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ import auraloss
9
+
10
+ import os
11
+ import sys
12
+ currentdir = os.path.dirname(os.path.realpath(__file__))
13
+ sys.path.append(os.path.dirname(currentdir))
14
+ from modules.front_back_end import *
15
+
16
+
17
+
18
+ # Root Mean Squared Loss
19
+ # penalizes the volume factor with non-linearlity
20
+ class RMSLoss(nn.Module):
21
+ def __init__(self, reduce, loss_type="l2"):
22
+ super(RMSLoss, self).__init__()
23
+ self.weight_factor = 100.
24
+ if loss_type=="l2":
25
+ self.loss = nn.MSELoss(reduce=None)
26
+
27
+
28
+ def forward(self, est_targets, targets):
29
+ est_targets = est_targets.reshape(est_targets.shape[0]*est_targets.shape[1], est_targets.shape[2])
30
+ targets = targets.reshape(targets.shape[0]*targets.shape[1], targets.shape[2])
31
+ normalized_est = torch.sqrt(torch.mean(est_targets**2, dim=-1))
32
+ normalized_tgt = torch.sqrt(torch.mean(targets**2, dim=-1))
33
+
34
+ weight = torch.clamp(torch.abs(normalized_tgt-normalized_est), min=1/self.weight_factor) * self.weight_factor
35
+
36
+ return torch.mean(weight**1.5 * self.loss(normalized_est, normalized_tgt))
37
+
38
+
39
+
40
+ # Multi-Scale Spectral Loss proposed at the paper "DDSP: DIFFERENTIABLE DIGITAL SIGNAL PROCESSING" (https://arxiv.org/abs/2001.04643)
41
+ # we extend this loss by applying it to mid/side channels
42
+ class MultiScale_Spectral_Loss_MidSide_DDSP(nn.Module):
43
+ def __init__(self, mode='midside', \
44
+ reduce=True, \
45
+ n_filters=None, \
46
+ windows_size=None, \
47
+ hops_size=None, \
48
+ window="hann", \
49
+ eps=1e-7, \
50
+ device=torch.device("cpu")):
51
+ super(MultiScale_Spectral_Loss_MidSide_DDSP, self).__init__()
52
+ self.mode = mode
53
+ self.eps = eps
54
+ self.mid_weight = 0.5 # value in the range of 0.0 ~ 1.0
55
+ self.logmag_weight = 0.1
56
+
57
+ if n_filters is None:
58
+ n_filters = [4096, 2048, 1024, 512]
59
+ if windows_size is None:
60
+ windows_size = [4096, 2048, 1024, 512]
61
+ if hops_size is None:
62
+ hops_size = [1024, 512, 256, 128]
63
+
64
+ self.multiscales = []
65
+ for i in range(len(windows_size)):
66
+ cur_scale = {'window_size' : float(windows_size[i])}
67
+ if self.mode=='midside':
68
+ cur_scale['front_end'] = FrontEnd(channel='mono', \
69
+ n_fft=n_filters[i], \
70
+ hop_length=hops_size[i], \
71
+ win_length=windows_size[i], \
72
+ window=window, \
73
+ device=device)
74
+ elif self.mode=='ori':
75
+ cur_scale['front_end'] = FrontEnd(channel='stereo', \
76
+ n_fft=n_filters[i], \
77
+ hop_length=hops_size[i], \
78
+ win_length=windows_size[i], \
79
+ window=window, \
80
+ device=device)
81
+ self.multiscales.append(cur_scale)
82
+
83
+ self.objective_l1 = nn.L1Loss(reduce=reduce)
84
+ self.objective_l2 = nn.MSELoss(reduce=reduce)
85
+
86
+
87
+ def forward(self, est_targets, targets):
88
+ if self.mode=='midside':
89
+ return self.forward_midside(est_targets, targets)
90
+ elif self.mode=='ori':
91
+ return self.forward_ori(est_targets, targets)
92
+
93
+
94
+ def forward_ori(self, est_targets, targets):
95
+ total_loss = 0.0
96
+ total_mag_loss = 0.0
97
+ total_logmag_loss = 0.0
98
+ for cur_scale in self.multiscales:
99
+ est_mag = cur_scale['front_end'](est_targets, mode=["mag"])
100
+ tgt_mag = cur_scale['front_end'](targets, mode=["mag"])
101
+
102
+ mag_loss = self.magnitude_loss(est_mag, tgt_mag)
103
+ logmag_loss = self.log_magnitude_loss(est_mag, tgt_mag)
104
+ total_mag_loss += mag_loss
105
+ total_logmag_loss += logmag_loss
106
+ # return total_loss
107
+ return (1-self.logmag_weight)*total_mag_loss + \
108
+ (self.logmag_weight)*total_logmag_loss
109
+
110
+
111
+ def forward_midside(self, est_targets, targets):
112
+ est_mid, est_side = self.to_mid_side(est_targets)
113
+ tgt_mid, tgt_side = self.to_mid_side(targets)
114
+ total_loss = 0.0
115
+ total_mag_loss = 0.0
116
+ total_logmag_loss = 0.0
117
+ for cur_scale in self.multiscales:
118
+ est_mid_mag = cur_scale['front_end'](est_mid, mode=["mag"])
119
+ est_side_mag = cur_scale['front_end'](est_side, mode=["mag"])
120
+ tgt_mid_mag = cur_scale['front_end'](tgt_mid, mode=["mag"])
121
+ tgt_side_mag = cur_scale['front_end'](tgt_side, mode=["mag"])
122
+
123
+ mag_loss = self.mid_weight*self.magnitude_loss(est_mid_mag, tgt_mid_mag) + \
124
+ (1-self.mid_weight)*self.magnitude_loss(est_side_mag, tgt_side_mag)
125
+ logmag_loss = self.mid_weight*self.log_magnitude_loss(est_mid_mag, tgt_mid_mag) + \
126
+ (1-self.mid_weight)*self.log_magnitude_loss(est_side_mag, tgt_side_mag)
127
+ total_mag_loss += mag_loss
128
+ total_logmag_loss += logmag_loss
129
+ # return total_loss
130
+ return (1-self.logmag_weight)*total_mag_loss + \
131
+ (self.logmag_weight)*total_logmag_loss
132
+
133
+
134
+ def to_mid_side(self, stereo_in):
135
+ mid = stereo_in[:,0] + stereo_in[:,1]
136
+ side = stereo_in[:,0] - stereo_in[:,1]
137
+ return mid, side
138
+
139
+
140
+ def magnitude_loss(self, est_mag_spec, tgt_mag_spec):
141
+ return torch.norm(self.objective_l1(est_mag_spec, tgt_mag_spec))
142
+
143
+
144
+ def log_magnitude_loss(self, est_mag_spec, tgt_mag_spec):
145
+ est_log_mag_spec = torch.log10(est_mag_spec+self.eps)
146
+ tgt_log_mag_spec = torch.log10(tgt_mag_spec+self.eps)
147
+ return self.objective_l2(est_log_mag_spec, tgt_log_mag_spec)
148
+
149
+
150
+
151
+ # Class of available loss functions
152
+ class Loss:
153
+ def __init__(self, args, reduce=True):
154
+ device = torch.device("cpu")
155
+ if torch.cuda.is_available():
156
+ device = torch.device(f"cuda:{args.gpu}")
157
+ self.l1 = nn.L1Loss(reduce=reduce)
158
+ self.mse = nn.MSELoss(reduce=reduce)
159
+ self.ce = nn.CrossEntropyLoss()
160
+ self.triplet = nn.TripletMarginLoss(margin=1., p=2)
161
+ self.cos = nn.CosineSimilarity(eps=args.eps)
162
+ self.cosemb = nn.CosineEmbeddingLoss()
163
+
164
+ self.multi_scale_spectral_midside = MultiScale_Spectral_Loss_MidSide_DDSP(mode='midside', eps=args.eps, device=device)
165
+ self.multi_scale_spectral_ori = MultiScale_Spectral_Loss_MidSide_DDSP(mode='ori', eps=args.eps, device=device)
166
+ self.gain = RMSLoss(reduce=reduce)
167
+ self.infonce = infoNCE
168
+ # perceptual weighting with mel scaled spectrograms
169
+ self.mrs_mel_perceptual = auraloss.freq.MultiResolutionSTFTLoss(
170
+ fft_sizes=[1024, 2048, 8192],
171
+ hop_sizes=[256, 512, 2048],
172
+ win_lengths=[1024, 2048, 8192],
173
+ scale="mel",
174
+ n_bins=128,
175
+ sample_rate=args.sample_rate,
176
+ perceptual_weighting=True,
177
+ )
178
+
179
+
180
+
181
+
182
+ """
183
+ Audio Feature Loss implementation
184
+ copied from https://github.com/sai-soum/Diff-MST/blob/main/mst/loss.py
185
+ """
186
+
187
+ import librosa
188
+
189
+ from typing import List
190
+ from modules.filter import barkscale_fbanks
191
+
192
+
193
+
194
+
195
+ def compute_mid_side(x: torch.Tensor):
196
+ x_mid = x[:, 0, :] + x[:, 1, :]
197
+ x_side = x[:, 0, :] - x[:, 1, :]
198
+ return x_mid, x_side
199
+
200
+
201
+ def compute_melspectrum(
202
+ x: torch.Tensor,
203
+ sample_rate: int = 44100,
204
+ fft_size: int = 32768,
205
+ n_bins: int = 128,
206
+ **kwargs,
207
+ ):
208
+ """Compute mel-spectrogram.
209
+
210
+ Args:
211
+ x: (bs, 2, seq_len)
212
+ sample_rate: sample rate of audio
213
+ fft_size: size of fft
214
+ n_bins: number of mel bins
215
+
216
+ Returns:
217
+ X: (bs, n_bins)
218
+
219
+ """
220
+ fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
221
+ fb = torch.tensor(fb).unsqueeze(0).type_as(x)
222
+
223
+ x = x.mean(dim=1, keepdim=True)
224
+ X = torch.fft.rfft(x, n=fft_size, dim=-1)
225
+ X = torch.abs(X)
226
+ X = torch.mean(X, dim=1, keepdim=True) # take mean over time
227
+ X = X.permute(0, 2, 1) # swap time and freq dims
228
+ X = torch.matmul(fb, X)
229
+ X = torch.log(X + 1e-8)
230
+
231
+ return X
232
+
233
+
234
+ def compute_barkspectrum(
235
+ x: torch.Tensor,
236
+ fft_size: int = 32768,
237
+ n_bands: int = 24,
238
+ sample_rate: int = 44100,
239
+ f_min: float = 20.0,
240
+ f_max: float = 20000.0,
241
+ mode: str = "mid-side",
242
+ **kwargs,
243
+ ):
244
+ """Compute bark-spectrogram.
245
+
246
+ Args:
247
+ x: (bs, 2, seq_len)
248
+ fft_size: size of fft
249
+ n_bands: number of bark bins
250
+ sample_rate: sample rate of audio
251
+ f_min: minimum frequency
252
+ f_max: maximum frequency
253
+ mode: "mono", "stereo", or "mid-side"
254
+
255
+ Returns:
256
+ X: (bs, 24)
257
+
258
+ """
259
+ # compute filterbank
260
+ fb = barkscale_fbanks((fft_size // 2) + 1, f_min, f_max, n_bands, sample_rate)
261
+ fb = fb.unsqueeze(0).type_as(x)
262
+ fb = fb.permute(0, 2, 1)
263
+
264
+ if mode == "mono":
265
+ x = x.mean(dim=1) # average over channels
266
+ signals = [x]
267
+ elif mode == "stereo":
268
+ signals = [x[:, 0, :], x[:, 1, :]]
269
+ elif mode == "mid-side":
270
+ x_mid = x[:, 0, :] + x[:, 1, :]
271
+ x_side = x[:, 0, :] - x[:, 1, :]
272
+ signals = [x_mid, x_side]
273
+ else:
274
+ raise ValueError(f"Invalid mode {mode}")
275
+
276
+ outputs = []
277
+ for signal in signals:
278
+ X = torch.stft(
279
+ signal,
280
+ n_fft=fft_size,
281
+ hop_length=fft_size // 4,
282
+ return_complex=True,
283
+ window=torch.hann_window(fft_size).to(x.device),
284
+ ) # compute stft
285
+ X = torch.abs(X) # take magnitude
286
+ X = torch.mean(X, dim=-1, keepdim=True) # take mean over time
287
+ # X = X.permute(0, 2, 1) # swap time and freq dims
288
+ X = torch.matmul(fb, X) # apply filterbank
289
+ X = torch.log(X + 1e-8)
290
+ # X = torch.cat([X, X_log], dim=-1)
291
+ outputs.append(X)
292
+
293
+ # stack into tensor
294
+ X = torch.cat(outputs, dim=-1)
295
+
296
+ return X
297
+
298
+
299
+ def compute_rms(x: torch.Tensor, **kwargs):
300
+ """Compute root mean square energy.
301
+
302
+ Args:
303
+ x: (bs, 1, seq_len)
304
+
305
+ Returns:
306
+ rms: (bs, )
307
+ """
308
+ rms = torch.sqrt(torch.mean(x**2, dim=-1).clamp(min=1e-8))
309
+ return rms
310
+
311
+
312
+ def compute_crest_factor(x: torch.Tensor, **kwargs):
313
+ """Compute crest factor as ratio of peak to rms energy in dB.
314
+
315
+ Args:
316
+ x: (bs, 2, seq_len)
317
+
318
+ """
319
+ num = torch.max(torch.abs(x), dim=-1)[0]
320
+ den = compute_rms(x).clamp(min=1e-8)
321
+ cf = 20 * torch.log10((num / den).clamp(min=1e-8))
322
+ return cf
323
+
324
+
325
+ def compute_stereo_width(x: torch.Tensor, **kwargs):
326
+ """Compute stereo width as ratio of energy in sum and difference signals.
327
+
328
+ Args:
329
+ x: (bs, 2, seq_len)
330
+
331
+ """
332
+ bs, chs, seq_len = x.size()
333
+
334
+ assert chs == 2, "Input must be stereo"
335
+
336
+ # compute sum and diff of stereo channels
337
+ x_sum = x[:, 0, :] + x[:, 1, :]
338
+ x_diff = x[:, 0, :] - x[:, 1, :]
339
+
340
+ # compute power of sum and diff
341
+ sum_energy = torch.mean(x_sum**2, dim=-1)
342
+ diff_energy = torch.mean(x_diff**2, dim=-1)
343
+
344
+ # compute stereo width as ratio
345
+ stereo_width = diff_energy / sum_energy.clamp(min=1e-8)
346
+
347
+ return stereo_width
348
+
349
+
350
+ def compute_stereo_imbalance(x: torch.Tensor, **kwargs):
351
+ """Compute stereo imbalance as ratio of energy in left and right channels.
352
+
353
+ Args:
354
+ x: (bs, 2, seq_len)
355
+
356
+ Returns:
357
+ stereo_imbalance: (bs, )
358
+
359
+ """
360
+ left_energy = torch.mean(x[:, 0, :] ** 2, dim=-1)
361
+ right_energy = torch.mean(x[:, 1, :] ** 2, dim=-1)
362
+
363
+ stereo_imbalance = (right_energy - left_energy) / (
364
+ right_energy + left_energy
365
+ ).clamp(min=1e-8)
366
+
367
+ return stereo_imbalance
368
+
369
+
370
+ class AudioFeatureLoss(torch.nn.Module):
371
+ def __init__(
372
+ self,
373
+ weights: List[float],
374
+ sample_rate: int,
375
+ stem_separation: bool = False,
376
+ use_clap: bool = False,
377
+ ) -> None:
378
+ """Compute loss using a set of differentiable audio features.
379
+
380
+ Args:
381
+ weights: weights for each feature
382
+ sample_rate: sample rate of audio
383
+ stem_separation: whether to compute loss on stems or mix
384
+
385
+ Based on features proposed in:
386
+
387
+ Man, B. D., et al.
388
+ "An analysis and evaluation of audio features for multitrack music mixtures."
389
+ (2014).
390
+
391
+ """
392
+ super().__init__()
393
+ self.weights = weights
394
+ self.sample_rate = sample_rate
395
+ self.stem_separation = stem_separation
396
+ self.sources_list = ["mix"]
397
+ self.source_weights = [1.0]
398
+ self.use_clap = use_clap
399
+
400
+ self.transforms = [
401
+ compute_rms,
402
+ compute_crest_factor,
403
+ compute_stereo_width,
404
+ compute_stereo_imbalance,
405
+ compute_barkspectrum,
406
+ ]
407
+
408
+ assert len(self.transforms) == len(weights)
409
+
410
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
411
+ losses = {}
412
+
413
+ # reshape for example stem dim
414
+ input_stems = input.unsqueeze(1)
415
+ target_stems = target.unsqueeze(1)
416
+
417
+ n_stems = input_stems.shape[1]
418
+
419
+ # iterate over each stem compute loss for each transform
420
+ for stem_idx in range(n_stems):
421
+ input_stem = input_stems[:, stem_idx, ...]
422
+ target_stem = target_stems[:, stem_idx, ...]
423
+
424
+ for transform, weight in zip(self.transforms, self.weights):
425
+ transform_name = "_".join(transform.__name__.split("_")[1:])
426
+ key = f"{self.sources_list[stem_idx]}-{transform_name}"
427
+ input_transform = transform(input_stem, sample_rate=self.sample_rate)
428
+ target_transform = transform(target_stem, sample_rate=self.sample_rate)
429
+ val = torch.nn.functional.mse_loss(input_transform, target_transform)
430
+ losses[key] = weight * val * self.source_weights[stem_idx]
431
+
432
+ return losses
networks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .architectures import *
2
+ from .network_utils import *
3
+ from .dasp_additionals import *
networks/architectures.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of neural networks used in the task 'Music Mastering Style Transfer'
3
+ - 'Effects Encoder'
4
+ - 'Mastering Style Transfer'
5
+ - 'Differentiable Mastering Style Transfer'
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.nn.init as init
11
+ import dasp_pytorch
12
+
13
+ import os
14
+ import sys
15
+ import time
16
+ currentdir = os.path.dirname(os.path.realpath(__file__))
17
+ sys.path.append(currentdir)
18
+ from network_utils import *
19
+ from dasp_additionals import Multiband_Compressor, Distortion, Limiter
20
+
21
+ # compute receptive field
22
+ def compute_receptive_field(kernels, strides, dilations):
23
+ rf = 0
24
+ for i in range(len(kernels)):
25
+ rf += rf * strides[i] + (kernels[i]-strides[i]) * dilations[i]
26
+ return rf
27
+
28
+
29
+
30
+ # Encoder of music effects for contrastive learning of music effects
31
+ class Effects_Encoder(nn.Module):
32
+ def __init__(self, config):
33
+ super(Effects_Encoder, self).__init__()
34
+ # input is stereo channeled audio
35
+ config["channels"].insert(0, 2)
36
+
37
+ # encoder layers
38
+ encoder = []
39
+ for i in range(len(config["kernels"])):
40
+ if config["conv_block"]=='res':
41
+ encoder.append(Res_ConvBlock(dimension=1, \
42
+ in_channels=config["channels"][i], \
43
+ out_channels=config["channels"][i+1], \
44
+ kernel_size=config["kernels"][i], \
45
+ stride=config["strides"][i], \
46
+ padding="SAME", \
47
+ dilation=config["dilation"][i], \
48
+ norm=config["norm"], \
49
+ activation=config["activation"], \
50
+ last_activation=config["activation"]))
51
+ elif config["conv_block"]=='conv':
52
+ encoder.append(ConvBlock(dimension=1, \
53
+ layer_num=1, \
54
+ in_channels=config["channels"][i], \
55
+ out_channels=config["channels"][i+1], \
56
+ kernel_size=config["kernels"][i], \
57
+ stride=config["strides"][i], \
58
+ padding="VALID", \
59
+ dilation=config["dilation"][i], \
60
+ norm=config["norm"], \
61
+ activation=config["activation"], \
62
+ last_activation=config["activation"], \
63
+ mode='conv'))
64
+ self.encoder = nn.Sequential(*encoder)
65
+
66
+ # pooling method
67
+ self.glob_pool = nn.AdaptiveAvgPool1d(1)
68
+
69
+
70
+ # network forward operation
71
+ def forward(self, input):
72
+ enc_output = self.encoder(input)
73
+ glob_pooled = self.glob_pool(enc_output).squeeze(-1)
74
+
75
+ # outputs c feature
76
+ return glob_pooled
77
+
78
+
79
+
80
+ class TCNBlock(torch.nn.Module):
81
+ def __init__(self,
82
+ in_ch,
83
+ out_ch,
84
+ kernel_size=3,
85
+ stride=1,
86
+ dilation=1,
87
+ cond_dim=2048,
88
+ grouped=False,
89
+ causal=False,
90
+ conditional=False,
91
+ **kwargs):
92
+ super(TCNBlock, self).__init__()
93
+
94
+ self.in_ch = in_ch
95
+ self.out_ch = out_ch
96
+ self.kernel_size = kernel_size
97
+ self.dilation = dilation
98
+ self.grouped = grouped
99
+ self.causal = causal
100
+ self.conditional = conditional
101
+
102
+ groups = out_ch if grouped and (in_ch % out_ch == 0) else 1
103
+
104
+ self.pad_length = ((kernel_size-1)*dilation) if self.causal else ((kernel_size-1)*dilation)//2
105
+ self.conv1 = torch.nn.Conv1d(in_ch,
106
+ out_ch,
107
+ kernel_size=kernel_size,
108
+ stride=stride,
109
+ padding=self.pad_length,
110
+ dilation=dilation,
111
+ groups=groups,
112
+ bias=False)
113
+ if grouped:
114
+ self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)
115
+
116
+ if conditional:
117
+ self.film = FiLM(cond_dim, out_ch)
118
+ self.bn = torch.nn.BatchNorm1d(out_ch)
119
+
120
+ self.relu = torch.nn.LeakyReLU()
121
+ self.res = torch.nn.Conv1d(in_ch,
122
+ out_ch,
123
+ kernel_size=1,
124
+ stride=stride,
125
+ groups=in_ch,
126
+ bias=False)
127
+
128
+
129
+ def forward(self, x, p):
130
+ x_in = x
131
+
132
+ x = self.relu(self.bn(self.conv1(x)))
133
+ x = self.film(x, p)
134
+
135
+ x_res = self.res(x_in)
136
+
137
+ if self.causal:
138
+ x = x[..., :-self.pad_length]
139
+ x += x_res
140
+
141
+ return x
142
+
143
+
144
+
145
+ import pytorch_lightning as pl
146
+ class TCNModel(pl.LightningModule):
147
+ """ Temporal convolutional network with conditioning module.
148
+ Args:
149
+ nparams (int): Number of conditioning parameters.
150
+ ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1
151
+ noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1
152
+ nblocks (int): Number of total TCN blocks. Default: 10
153
+ kernel_size (int): Width of the convolutional kernels. Default: 3
154
+ dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
155
+ channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2
156
+ channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
157
+ stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10
158
+ grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False
159
+ causal (bool): Causal TCN configuration does not consider future input values. Default: False
160
+ skip_connections (bool): Skip connections from each block to the output. Default: False
161
+ num_examples (int): Number of evaluation audio examples to log after each epochs. Default: 4
162
+ """
163
+ def __init__(self,
164
+ nparams,
165
+ ninputs=1,
166
+ noutputs=1,
167
+ nblocks=10,
168
+ kernel_size=3,
169
+ stride=1,
170
+ dilation_growth=1,
171
+ channel_growth=1,
172
+ channel_width=32,
173
+ stack_size=10,
174
+ cond_dim=2048,
175
+ grouped=False,
176
+ causal=False,
177
+ skip_connections=False,
178
+ num_examples=4,
179
+ save_dir=None,
180
+ **kwargs):
181
+ super(TCNModel, self).__init__()
182
+ self.save_hyperparameters()
183
+
184
+ self.blocks = torch.nn.ModuleList()
185
+ for n in range(nblocks):
186
+ in_ch = out_ch if n > 0 else ninputs
187
+
188
+ if self.hparams.channel_growth > 1:
189
+ out_ch = in_ch * self.hparams.channel_growth
190
+ else:
191
+ out_ch = self.hparams.channel_width
192
+
193
+ dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
194
+ cur_stride = stride[n] if isinstance(stride, list) else stride
195
+ self.blocks.append(TCNBlock(in_ch,
196
+ out_ch,
197
+ kernel_size=self.hparams.kernel_size,
198
+ stride=cur_stride,
199
+ dilation=dilation,
200
+ padding="same" if self.hparams.causal else "valid",
201
+ causal=self.hparams.causal,
202
+ cond_dim=cond_dim,
203
+ grouped=self.hparams.grouped,
204
+ conditional=True if self.hparams.nparams > 0 else False))
205
+
206
+ self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)
207
+
208
+ def forward(self, x, cond):
209
+ # iterate over blocks passing conditioning
210
+ for idx, block in enumerate(self.blocks):
211
+ # for SeFa
212
+ if isinstance(cond, list):
213
+ x = block(x, cond[idx])
214
+ else:
215
+ x = block(x, cond)
216
+ skips = 0
217
+
218
+ # out = torch.tanh(self.output(x + skips))
219
+ out = torch.clamp(self.output(x + skips), min=-1, max=1)
220
+
221
+ return out
222
+
223
+ def compute_receptive_field(self):
224
+ """ Compute the receptive field in samples."""
225
+ rf = self.hparams.kernel_size
226
+ for n in range(1,self.hparams.nblocks):
227
+ dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
228
+ rf = rf + ((self.hparams.kernel_size-1) * dilation)
229
+ return rf
230
+
231
+ # add any model hyperparameters here
232
+ @staticmethod
233
+ def add_model_specific_args(parent_parser):
234
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
235
+ # --- model related ---
236
+ parser.add_argument('--ninputs', type=int, default=1)
237
+ parser.add_argument('--noutputs', type=int, default=1)
238
+ parser.add_argument('--nblocks', type=int, default=4)
239
+ parser.add_argument('--kernel_size', type=int, default=5)
240
+ parser.add_argument('--dilation_growth', type=int, default=10)
241
+ parser.add_argument('--channel_growth', type=int, default=1)
242
+ parser.add_argument('--channel_width', type=int, default=32)
243
+ parser.add_argument('--stack_size', type=int, default=10)
244
+ parser.add_argument('--grouped', default=False, action='store_true')
245
+ parser.add_argument('--causal', default=False, action="store_true")
246
+ parser.add_argument('--skip_connections', default=False, action="store_true")
247
+
248
+ return parser
249
+
250
+
251
+
252
+ # Module for fitting SeFa parameters
253
+ class Dasp_Mastering_Style_Transfer(nn.Module):
254
+ def __init__(self, num_features, sample_rate, \
255
+ tgt_fx_names = ['eq', 'comp', 'imager', 'gain'], \
256
+ model_type='2mlp', \
257
+ config=None, \
258
+ batch_size=4):
259
+ super(Dasp_Mastering_Style_Transfer, self).__init__()
260
+ self.sample_rate = sample_rate
261
+ self.tgt_fx_names = tgt_fx_names
262
+
263
+ self.fx_processors = {}
264
+ self.last_predicted_params = None
265
+ for cur_fx in tgt_fx_names:
266
+ if cur_fx=='eq':
267
+ cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=sample_rate, \
268
+ min_gain_db = -20.0, \
269
+ max_gain_db = 20.0, \
270
+ min_q_factor = 0.1, \
271
+ max_q_factor=5.0)
272
+ elif cur_fx=='distortion':
273
+ cur_fx_module = Distortion(sample_rate=sample_rate,
274
+ min_gain_db = 0.0,
275
+ max_gain_db = 8.0)
276
+ elif cur_fx=='comp':
277
+ cur_fx_module = dasp_pytorch.Compressor(sample_rate=sample_rate)
278
+ elif cur_fx=='multiband_comp':
279
+ cur_fx_module = Multiband_Compressor(sample_rate=sample_rate)
280
+ elif cur_fx=='gain':
281
+ cur_fx_module = dasp_pytorch.Gain(sample_rate=sample_rate)
282
+ elif cur_fx=='imager':
283
+ continue
284
+ elif cur_fx=='limiter':
285
+ cur_fx_module = Limiter(sample_rate=sample_rate)
286
+ else:
287
+ raise AssertionError(f"current fx name ({cur_fx}) not found")
288
+ self.fx_processors[cur_fx] = cur_fx_module
289
+ total_num_param = sum([self.fx_processors[cur_fx].num_params for cur_fx in self.fx_processors])
290
+ if 'imager' in tgt_fx_names:
291
+ total_num_param += 1
292
+
293
+ ''' model architecture '''
294
+ self.model_type = model_type
295
+ if self.model_type.lower()=='tcn':
296
+ self.network = TCNModel(nparams=config["condition_dimension"], ninputs=2, \
297
+ noutputs=total_num_param, \
298
+ nblocks=config["nblocks"], \
299
+ dilation_growth=config["dilation_growth"], \
300
+ kernel_size=config["kernel_size"], \
301
+ stride=config['stride'], \
302
+ channel_width=config["channel_width"], \
303
+ stack_size=config["stack_size"], \
304
+ cond_dim=config["condition_dimension"], \
305
+ causal=config["causal"])
306
+ elif self.model_type.lower()=='ito':
307
+ self.params = torch.nn.Parameter(torch.ones((batch_size,total_num_param))*0.5)
308
+
309
+ # network forward operation
310
+ def forward(self, x, embedding):
311
+ # embedding mapper
312
+ if self.model_type.lower()=='tcn':
313
+ est_param = self.network(x, embedding)
314
+ est_param = est_param.mean(axis=-1)
315
+ elif self.model_type.lower()=='ito':
316
+ est_param = self.params
317
+ est_param = torch.clamp(est_param, min=0.0, max=1.0)
318
+
319
+ if self.model_type.lower()!='ito':
320
+ est_param = F.sigmoid(est_param)
321
+
322
+ self.last_predicted_params = est_param
323
+
324
+ # dafx chain
325
+ cur_param_idx = 0
326
+ for cur_fx in self.tgt_fx_names:
327
+ if cur_fx=='imager':
328
+ cur_param_count = 1
329
+ x = dasp_pytorch.functional.stereo_widener(x, \
330
+ sample_rate=self.sample_rate, \
331
+ width=est_param[:,cur_param_idx:cur_param_idx+1])
332
+ else:
333
+ cur_param_count = self.fx_processors[cur_fx].num_params
334
+ cur_input_param = est_param[:, cur_param_idx:cur_param_idx+cur_param_count]
335
+ x = self.fx_processors[cur_fx].process_normalized(x, cur_input_param)
336
+ # update param index
337
+ cur_param_idx += cur_param_count
338
+
339
+ return x
340
+
341
+
342
+ def reset_fx_chain(self, ):
343
+ self.fx_processors = {}
344
+ for cur_fx in self.tgt_fx_names:
345
+ if cur_fx=='eq':
346
+ cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=self.sample_rate, \
347
+ min_gain_db = -20.0, \
348
+ max_gain_db = 20.0, \
349
+ min_q_factor = 0.1, \
350
+ max_q_factor=5.0)
351
+ elif cur_fx=='distortion':
352
+ cur_fx_module = Distortion(sample_rate=self.sample_rate,
353
+ min_gain_db = 0.0,
354
+ max_gain_db = 8.0)
355
+ elif cur_fx=='comp':
356
+ cur_fx_module = dasp_pytorch.Compressor(sample_rate=self.sample_rate)
357
+ elif cur_fx=='multiband_comp':
358
+ cur_fx_module = Multiband_Compressor(sample_rate=self.sample_rate)
359
+ elif cur_fx=='gain':
360
+ cur_fx_module = dasp_pytorch.Gain(sample_rate=self.sample_rate)
361
+ elif cur_fx=='imager':
362
+ continue
363
+ elif cur_fx=='limiter':
364
+ cur_fx_module = Limiter(sample_rate=self.sample_rate)
365
+ else:
366
+ raise AssertionError(f"current fx name ({cur_fx}) not found")
367
+ self.fx_processors[cur_fx] = cur_fx_module
368
+
369
+ def get_last_predicted_params(self):
370
+ if self.last_predicted_params is None:
371
+ return None
372
+
373
+ params_dict = {}
374
+ cur_param_idx = 0
375
+
376
+ for cur_fx in self.tgt_fx_names:
377
+ if cur_fx == 'imager':
378
+ cur_param_count = 1
379
+ normalized_param = self.last_predicted_params[:, cur_param_idx:cur_param_idx+1]
380
+ original_param = self.denormalize_param(normalized_param, 0, 1)
381
+ params_dict[cur_fx] = original_param
382
+ else:
383
+ cur_param_count = self.fx_processors[cur_fx].num_params
384
+ normalized_params = self.last_predicted_params[:, cur_param_idx:cur_param_idx+cur_param_count]
385
+ original_params = self.denormalize_params(cur_fx, normalized_params)
386
+ params_dict[cur_fx] = original_params
387
+
388
+ cur_param_idx += cur_param_count
389
+
390
+ return params_dict
391
+
392
+ def denormalize_params(self, fx_name, normalized_params):
393
+ fx_processor = self.fx_processors[fx_name]
394
+ original_params = {}
395
+
396
+ for i, (param_name, (min_val, max_val)) in enumerate(fx_processor.param_ranges.items()):
397
+ original_param = self.denormalize_param(normalized_params[:, i:i+1], min_val, max_val)
398
+ original_params[param_name] = original_param
399
+
400
+ return original_params
401
+
402
+ @staticmethod
403
+ def denormalize_param(normalized_param, min_val, max_val):
404
+ return normalized_param * (max_val - min_val) + min_val
405
+
networks/dasp_additionals.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of differentiable mastering effects based on DASP-pytorch and torchcomp libraries
3
+ - Distortion
4
+ - Multiband Compressor
5
+ - Limiter
6
+ DASP-pytorch: https://github.com/csteinmetz1/dasp-pytorch
7
+ torchcomp: https://github.com/yoyololicon/torchcomp
8
+ """
9
+ import dasp_pytorch
10
+ from dasp_pytorch.modules import Processor
11
+ import torchcomp
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ import time
17
+
18
+
19
+ EPS = 1e-6
20
+
21
+ class Distortion(Processor):
22
+ def __init__(
23
+ self,
24
+ sample_rate: int,
25
+ min_gain_db: float = 0.0,
26
+ max_gain_db: float = 24.0,
27
+ ):
28
+ super().__init__()
29
+ self.sample_rate = sample_rate
30
+ self.process_fn = distortion
31
+ self.param_ranges = {
32
+ "drive_db": (min_gain_db, max_gain_db),
33
+ "parallel_weight_factor": (0.2, 0.7),
34
+ }
35
+ self.num_params = len(self.param_ranges)
36
+
37
+ def distortion(x: torch.Tensor,
38
+ sample_rate: int,
39
+ drive_db: torch.Tensor,
40
+ parallel_weight_factor: torch.Tensor()):
41
+ """Simple soft-clipping distortion with drive control.
42
+
43
+ Args:
44
+ x (torch.Tensor): Input audio tensor with shape (bs, chs, seq_len)
45
+ sample_rate (int): Audio sample rate.
46
+ drive_db (torch.Tensor): Drive in dB with shape (bs)
47
+
48
+ Returns:
49
+ torch.Tensor: Output audio tensor with shape (bs, chs, seq_len)
50
+
51
+ """
52
+ bs, chs, seq_len = x.size()
53
+ parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1)
54
+
55
+ # return torch.tanh(x * (10 ** (drive_db.view(bs, chs, -1) / 20.0))) -> wrong?
56
+ x_dist = torch.tanh(x * (10 ** (drive_db.view(bs, 1, 1) / 20.0)))
57
+
58
+ # parallel compuatation
59
+ return parallel_weight_factor * x_dist + (1-parallel_weight_factor) * x
60
+
61
+
62
+
63
+ class Multiband_Compressor(Processor):
64
+ def __init__(
65
+ self,
66
+ sample_rate: int,
67
+ min_threshold_db_comp: float = -60.0,
68
+ max_threshold_db_comp: float = 0.0-EPS,
69
+ min_ratio_comp: float = 1.0+EPS,
70
+ max_ratio_comp: float = 20.0,
71
+ min_attack_ms_comp: float = 5.0,
72
+ max_attack_ms_comp: float = 100.0,
73
+ min_release_ms_comp: float = 5.0,
74
+ max_release_ms_comp: float = 100.0,
75
+ min_threshold_db_exp: float = -60.0,
76
+ max_threshold_db_exp: float = 0.0-EPS,
77
+ min_ratio_exp: float = 0.0+EPS,
78
+ max_ratio_exp: float = 1.0-EPS,
79
+ min_attack_ms_exp: float = 5.0,
80
+ max_attack_ms_exp: float = 100.0,
81
+ min_release_ms_exp: float = 5.0,
82
+ max_release_ms_exp: float = 100.0,
83
+ ):
84
+ super().__init__()
85
+ self.sample_rate = sample_rate
86
+ self.process_fn = multiband_compressor
87
+ self.param_ranges = {
88
+ "low_cutoff": (20, 300),
89
+ "high_cutoff": (2000, 12000),
90
+ "parallel_weight_factor": (0.2, 0.7),
91
+
92
+ "low_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
93
+ "low_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp),
94
+ "low_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
95
+ "low_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp),
96
+ "low_shelf_at": (min_attack_ms_exp, max_attack_ms_exp),
97
+ "low_shelf_rt": (min_release_ms_exp, max_release_ms_exp),
98
+
99
+ "mid_band_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
100
+ "mid_band_comp_ratio": (min_ratio_comp, max_ratio_comp),
101
+ "mid_band_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
102
+ "mid_band_exp_ratio": (min_ratio_exp, max_ratio_exp),
103
+ "mid_band_at": (min_attack_ms_exp, max_attack_ms_exp),
104
+ "mid_band_rt": (min_release_ms_exp, max_release_ms_exp),
105
+
106
+ "high_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
107
+ "high_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp),
108
+ "high_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
109
+ "high_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp),
110
+ "high_shelf_at": (min_attack_ms_exp, max_attack_ms_exp),
111
+ "high_shelf_rt": (min_release_ms_exp, max_release_ms_exp),
112
+ }
113
+ self.num_params = len(self.param_ranges)
114
+
115
+
116
+
117
+ def linkwitz_riley_4th_order(
118
+ x: torch.Tensor,
119
+ cutoff_freq: torch.Tensor,
120
+ sample_rate: float,
121
+ filter_type: str):
122
+ q_factor = torch.ones(cutoff_freq.shape) / torch.sqrt(torch.tensor([2.0]))
123
+ gain_db = torch.zeros(cutoff_freq.shape)
124
+ q_factor = q_factor.to(x.device)
125
+ gain_db = gain_db.to(x.device)
126
+
127
+ b, a = dasp_pytorch.signal.biquad(
128
+ gain_db,
129
+ cutoff_freq,
130
+ q_factor,
131
+ sample_rate,
132
+ filter_type
133
+ )
134
+
135
+ del gain_db
136
+ del q_factor
137
+
138
+ eff_bs = x.size(0)
139
+ # six second order sections
140
+ sos = torch.cat((b, a), dim=-1).unsqueeze(1)
141
+
142
+ # apply filter twice to phase difference amounts of 360°
143
+ x = dasp_pytorch.signal.sosfilt_via_fsm(sos, x)
144
+ x_out = dasp_pytorch.signal.sosfilt_via_fsm(sos, x)
145
+
146
+ return x_out
147
+
148
+
149
+ def multiband_compressor(
150
+ x: torch.Tensor,
151
+ sample_rate: float,
152
+
153
+ low_cutoff: torch.Tensor,
154
+ high_cutoff: torch.Tensor,
155
+ parallel_weight_factor: torch.Tensor,
156
+
157
+ low_shelf_comp_thresh: torch.Tensor,
158
+ low_shelf_comp_ratio: torch.Tensor,
159
+ low_shelf_exp_thresh: torch.Tensor,
160
+ low_shelf_exp_ratio: torch.Tensor,
161
+ low_shelf_at: torch.Tensor,
162
+ low_shelf_rt: torch.Tensor,
163
+
164
+ mid_band_comp_thresh: torch.Tensor,
165
+ mid_band_comp_ratio: torch.Tensor,
166
+ mid_band_exp_thresh: torch.Tensor,
167
+ mid_band_exp_ratio: torch.Tensor,
168
+ mid_band_at: torch.Tensor,
169
+ mid_band_rt: torch.Tensor,
170
+
171
+ high_shelf_comp_thresh: torch.Tensor,
172
+ high_shelf_comp_ratio: torch.Tensor,
173
+ high_shelf_exp_thresh: torch.Tensor,
174
+ high_shelf_exp_ratio: torch.Tensor,
175
+ high_shelf_at: torch.Tensor,
176
+ high_shelf_rt: torch.Tensor,
177
+ ):
178
+ """Multiband (Three-band) Compressor.
179
+
180
+ Low-shelf -> Mid-band -> High-shelf
181
+
182
+ Args:
183
+ x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len)
184
+ sample_rate (float): Audio sample rate.
185
+ low_cutoff (torch.Tensor): Low-shelf filter cutoff frequency in Hz.
186
+ high_cutoff (torch.Tensor): High-shelf filter cutoff frequency in Hz.
187
+ low_shelf_comp_thresh (torch.Tensor):
188
+ low_shelf_comp_ratio (torch.Tensor):
189
+ low_shelf_exp_thresh (torch.Tensor):
190
+ low_shelf_exp_ratio (torch.Tensor):
191
+ low_shelf_at (torch.Tensor):
192
+ low_shelf_rt (torch.Tensor):
193
+ mid_band_comp_thresh (torch.Tensor):
194
+ mid_band_comp_ratio (torch.Tensor):
195
+ mid_band_exp_thresh (torch.Tensor):
196
+ mid_band_exp_ratio (torch.Tensor):
197
+ mid_band_at (torch.Tensor):
198
+ mid_band_rt (torch.Tensor):
199
+ high_shelf_comp_thresh (torch.Tensor):
200
+ high_shelf_comp_ratio (torch.Tensor):
201
+ high_shelf_exp_thresh (torch.Tensor):
202
+ high_shelf_exp_ratio (torch.Tensor):
203
+ high_shelf_at (torch.Tensor):
204
+ high_shelf_rt (torch.Tensor):
205
+
206
+ Returns:
207
+ y (torch.Tensor): Filtered signal.
208
+ """
209
+ bs, chs, seq_len = x.size()
210
+
211
+ low_cutoff = low_cutoff.view(-1, 1, 1)
212
+ high_cutoff = high_cutoff.view(-1, 1, 1)
213
+ parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1)
214
+
215
+ eff_bs = x.size(0)
216
+
217
+ ''' cross over filter '''
218
+ # Low-shelf band (low frequencies)
219
+ low_band = linkwitz_riley_4th_order(x, low_cutoff, sample_rate, filter_type="low_pass")
220
+ # High-shelf band (high frequencies)
221
+ high_band = linkwitz_riley_4th_order(x, high_cutoff, sample_rate, filter_type="high_pass")
222
+ # Mid-band (band-pass)
223
+ mid_band = x - low_band - high_band # Subtract low and high bands from original signal
224
+
225
+ ''' compressor '''
226
+ try:
227
+ x_out_low = low_band * torchcomp.compexp_gain(low_band.sum(axis=1).abs(),
228
+ comp_thresh=low_shelf_comp_thresh, \
229
+ comp_ratio=low_shelf_comp_ratio, \
230
+ exp_thresh=low_shelf_exp_thresh, \
231
+ exp_ratio=low_shelf_exp_ratio, \
232
+ at=torchcomp.ms2coef(low_shelf_at, sample_rate), \
233
+ rt=torchcomp.ms2coef(low_shelf_rt, sample_rate)).unsqueeze(1)
234
+ except:
235
+ x_out_low = low_band
236
+ print('\t!!!failed computing low-band compression!!!')
237
+ try:
238
+ x_out_high = high_band * torchcomp.compexp_gain(high_band.sum(axis=1).abs(),
239
+ comp_thresh=high_shelf_comp_thresh, \
240
+ comp_ratio=high_shelf_comp_ratio, \
241
+ exp_thresh=high_shelf_exp_thresh, \
242
+ exp_ratio=high_shelf_exp_ratio, \
243
+ at=torchcomp.ms2coef(high_shelf_at, sample_rate), \
244
+ rt=torchcomp.ms2coef(high_shelf_rt, sample_rate)).unsqueeze(1)
245
+ except:
246
+ x_out_high = high_band
247
+ print('\t!!!failed computing high-band compression!!!')
248
+ try:
249
+ x_out_mid = mid_band * torchcomp.compexp_gain(mid_band.sum(axis=1).abs(),
250
+ comp_thresh=mid_band_comp_thresh, \
251
+ comp_ratio=mid_band_comp_ratio, \
252
+ exp_thresh=mid_band_exp_thresh, \
253
+ exp_ratio=mid_band_exp_ratio, \
254
+ at=torchcomp.ms2coef(mid_band_at, sample_rate), \
255
+ rt=torchcomp.ms2coef(mid_band_rt, sample_rate)).unsqueeze(1)
256
+ except:
257
+ x_out_mid = mid_band
258
+ print('\t!!!failed computing mid-band compression!!!')
259
+ x_out = x_out_low + x_out_high + x_out_mid
260
+
261
+ # parallel computation
262
+ x_out = parallel_weight_factor * x_out + (1-parallel_weight_factor) * x
263
+
264
+ # move channels back
265
+ x_out = x_out.view(bs, chs, seq_len)
266
+
267
+ return x_out
268
+
269
+
270
+
271
+
272
+ class Limiter(Processor):
273
+ def __init__(
274
+ self,
275
+ sample_rate: int,
276
+ min_threshold_db: float = -60.0,
277
+ max_threshold_db: float = 0.0-EPS,
278
+ min_attack_ms: float = 5.0,
279
+ max_attack_ms: float = 100.0,
280
+ min_release_ms: float = 5.0,
281
+ max_release_ms: float = 100.0,
282
+ ):
283
+ super().__init__()
284
+ self.sample_rate = sample_rate
285
+ self.process_fn = limiter
286
+ self.param_ranges = {
287
+ "threshold": (min_threshold_db, max_threshold_db),
288
+ "at": (min_attack_ms, max_attack_ms),
289
+ "rt": (min_release_ms, max_release_ms),
290
+ }
291
+ self.num_params = len(self.param_ranges)
292
+
293
+
294
+ def limiter(
295
+ x: torch.Tensor,
296
+ sample_rate: float,
297
+ threshold: float,
298
+ at: float,
299
+ rt: float,
300
+ ):
301
+ """Limiter.
302
+
303
+ from Chin-yun's paper
304
+
305
+ Args:
306
+ x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len)
307
+ sample_rate (float): Audio sample rate.
308
+ threshold (torch.Tensor): Limiter threshold in dB.
309
+ at (torch.Tensor): Attack time.
310
+ rt (torch.Tensor): Release time.
311
+
312
+ Returns:
313
+ y (torch.Tensor): Limited signal.
314
+ """
315
+ bs, chs, seq_len = x.size()
316
+
317
+ x_out = x * torchcomp.limiter_gain(x.sum(axis=1).abs(),
318
+ threshold=threshold,
319
+ at=torchcomp.ms2coef(at, sample_rate),
320
+ rt=torchcomp.ms2coef(rt, sample_rate)).unsqueeze(1)
321
+
322
+ # move channels back
323
+ x_out = x_out.view(bs, chs, seq_len)
324
+
325
+ return x_out
326
+
327
+
328
+
329
+
330
+ class Random_Augmentation_Dasp(nn.Module):
331
+ def __init__(self, sample_rate, \
332
+ tgt_fx_names = ['eq', 'comp', 'imager', 'gain']):
333
+ super(Random_Augmentation_Dasp, self).__init__()
334
+ self.sample_rate = sample_rate
335
+ self.tgt_fx_names = tgt_fx_names
336
+
337
+ self.device = torch.device("cpu")
338
+ if torch.cuda.is_available():
339
+ self.device = torch.device(f"cuda")
340
+
341
+ self.fx_prob = {'eq': 0.9, \
342
+ 'distortion': 0.3, \
343
+ 'comp': 0.8, \
344
+ 'multiband_comp': 0.8, \
345
+ 'gain': 0.85, \
346
+ 'imager': 0.6, \
347
+ 'limiter': 1.0}
348
+ self.fx_processors = {}
349
+ for cur_fx in tgt_fx_names:
350
+ if cur_fx=='eq':
351
+ cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=sample_rate, \
352
+ min_gain_db = -10.0, \
353
+ max_gain_db = 10.0, \
354
+ min_q_factor = 0.5, \
355
+ max_q_factor=5.0)
356
+ elif cur_fx=='distortion':
357
+ cur_fx_module = Distortion(sample_rate=sample_rate,
358
+ min_gain_db = 0.0,
359
+ max_gain_db = 4.0)
360
+ elif cur_fx=='comp':
361
+ cur_fx_module = dasp_pytorch.Compressor(sample_rate=sample_rate)
362
+ elif cur_fx=='multiband_comp':
363
+ cur_fx_module = Multiband_Compressor(sample_rate=sample_rate,
364
+ min_threshold_db_comp = -30.0,
365
+ max_threshold_db_comp = -5.0,
366
+ min_ratio_comp = 1.5,
367
+ max_ratio_comp = 6.0,
368
+ min_attack_ms_comp = 1.0,
369
+ max_attack_ms_comp = 20.0,
370
+ min_release_ms_comp = 20.0,
371
+ max_release_ms_comp = 500.0,
372
+ min_threshold_db_exp = -30.0,
373
+ max_threshold_db_exp = -5.0,
374
+ min_ratio_exp = 0.0+EPS,
375
+ max_ratio_exp = 1.0-EPS,
376
+ min_attack_ms_exp = 1.0,
377
+ max_attack_ms_exp = 20.0,
378
+ min_release_ms_exp = 20.0,
379
+ max_release_ms_exp = 500.0,
380
+ )
381
+ elif cur_fx=='gain':
382
+ cur_fx_module = dasp_pytorch.Gain(sample_rate=sample_rate,
383
+ min_gain_db = 0.0,
384
+ max_gain_db = 6.0,)
385
+ elif cur_fx=='imager':
386
+ continue
387
+ elif cur_fx=='limiter':
388
+ cur_fx_module = Limiter(sample_rate=sample_rate,
389
+ min_threshold_db = -20.0,
390
+ max_threshold_db = 0.0-EPS,
391
+ min_attack_ms = 0.1,
392
+ max_attack_ms = 5.0,
393
+ min_release_ms = 20.0,
394
+ max_release_ms = 1000.0,)
395
+ else:
396
+ raise AssertionError(f"current fx name ({cur_fx}) not found")
397
+ self.fx_processors[cur_fx] = cur_fx_module
398
+ total_num_param = sum([self.fx_processors[cur_fx].num_params for cur_fx in self.fx_processors])
399
+ if 'imager' in tgt_fx_names:
400
+ total_num_param += 1
401
+ self.total_num_param = total_num_param
402
+
403
+
404
+ # network forward operation
405
+ def forward(self, x, rand_param=None, use_mask=None):
406
+ if rand_param==None:
407
+ rand_param = torch.rand((x.shape[0], self.total_num_param)).to(self.device)
408
+ else:
409
+ assert rand_param.shape[0]==x.shape[0] and rand_param.shape[1]==self.total_num_param
410
+ if use_mask==None:
411
+ use_mask = self.random_mask_generator(x.shape[0])
412
+
413
+ # dafx chain
414
+ cur_param_idx = 0
415
+ for cur_fx in self.tgt_fx_names:
416
+ cur_param_count = 1 if cur_fx=='imager' else self.fx_processors[cur_fx].num_params
417
+ if cur_fx=='imager':
418
+ x_processed = dasp_pytorch.functional.stereo_widener(x, \
419
+ sample_rate=self.sample_rate, \
420
+ width=rand_param[:,cur_param_idx:cur_param_idx+1])
421
+ else:
422
+ cur_input_param = rand_param[:, cur_param_idx:cur_param_idx+cur_param_count]
423
+ x_processed = self.fx_processors[cur_fx].process_normalized(x, cur_input_param)
424
+ # process all FX but decide to use the processed output based on probability
425
+ cur_mask = use_mask[cur_fx]
426
+ x = x_processed*cur_mask + x*~cur_mask
427
+ # update param index
428
+ cur_param_idx += cur_param_count
429
+
430
+ return x
431
+
432
+
433
+ def random_mask_generator(self, batch_size, repeat=1):
434
+ mask = {}
435
+ for cur_fx in self.tgt_fx_names:
436
+ mask[cur_fx] = self.fx_prob[cur_fx] > torch.rand(batch_size).view(-1, 1, 1)
437
+ if repeat>1:
438
+ mask[cur_fx] = mask[cur_fx].repeat(repeat, 1, 1)
439
+ mask[cur_fx] = mask[cur_fx].to(self.device)
440
+ return mask
441
+
networks/network_utils.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility File
3
+ containing functions for neural networks
4
+ """
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.nn.init as init
8
+ import torch
9
+ import torchaudio
10
+
11
+
12
+
13
+ # 2-dimensional convolutional layer
14
+ # in the order of conv -> norm -> activation
15
+ class Conv2d_layer(nn.Module):
16
+ def __init__(self, in_channels, out_channels, kernel_size, \
17
+ stride=1, \
18
+ padding="SAME", dilation=(1,1), bias=True, \
19
+ norm="batch", activation="relu", \
20
+ mode="conv"):
21
+ super(Conv2d_layer, self).__init__()
22
+
23
+ self.conv2d = nn.Sequential()
24
+
25
+ if isinstance(kernel_size, int):
26
+ kernel_size = [kernel_size, kernel_size]
27
+ if isinstance(stride, int):
28
+ stride = [stride, stride]
29
+ if isinstance(dilation, int):
30
+ dilation = [dilation, dilation]
31
+
32
+ ''' padding '''
33
+ if mode=="deconv":
34
+ padding = tuple(int((current_kernel - 1)/2) for current_kernel in kernel_size)
35
+ out_padding = tuple(0 if current_stride == 1 else 1 for current_stride in stride)
36
+ elif mode=="conv":
37
+ if padding == "SAME":
38
+ f_pad = int((kernel_size[0]-1) * dilation[0])
39
+ t_pad = int((kernel_size[1]-1) * dilation[1])
40
+ t_l_pad = int(t_pad//2)
41
+ t_r_pad = t_pad - t_l_pad
42
+ f_l_pad = int(f_pad//2)
43
+ f_r_pad = f_pad - f_l_pad
44
+ padding_area = (t_l_pad, t_r_pad, f_l_pad, f_r_pad)
45
+ elif padding == "VALID":
46
+ padding = 0
47
+ else:
48
+ pass
49
+
50
+ ''' convolutional layer '''
51
+ if mode=="deconv":
52
+ self.conv2d.add_module("deconv2d", nn.ConvTranspose2d(in_channels, out_channels, \
53
+ (kernel_size[0], kernel_size[1]), \
54
+ stride=stride, \
55
+ padding=padding, output_padding=out_padding, \
56
+ dilation=dilation, \
57
+ bias=bias))
58
+ elif mode=="conv":
59
+ self.conv2d.add_module(f"{mode}2d_pad", nn.ReflectionPad2d(padding_area))
60
+ self.conv2d.add_module(f"{mode}2d", nn.Conv2d(in_channels, out_channels, \
61
+ (kernel_size[0], kernel_size[1]), \
62
+ stride=stride, \
63
+ padding=0, \
64
+ dilation=dilation, \
65
+ bias=bias))
66
+
67
+ ''' normalization '''
68
+ if norm=="batch":
69
+ self.conv2d.add_module("batch_norm", nn.BatchNorm2d(out_channels))
70
+
71
+ ''' activation '''
72
+ if activation=="relu":
73
+ self.conv2d.add_module("relu", nn.ReLU())
74
+ elif activation=="lrelu":
75
+ self.conv2d.add_module("lrelu", nn.LeakyReLU())
76
+
77
+
78
+ def forward(self, input):
79
+ # input shape should be : batch x channel x height x width
80
+ output = self.conv2d(input)
81
+ return output
82
+
83
+
84
+
85
+ # 1-dimensional convolutional layer
86
+ # in the order of conv -> norm -> activation
87
+ class Conv1d_layer(nn.Module):
88
+ def __init__(self, in_channels, out_channels, kernel_size, \
89
+ stride=1, \
90
+ padding="SAME", dilation=1, bias=True, \
91
+ norm="batch", activation="relu", \
92
+ mode="conv"):
93
+ super(Conv1d_layer, self).__init__()
94
+
95
+ self.conv1d = nn.Sequential()
96
+
97
+ ''' padding '''
98
+ if mode=="deconv":
99
+ padding = int(dilation * (kernel_size-1) / 2)
100
+ out_padding = 0 if stride==1 else 1
101
+ elif mode=="conv" or "alias_free" in mode:
102
+ if padding == "SAME":
103
+ pad = int((kernel_size-1) * dilation)
104
+ l_pad = int(pad//2)
105
+ r_pad = pad - l_pad
106
+ padding_area = (l_pad, r_pad)
107
+ elif padding == "VALID":
108
+ padding_area = (0, 0)
109
+ else:
110
+ pass
111
+
112
+ ''' convolutional layer '''
113
+ if mode=="deconv":
114
+ self.conv1d.add_module("deconv1d", nn.ConvTranspose1d(in_channels, out_channels, kernel_size, \
115
+ stride=stride, padding=padding, output_padding=out_padding, \
116
+ dilation=dilation, \
117
+ bias=bias))
118
+ elif mode=="conv":
119
+ self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
120
+ self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
121
+ stride=stride, padding=0, \
122
+ dilation=dilation, \
123
+ bias=bias))
124
+ elif "alias_free" in mode:
125
+ if "up" in mode:
126
+ up_factor = stride * 2
127
+ down_factor = 2
128
+ elif "down" in mode:
129
+ up_factor = 2
130
+ down_factor = stride * 2
131
+ else:
132
+ raise ValueError("choose alias-free method : 'up' or 'down'")
133
+ # procedure : conv -> upsample -> lrelu -> low-pass filter -> downsample
134
+ # the torchaudio.transforms.Resample's default resampling_method is 'sinc_interpolation' which performs low-pass filter during the process
135
+ # details at https://pytorch.org/audio/stable/transforms.html
136
+ self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
137
+ self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
138
+ stride=1, padding=0, \
139
+ dilation=dilation, \
140
+ bias=bias))
141
+ self.conv1d.add_module(f"{mode}upsample", torchaudio.transforms.Resample(orig_freq=1, new_freq=up_factor))
142
+ self.conv1d.add_module(f"{mode}lrelu", nn.LeakyReLU())
143
+ self.conv1d.add_module(f"{mode}downsample", torchaudio.transforms.Resample(orig_freq=down_factor, new_freq=1))
144
+
145
+ ''' normalization '''
146
+ if norm=="batch":
147
+ self.conv1d.add_module("batch_norm", nn.BatchNorm1d(out_channels))
148
+ # self.conv1d.add_module("batch_norm", nn.SyncBatchNorm(out_channels))
149
+
150
+ ''' activation '''
151
+ if 'alias_free' not in mode:
152
+ if activation=="relu":
153
+ self.conv1d.add_module("relu", nn.ReLU())
154
+ elif activation=="lrelu":
155
+ self.conv1d.add_module("lrelu", nn.LeakyReLU())
156
+
157
+
158
+ def forward(self, input):
159
+ # input shape should be : batch x channel x height x width
160
+ output = self.conv1d(input)
161
+ return output
162
+
163
+
164
+
165
+ # Residual Block
166
+ # the input is added after the first convolutional layer, retaining its original channel size
167
+ # therefore, the second convolutional layer's output channel may differ
168
+ class Res_ConvBlock(nn.Module):
169
+ def __init__(self, dimension, \
170
+ in_channels, out_channels, \
171
+ kernel_size, \
172
+ stride=1, padding="SAME", \
173
+ dilation=1, \
174
+ bias=True, \
175
+ norm="batch", \
176
+ activation="relu", last_activation="relu", \
177
+ mode="conv"):
178
+ super(Res_ConvBlock, self).__init__()
179
+
180
+ if dimension==1:
181
+ self.conv1 = Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
182
+ self.conv2 = Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
183
+ elif dimension==2:
184
+ self.conv1 = Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
185
+ self.conv2 = Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
186
+
187
+
188
+ def forward(self, input):
189
+ c1_out = self.conv1(input) + input
190
+ c2_out = self.conv2(c1_out)
191
+ return c2_out
192
+
193
+
194
+
195
+ # Convoluaionl Block
196
+ # consists of multiple (number of layer_num) convolutional layers
197
+ # only the final convoluational layer outputs the desired 'out_channels'
198
+ class ConvBlock(nn.Module):
199
+ def __init__(self, dimension, layer_num, \
200
+ in_channels, out_channels, \
201
+ kernel_size, \
202
+ stride=1, padding="SAME", \
203
+ dilation=1, \
204
+ bias=True, \
205
+ norm="batch", \
206
+ activation="relu", last_activation="relu", \
207
+ mode="conv"):
208
+ super(ConvBlock, self).__init__()
209
+
210
+ conv_block = []
211
+ if dimension==1:
212
+ for i in range(layer_num-1):
213
+ conv_block.append(Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
214
+ conv_block.append(Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
215
+ elif dimension==2:
216
+ for i in range(layer_num-1):
217
+ conv_block.append(Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
218
+ conv_block.append(Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
219
+ self.conv_block = nn.Sequential(*conv_block)
220
+
221
+
222
+ def forward(self, input):
223
+ return self.conv_block(input)
224
+
225
+
226
+ # Feature-wise Linear Modulation
227
+ class FiLM(nn.Module):
228
+ def __init__(self, condition_len=2048, feature_len=1024):
229
+ super(FiLM, self).__init__()
230
+ self.film_fc = nn.Linear(condition_len, feature_len*2)
231
+ self.feat_len = feature_len
232
+
233
+
234
+ def forward(self, feature, condition, sefa=None):
235
+ # SeFA
236
+ if sefa:
237
+ weight = self.film_fc.weight.T
238
+ weight = weight / torch.linalg.norm((weight+1e-07), dim=0, keepdims=True)
239
+ eigen_values, eigen_vectors = torch.eig(torch.matmul(weight, weight.T), eigenvectors=True)
240
+
241
+ ####### custom parameters #######
242
+ chosen_eig_idx = sefa[0]
243
+ alpha = eigen_values[chosen_eig_idx][0] * sefa[1]
244
+ #################################
245
+
246
+ An = eigen_vectors[chosen_eig_idx].repeat(condition.shape[0], 1)
247
+ alpha_An = alpha * An
248
+
249
+ condition += alpha_An
250
+
251
+ film_factor = self.film_fc(condition).unsqueeze(-1)
252
+ r, b = torch.split(film_factor, self.feat_len, dim=1)
253
+ return r*feature + b
254
+