Spaces:
Sleeping
Sleeping
revise
Browse files- app.py +20 -3
- inference.py +55 -4
- modules/__init__.py +2 -0
- modules/front_back_end.py +240 -0
- modules/loss.py +432 -0
- networks/__init__.py +3 -0
- networks/architectures.py +405 -0
- networks/dasp_additionals.py +441 -0
- networks/network_utils.py +254 -0
app.py
CHANGED
@@ -25,7 +25,14 @@ def process_audio(input_audio, reference_audio, perform_ito):
|
|
25 |
if ito_output_audio is not None:
|
26 |
sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def process_youtube(input_url, reference_url, perform_ito):
|
31 |
input_audio = download_youtube_audio(input_url)
|
@@ -41,7 +48,12 @@ with gr.Blocks() as demo:
|
|
41 |
submit_button = gr.Button("Process")
|
42 |
output_audio = gr.Audio(label="Output Audio")
|
43 |
ito_output_audio = gr.Audio(label="ITO Output Audio")
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
with gr.Tab("YouTube URLs"):
|
47 |
input_url = gr.Textbox(label="Input YouTube URL")
|
@@ -50,6 +62,11 @@ with gr.Blocks() as demo:
|
|
50 |
submit_button_yt = gr.Button("Process")
|
51 |
output_audio_yt = gr.Audio(label="Output Audio")
|
52 |
ito_output_audio_yt = gr.Audio(label="ITO Output Audio")
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
demo.launch()
|
|
|
25 |
if ito_output_audio is not None:
|
26 |
sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)
|
27 |
|
28 |
+
# Generate parameter output strings
|
29 |
+
param_output = mastering_transfer.get_param_output_string(predicted_params)
|
30 |
+
ito_param_output = mastering_transfer.get_param_output_string(ito_predicted_params) if ito_predicted_params is not None else "ITO not performed"
|
31 |
+
|
32 |
+
# Generate top 10 differences if ITO was performed
|
33 |
+
top_10_diff = mastering_transfer.get_top_10_diff_string(predicted_params, ito_predicted_params) if ito_predicted_params is not None else "ITO not performed"
|
34 |
+
|
35 |
+
return "output_mastered.wav", "ito_output_mastered.wav" if ito_output_audio is not None else None, param_output, ito_param_output, top_10_diff
|
36 |
|
37 |
def process_youtube(input_url, reference_url, perform_ito):
|
38 |
input_audio = download_youtube_audio(input_url)
|
|
|
48 |
submit_button = gr.Button("Process")
|
49 |
output_audio = gr.Audio(label="Output Audio")
|
50 |
ito_output_audio = gr.Audio(label="ITO Output Audio")
|
51 |
+
param_output = gr.Textbox(label="Predicted Parameters", lines=10)
|
52 |
+
ito_param_output = gr.Textbox(label="ITO Predicted Parameters", lines=10)
|
53 |
+
top_10_diff = gr.Textbox(label="Top 10 Parameter Differences", lines=10)
|
54 |
+
submit_button.click(process_audio,
|
55 |
+
inputs=[input_audio, reference_audio, perform_ito],
|
56 |
+
outputs=[output_audio, ito_output_audio, param_output, ito_param_output, top_10_diff])
|
57 |
|
58 |
with gr.Tab("YouTube URLs"):
|
59 |
input_url = gr.Textbox(label="Input YouTube URL")
|
|
|
62 |
submit_button_yt = gr.Button("Process")
|
63 |
output_audio_yt = gr.Audio(label="Output Audio")
|
64 |
ito_output_audio_yt = gr.Audio(label="ITO Output Audio")
|
65 |
+
param_output_yt = gr.Textbox(label="Predicted Parameters", lines=10)
|
66 |
+
ito_param_output_yt = gr.Textbox(label="ITO Predicted Parameters", lines=10)
|
67 |
+
top_10_diff_yt = gr.Textbox(label="Top 10 Parameter Differences", lines=10)
|
68 |
+
submit_button_yt.click(process_youtube,
|
69 |
+
inputs=[input_url, reference_url, perform_ito_yt],
|
70 |
+
outputs=[output_audio_yt, ito_output_audio_yt, param_output_yt, ito_param_output_yt, top_10_diff_yt])
|
71 |
|
72 |
demo.launch()
|
inference.py
CHANGED
@@ -9,8 +9,7 @@ import sys
|
|
9 |
currentdir = os.path.dirname(os.path.realpath(__file__))
|
10 |
sys.path.append(os.path.dirname(currentdir))
|
11 |
from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder
|
12 |
-
from modules import
|
13 |
-
from modules.loss import AudioFeatureLoss
|
14 |
|
15 |
class MasteringStyleTransfer:
|
16 |
def __init__(self, args):
|
@@ -205,6 +204,59 @@ class MasteringStyleTransfer:
|
|
205 |
else:
|
206 |
print(f" {fx_params}")
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
def reload_weights(model, ckpt_path, device):
|
209 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
210 |
|
@@ -215,6 +267,7 @@ def reload_weights(model, ckpt_path, device):
|
|
215 |
new_state_dict[name] = v
|
216 |
model.load_state_dict(new_state_dict, strict=False)
|
217 |
|
|
|
218 |
if __name__ == "__main__":
|
219 |
basis_path = '/data2/tony/Mastering_Style_Transfer/results/dasp_tcn_tuneenc_daspman_loudnessnorm/ckpt/1000/'
|
220 |
|
@@ -258,5 +311,3 @@ if __name__ == "__main__":
|
|
258 |
if ito_output_audio is not None:
|
259 |
sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)
|
260 |
|
261 |
-
|
262 |
-
|
|
|
9 |
currentdir = os.path.dirname(os.path.realpath(__file__))
|
10 |
sys.path.append(os.path.dirname(currentdir))
|
11 |
from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder
|
12 |
+
from modules.loss import AudioFeatureLoss, Loss
|
|
|
13 |
|
14 |
class MasteringStyleTransfer:
|
15 |
def __init__(self, args):
|
|
|
204 |
else:
|
205 |
print(f" {fx_params}")
|
206 |
|
207 |
+
def get_param_output_string(self, params):
|
208 |
+
if params is None:
|
209 |
+
return "No parameters available"
|
210 |
+
|
211 |
+
output = []
|
212 |
+
for fx_name, fx_params in params.items():
|
213 |
+
output.append(f"{fx_name.upper()}:")
|
214 |
+
if isinstance(fx_params, dict):
|
215 |
+
for param_name, param_value in fx_params.items():
|
216 |
+
if isinstance(param_value, torch.Tensor):
|
217 |
+
param_value = param_value.item()
|
218 |
+
output.append(f" {param_name}: {param_value:.4f}")
|
219 |
+
elif isinstance(fx_params, torch.Tensor):
|
220 |
+
output.append(f" {fx_params.item():.4f}")
|
221 |
+
else:
|
222 |
+
output.append(f" {fx_params:.4f}")
|
223 |
+
|
224 |
+
return "\n".join(output)
|
225 |
+
|
226 |
+
def get_top_10_diff_string(self, initial_params, ito_params):
|
227 |
+
if initial_params is None or ito_params is None:
|
228 |
+
return "Cannot compare parameters"
|
229 |
+
|
230 |
+
all_diffs = []
|
231 |
+
for fx_name in initial_params.keys():
|
232 |
+
if isinstance(initial_params[fx_name], dict):
|
233 |
+
for param_name in initial_params[fx_name].keys():
|
234 |
+
initial_value = initial_params[fx_name][param_name]
|
235 |
+
ito_value = ito_params[fx_name][param_name]
|
236 |
+
|
237 |
+
param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name]
|
238 |
+
normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0]))
|
239 |
+
|
240 |
+
all_diffs.append((fx_name, param_name, initial_value.item(), ito_value.item(), normalized_diff.item()))
|
241 |
+
else:
|
242 |
+
initial_value = initial_params[fx_name]
|
243 |
+
ito_value = ito_params[fx_name]
|
244 |
+
normalized_diff = abs(ito_value - initial_value)
|
245 |
+
all_diffs.append((fx_name, 'width', initial_value.item(), ito_value.item(), normalized_diff.item()))
|
246 |
+
|
247 |
+
top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:10]
|
248 |
+
|
249 |
+
output = ["Top 10 parameter differences (sorted by normalized difference):"]
|
250 |
+
for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs:
|
251 |
+
output.append(f"{fx_name.upper()} - {param_name}:")
|
252 |
+
output.append(f" Initial: {initial_value:.4f}")
|
253 |
+
output.append(f" ITO: {ito_value:.4f}")
|
254 |
+
output.append(f" Normalized Diff: {normalized_diff:.4f}")
|
255 |
+
output.append("")
|
256 |
+
|
257 |
+
return "\n".join(output)
|
258 |
+
|
259 |
+
|
260 |
def reload_weights(model, ckpt_path, device):
|
261 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
262 |
|
|
|
267 |
new_state_dict[name] = v
|
268 |
model.load_state_dict(new_state_dict, strict=False)
|
269 |
|
270 |
+
|
271 |
if __name__ == "__main__":
|
272 |
basis_path = '/data2/tony/Mastering_Style_Transfer/results/dasp_tcn_tuneenc_daspman_loudnessnorm/ckpt/1000/'
|
273 |
|
|
|
311 |
if ito_output_audio is not None:
|
312 |
sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)
|
313 |
|
|
|
|
modules/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .front_back_end import *
|
2 |
+
from .loss import *
|
modules/front_back_end.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Front-end: processing raw data input """
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchaudio.functional as ta_F
|
5 |
+
import torchaudio
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class FrontEnd(nn.Module):
|
10 |
+
def __init__(self, channel='stereo', \
|
11 |
+
n_fft=2048, \
|
12 |
+
n_mels=128, \
|
13 |
+
sample_rate=44100, \
|
14 |
+
hop_length=None, \
|
15 |
+
win_length=None, \
|
16 |
+
window="hann", \
|
17 |
+
eps=1e-7, \
|
18 |
+
device=torch.device("cpu")):
|
19 |
+
super(FrontEnd, self).__init__()
|
20 |
+
self.channel = channel
|
21 |
+
self.n_fft = n_fft
|
22 |
+
self.n_mels = n_mels
|
23 |
+
self.sample_rate = sample_rate
|
24 |
+
self.hop_length = n_fft//4 if hop_length==None else hop_length
|
25 |
+
self.win_length = n_fft if win_length==None else win_length
|
26 |
+
self.eps = eps
|
27 |
+
if window=="hann":
|
28 |
+
self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(device)
|
29 |
+
elif window=="hamming":
|
30 |
+
self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(device)
|
31 |
+
self.melscale_transform = torchaudio.transforms.MelScale(n_mels=self.n_mels, \
|
32 |
+
sample_rate=self.sample_rate, \
|
33 |
+
n_stft=self.n_fft//2+1).to(device)
|
34 |
+
|
35 |
+
|
36 |
+
def forward(self, input, mode):
|
37 |
+
# front-end function which channel-wise combines all demanded features
|
38 |
+
# input shape : batch x channel x raw waveform
|
39 |
+
# output shape : batch x channel x frequency x time
|
40 |
+
phase_output = None
|
41 |
+
|
42 |
+
front_output_list = []
|
43 |
+
for cur_mode in mode:
|
44 |
+
# Real & Imaginary
|
45 |
+
if cur_mode=="cplx":
|
46 |
+
if self.channel=="mono":
|
47 |
+
output = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
|
48 |
+
elif self.channel=="stereo":
|
49 |
+
output_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
|
50 |
+
output_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
|
51 |
+
output = torch.cat((output_l, output_r), axis=-1)
|
52 |
+
if input.shape[-1] % round(self.n_fft/4) == 0:
|
53 |
+
output = output[:, :, :-1]
|
54 |
+
if self.n_fft % 2 == 0:
|
55 |
+
output = output[:, :-1]
|
56 |
+
front_output_list.append(output.permute(0, 3, 1, 2))
|
57 |
+
# Magnitude & Phase or Mel
|
58 |
+
elif "mag" in cur_mode or "mel" in cur_mode:
|
59 |
+
if self.channel=="mono":
|
60 |
+
cur_cplx = torch.stft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, return_complex=True)
|
61 |
+
output = self.mag(cur_cplx).unsqueeze(-1)[..., 0:1]
|
62 |
+
if "mag_phase" in cur_mode:
|
63 |
+
phase = self.phase(cur_cplx)
|
64 |
+
if "mel" in cur_mode:
|
65 |
+
output = self.melscale_transform(output.squeeze(-1)).unsqueeze(-1)
|
66 |
+
elif self.channel=="stereo":
|
67 |
+
cplx_l = torch.stft(input[:,0], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, return_complex=True)
|
68 |
+
cplx_r = torch.stft(input[:,1], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, return_complex=True)
|
69 |
+
mag_l = self.mag(cplx_l).unsqueeze(-1)
|
70 |
+
mag_r = self.mag(cplx_r).unsqueeze(-1)
|
71 |
+
output = torch.cat((mag_l, mag_r), axis=-1)
|
72 |
+
if "mag_phase" in cur_mode:
|
73 |
+
phase_l = self.phase(cplx_l).unsqueeze(-1)
|
74 |
+
phase_r = self.phase(cplx_r).unsqueeze(-1)
|
75 |
+
output = torch.cat((mag_l, phase_l, mag_r, phase_r), axis=-1)
|
76 |
+
if "mel" in cur_mode:
|
77 |
+
output = torch.cat((self.melscale_transform(mag_l.squeeze(-1)).unsqueeze(-1), self.melscale_transform(mag_r.squeeze(-1)).unsqueeze(-1)), axis=-1)
|
78 |
+
|
79 |
+
if "log" in cur_mode:
|
80 |
+
output = torch.log(output+self.eps)
|
81 |
+
|
82 |
+
if input.shape[-1] % round(self.n_fft/4) == 0:
|
83 |
+
output = output[:, :, :-1]
|
84 |
+
if cur_mode!="mel" and self.n_fft % 2 == 0: # discard highest frequency
|
85 |
+
output = output[:, 1:]
|
86 |
+
front_output_list.append(output.permute(0, 3, 1, 2))
|
87 |
+
|
88 |
+
# combine all demanded features
|
89 |
+
if not front_output_list:
|
90 |
+
raise NameError("NameError at FrontEnd: check using features for front-end")
|
91 |
+
elif len(mode)!=1:
|
92 |
+
for i, cur_output in enumerate(front_output_list):
|
93 |
+
if i==0:
|
94 |
+
front_output = cur_output
|
95 |
+
else:
|
96 |
+
front_output = torch.cat((front_output, cur_output), axis=1)
|
97 |
+
else:
|
98 |
+
front_output = front_output_list[0]
|
99 |
+
|
100 |
+
return front_output
|
101 |
+
|
102 |
+
|
103 |
+
def mag(self, cplx_input, eps=1e-07):
|
104 |
+
# mag_summed = cplx_input.pow(2.).sum(-1) + eps
|
105 |
+
mag_summed = cplx_input.real.pow(2.) + cplx_input.imag.pow(2.) + eps
|
106 |
+
return mag_summed.pow(0.5)
|
107 |
+
|
108 |
+
|
109 |
+
def phase(self, cplx_input, ):
|
110 |
+
return torch.atan2(cplx_input.imag, cplx_input.real)
|
111 |
+
# return torch.angle(cplx_input)
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
class BackEnd(nn.Module):
|
116 |
+
def __init__(self, channel='stereo', \
|
117 |
+
n_fft=2048, \
|
118 |
+
hop_length=None, \
|
119 |
+
win_length=None, \
|
120 |
+
window="hann", \
|
121 |
+
eps=1e-07, \
|
122 |
+
orig_freq=44100, \
|
123 |
+
new_freq=16000, \
|
124 |
+
device=torch.device("cpu")):
|
125 |
+
super(BackEnd, self).__init__()
|
126 |
+
self.device = device
|
127 |
+
self.channel = channel
|
128 |
+
self.n_fft = n_fft
|
129 |
+
self.hop_length = n_fft//4 if hop_length==None else hop_length
|
130 |
+
self.win_length = n_fft if win_length==None else win_length
|
131 |
+
self.eps = eps
|
132 |
+
if window=="hann":
|
133 |
+
self.window = torch.hann_window(window_length=self.win_length, periodic=True).to(self.device)
|
134 |
+
elif window=="hamming":
|
135 |
+
self.window = torch.hamming_window(window_length=self.win_length, periodic=True).to(self.device)
|
136 |
+
self.resample_func_8k = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=8000).to(self.device)
|
137 |
+
self.resample_func = torchaudio.transforms.Resample(orig_freq=orig_freq, new_freq=new_freq).to(self.device)
|
138 |
+
|
139 |
+
def magphase_to_cplx(self, magphase_spec):
|
140 |
+
real = magphase_spec[..., 0] * torch.cos(magphase_spec[..., 1])
|
141 |
+
imaginary = magphase_spec[..., 0] * torch.sin(magphase_spec[..., 1])
|
142 |
+
return torch.cat((real.unsqueeze(-1), imaginary.unsqueeze(-1)), dim=-1)
|
143 |
+
|
144 |
+
|
145 |
+
def forward(self, input, phase, mode):
|
146 |
+
# back-end function which convert output spectrograms into waveform
|
147 |
+
# input shape : batch x channel x frequency x time
|
148 |
+
# output shape : batch x channel x raw waveform
|
149 |
+
|
150 |
+
# convert to shape : batch x frequency x time x channel
|
151 |
+
input = input.permute(0, 2, 3, 1)
|
152 |
+
# pad highest frequency
|
153 |
+
pad = torch.zeros((input.shape[0], 1, input.shape[2], input.shape[3])).to(self.device)
|
154 |
+
input = torch.cat((pad, input), dim=1)
|
155 |
+
|
156 |
+
back_output_list = []
|
157 |
+
channel_count = 0
|
158 |
+
for i, cur_mode in enumerate(mode):
|
159 |
+
# Real & Imaginary
|
160 |
+
if cur_mode=="cplx":
|
161 |
+
if self.channel=="mono":
|
162 |
+
output = ta_F.istft(input[...,channel_count:channel_count+2], n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
|
163 |
+
channel_count += 2
|
164 |
+
elif self.channel=="stereo":
|
165 |
+
cplx_spec = torch.cat([input[...,channel_count:channel_count+2], input[...,channel_count+2:channel_count+4]], dim=0)
|
166 |
+
output_wav = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
|
167 |
+
output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
|
168 |
+
channel_count += 4
|
169 |
+
back_output_list.append(output)
|
170 |
+
# Magnitude & Phase
|
171 |
+
elif cur_mode=="mag_phase" or cur_mode=="mag":
|
172 |
+
if self.channel=="mono":
|
173 |
+
if cur_mode=="mag":
|
174 |
+
input_spec = torch.cat((input[...,channel_count:channel_count+1], phase), axis=-1)
|
175 |
+
channel_count += 1
|
176 |
+
else:
|
177 |
+
input_spec = input[...,channel_count:channel_count+2]
|
178 |
+
channel_count += 2
|
179 |
+
cplx_spec = self.magphase_to_cplx(input_spec)
|
180 |
+
output = ta_F.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window).unsqueeze(1)
|
181 |
+
elif self.channel=="stereo":
|
182 |
+
if cur_mode=="mag":
|
183 |
+
input_spec_l = torch.cat((input[...,channel_count:channel_count+1], phase[...,0:1]), axis=-1)
|
184 |
+
input_spec_r = torch.cat((input[...,channel_count+1:channel_count+2], phase[...,1:2]), axis=-1)
|
185 |
+
channel_count += 2
|
186 |
+
else:
|
187 |
+
input_spec_l = input[...,channel_count:channel_count+2]
|
188 |
+
input_spec_r = input[...,channel_count+2:channel_count+4]
|
189 |
+
channel_count += 4
|
190 |
+
cplx_spec_l = self.magphase_to_cplx(input_spec_l)
|
191 |
+
cplx_spec_r = self.magphase_to_cplx(input_spec_r)
|
192 |
+
cplx_spec = torch.cat([cplx_spec_l, cplx_spec_r], dim=0)
|
193 |
+
output_wav = torch.istft(cplx_spec, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window)
|
194 |
+
output = torch.cat((output_wav[:output_wav.shape[0]//2].unsqueeze(1), output_wav[output_wav.shape[0]//2:].unsqueeze(1)), dim=1)
|
195 |
+
channel_count += 4
|
196 |
+
back_output_list.append(output)
|
197 |
+
elif cur_mode=="griff":
|
198 |
+
if self.channel=="mono":
|
199 |
+
output = self.griffin_lim(input.squeeze(-1), input.device).unsqueeze(1)
|
200 |
+
# output = self.griff(input.permute(0, 3, 1, 2))
|
201 |
+
else:
|
202 |
+
output_l = self.griffin_lim(input[..., 0], input.device).unsqueeze(1)
|
203 |
+
output_r = self.griffin_lim(input[..., 1], input.device).unsqueeze(1)
|
204 |
+
output = torch.cat((output_l, output_r), axis=1)
|
205 |
+
|
206 |
+
back_output_list.append(output)
|
207 |
+
|
208 |
+
# combine all demanded feature outputs
|
209 |
+
if not back_output_list:
|
210 |
+
raise NameError("NameError at BackEnd: check using features for back-end")
|
211 |
+
elif len(mode)!=1:
|
212 |
+
for i, cur_output in enumerate(back_output_list):
|
213 |
+
if i==0:
|
214 |
+
back_output = cur_output
|
215 |
+
else:
|
216 |
+
back_output = torch.cat((back_output, cur_output), axis=1)
|
217 |
+
else:
|
218 |
+
back_output = back_output_list[0]
|
219 |
+
|
220 |
+
return back_output
|
221 |
+
|
222 |
+
|
223 |
+
def griffin_lim(self, l_est, gpu, n_iter=100):
|
224 |
+
l_est = l_est.cpu().detach()
|
225 |
+
|
226 |
+
l_est = torch.pow(l_est, 1/0.80)
|
227 |
+
# l_est [batch, channel, time]
|
228 |
+
l_mag = l_est.unsqueeze(-1)
|
229 |
+
l_phase = 2 * np.pi * torch.rand_like(l_mag) - np.pi
|
230 |
+
real = l_mag * torch.cos(l_phase)
|
231 |
+
imag = l_mag * torch.sin(l_phase)
|
232 |
+
S = torch.cat((real, imag), axis=-1)
|
233 |
+
S_mag = (real**2 + imag**2 + self.eps) ** 1/2
|
234 |
+
for i in range(n_iter):
|
235 |
+
x = ta_F.istft(S, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
|
236 |
+
S_new = torch.stft(x, n_fft=2048, hop_length=512, win_length=2048, window=torch.hann_window(2048))
|
237 |
+
S_new_phase = S_new/mag(S_new)
|
238 |
+
S = S_mag * S_new_phase
|
239 |
+
return x / torch.max(torch.abs(x))
|
240 |
+
|
modules/loss.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of objective functions used in the task 'ITO-Master'
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.nn as nn
|
8 |
+
import auraloss
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
13 |
+
sys.path.append(os.path.dirname(currentdir))
|
14 |
+
from modules.front_back_end import *
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
# Root Mean Squared Loss
|
19 |
+
# penalizes the volume factor with non-linearlity
|
20 |
+
class RMSLoss(nn.Module):
|
21 |
+
def __init__(self, reduce, loss_type="l2"):
|
22 |
+
super(RMSLoss, self).__init__()
|
23 |
+
self.weight_factor = 100.
|
24 |
+
if loss_type=="l2":
|
25 |
+
self.loss = nn.MSELoss(reduce=None)
|
26 |
+
|
27 |
+
|
28 |
+
def forward(self, est_targets, targets):
|
29 |
+
est_targets = est_targets.reshape(est_targets.shape[0]*est_targets.shape[1], est_targets.shape[2])
|
30 |
+
targets = targets.reshape(targets.shape[0]*targets.shape[1], targets.shape[2])
|
31 |
+
normalized_est = torch.sqrt(torch.mean(est_targets**2, dim=-1))
|
32 |
+
normalized_tgt = torch.sqrt(torch.mean(targets**2, dim=-1))
|
33 |
+
|
34 |
+
weight = torch.clamp(torch.abs(normalized_tgt-normalized_est), min=1/self.weight_factor) * self.weight_factor
|
35 |
+
|
36 |
+
return torch.mean(weight**1.5 * self.loss(normalized_est, normalized_tgt))
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
# Multi-Scale Spectral Loss proposed at the paper "DDSP: DIFFERENTIABLE DIGITAL SIGNAL PROCESSING" (https://arxiv.org/abs/2001.04643)
|
41 |
+
# we extend this loss by applying it to mid/side channels
|
42 |
+
class MultiScale_Spectral_Loss_MidSide_DDSP(nn.Module):
|
43 |
+
def __init__(self, mode='midside', \
|
44 |
+
reduce=True, \
|
45 |
+
n_filters=None, \
|
46 |
+
windows_size=None, \
|
47 |
+
hops_size=None, \
|
48 |
+
window="hann", \
|
49 |
+
eps=1e-7, \
|
50 |
+
device=torch.device("cpu")):
|
51 |
+
super(MultiScale_Spectral_Loss_MidSide_DDSP, self).__init__()
|
52 |
+
self.mode = mode
|
53 |
+
self.eps = eps
|
54 |
+
self.mid_weight = 0.5 # value in the range of 0.0 ~ 1.0
|
55 |
+
self.logmag_weight = 0.1
|
56 |
+
|
57 |
+
if n_filters is None:
|
58 |
+
n_filters = [4096, 2048, 1024, 512]
|
59 |
+
if windows_size is None:
|
60 |
+
windows_size = [4096, 2048, 1024, 512]
|
61 |
+
if hops_size is None:
|
62 |
+
hops_size = [1024, 512, 256, 128]
|
63 |
+
|
64 |
+
self.multiscales = []
|
65 |
+
for i in range(len(windows_size)):
|
66 |
+
cur_scale = {'window_size' : float(windows_size[i])}
|
67 |
+
if self.mode=='midside':
|
68 |
+
cur_scale['front_end'] = FrontEnd(channel='mono', \
|
69 |
+
n_fft=n_filters[i], \
|
70 |
+
hop_length=hops_size[i], \
|
71 |
+
win_length=windows_size[i], \
|
72 |
+
window=window, \
|
73 |
+
device=device)
|
74 |
+
elif self.mode=='ori':
|
75 |
+
cur_scale['front_end'] = FrontEnd(channel='stereo', \
|
76 |
+
n_fft=n_filters[i], \
|
77 |
+
hop_length=hops_size[i], \
|
78 |
+
win_length=windows_size[i], \
|
79 |
+
window=window, \
|
80 |
+
device=device)
|
81 |
+
self.multiscales.append(cur_scale)
|
82 |
+
|
83 |
+
self.objective_l1 = nn.L1Loss(reduce=reduce)
|
84 |
+
self.objective_l2 = nn.MSELoss(reduce=reduce)
|
85 |
+
|
86 |
+
|
87 |
+
def forward(self, est_targets, targets):
|
88 |
+
if self.mode=='midside':
|
89 |
+
return self.forward_midside(est_targets, targets)
|
90 |
+
elif self.mode=='ori':
|
91 |
+
return self.forward_ori(est_targets, targets)
|
92 |
+
|
93 |
+
|
94 |
+
def forward_ori(self, est_targets, targets):
|
95 |
+
total_loss = 0.0
|
96 |
+
total_mag_loss = 0.0
|
97 |
+
total_logmag_loss = 0.0
|
98 |
+
for cur_scale in self.multiscales:
|
99 |
+
est_mag = cur_scale['front_end'](est_targets, mode=["mag"])
|
100 |
+
tgt_mag = cur_scale['front_end'](targets, mode=["mag"])
|
101 |
+
|
102 |
+
mag_loss = self.magnitude_loss(est_mag, tgt_mag)
|
103 |
+
logmag_loss = self.log_magnitude_loss(est_mag, tgt_mag)
|
104 |
+
total_mag_loss += mag_loss
|
105 |
+
total_logmag_loss += logmag_loss
|
106 |
+
# return total_loss
|
107 |
+
return (1-self.logmag_weight)*total_mag_loss + \
|
108 |
+
(self.logmag_weight)*total_logmag_loss
|
109 |
+
|
110 |
+
|
111 |
+
def forward_midside(self, est_targets, targets):
|
112 |
+
est_mid, est_side = self.to_mid_side(est_targets)
|
113 |
+
tgt_mid, tgt_side = self.to_mid_side(targets)
|
114 |
+
total_loss = 0.0
|
115 |
+
total_mag_loss = 0.0
|
116 |
+
total_logmag_loss = 0.0
|
117 |
+
for cur_scale in self.multiscales:
|
118 |
+
est_mid_mag = cur_scale['front_end'](est_mid, mode=["mag"])
|
119 |
+
est_side_mag = cur_scale['front_end'](est_side, mode=["mag"])
|
120 |
+
tgt_mid_mag = cur_scale['front_end'](tgt_mid, mode=["mag"])
|
121 |
+
tgt_side_mag = cur_scale['front_end'](tgt_side, mode=["mag"])
|
122 |
+
|
123 |
+
mag_loss = self.mid_weight*self.magnitude_loss(est_mid_mag, tgt_mid_mag) + \
|
124 |
+
(1-self.mid_weight)*self.magnitude_loss(est_side_mag, tgt_side_mag)
|
125 |
+
logmag_loss = self.mid_weight*self.log_magnitude_loss(est_mid_mag, tgt_mid_mag) + \
|
126 |
+
(1-self.mid_weight)*self.log_magnitude_loss(est_side_mag, tgt_side_mag)
|
127 |
+
total_mag_loss += mag_loss
|
128 |
+
total_logmag_loss += logmag_loss
|
129 |
+
# return total_loss
|
130 |
+
return (1-self.logmag_weight)*total_mag_loss + \
|
131 |
+
(self.logmag_weight)*total_logmag_loss
|
132 |
+
|
133 |
+
|
134 |
+
def to_mid_side(self, stereo_in):
|
135 |
+
mid = stereo_in[:,0] + stereo_in[:,1]
|
136 |
+
side = stereo_in[:,0] - stereo_in[:,1]
|
137 |
+
return mid, side
|
138 |
+
|
139 |
+
|
140 |
+
def magnitude_loss(self, est_mag_spec, tgt_mag_spec):
|
141 |
+
return torch.norm(self.objective_l1(est_mag_spec, tgt_mag_spec))
|
142 |
+
|
143 |
+
|
144 |
+
def log_magnitude_loss(self, est_mag_spec, tgt_mag_spec):
|
145 |
+
est_log_mag_spec = torch.log10(est_mag_spec+self.eps)
|
146 |
+
tgt_log_mag_spec = torch.log10(tgt_mag_spec+self.eps)
|
147 |
+
return self.objective_l2(est_log_mag_spec, tgt_log_mag_spec)
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
# Class of available loss functions
|
152 |
+
class Loss:
|
153 |
+
def __init__(self, args, reduce=True):
|
154 |
+
device = torch.device("cpu")
|
155 |
+
if torch.cuda.is_available():
|
156 |
+
device = torch.device(f"cuda:{args.gpu}")
|
157 |
+
self.l1 = nn.L1Loss(reduce=reduce)
|
158 |
+
self.mse = nn.MSELoss(reduce=reduce)
|
159 |
+
self.ce = nn.CrossEntropyLoss()
|
160 |
+
self.triplet = nn.TripletMarginLoss(margin=1., p=2)
|
161 |
+
self.cos = nn.CosineSimilarity(eps=args.eps)
|
162 |
+
self.cosemb = nn.CosineEmbeddingLoss()
|
163 |
+
|
164 |
+
self.multi_scale_spectral_midside = MultiScale_Spectral_Loss_MidSide_DDSP(mode='midside', eps=args.eps, device=device)
|
165 |
+
self.multi_scale_spectral_ori = MultiScale_Spectral_Loss_MidSide_DDSP(mode='ori', eps=args.eps, device=device)
|
166 |
+
self.gain = RMSLoss(reduce=reduce)
|
167 |
+
self.infonce = infoNCE
|
168 |
+
# perceptual weighting with mel scaled spectrograms
|
169 |
+
self.mrs_mel_perceptual = auraloss.freq.MultiResolutionSTFTLoss(
|
170 |
+
fft_sizes=[1024, 2048, 8192],
|
171 |
+
hop_sizes=[256, 512, 2048],
|
172 |
+
win_lengths=[1024, 2048, 8192],
|
173 |
+
scale="mel",
|
174 |
+
n_bins=128,
|
175 |
+
sample_rate=args.sample_rate,
|
176 |
+
perceptual_weighting=True,
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
"""
|
183 |
+
Audio Feature Loss implementation
|
184 |
+
copied from https://github.com/sai-soum/Diff-MST/blob/main/mst/loss.py
|
185 |
+
"""
|
186 |
+
|
187 |
+
import librosa
|
188 |
+
|
189 |
+
from typing import List
|
190 |
+
from modules.filter import barkscale_fbanks
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
def compute_mid_side(x: torch.Tensor):
|
196 |
+
x_mid = x[:, 0, :] + x[:, 1, :]
|
197 |
+
x_side = x[:, 0, :] - x[:, 1, :]
|
198 |
+
return x_mid, x_side
|
199 |
+
|
200 |
+
|
201 |
+
def compute_melspectrum(
|
202 |
+
x: torch.Tensor,
|
203 |
+
sample_rate: int = 44100,
|
204 |
+
fft_size: int = 32768,
|
205 |
+
n_bins: int = 128,
|
206 |
+
**kwargs,
|
207 |
+
):
|
208 |
+
"""Compute mel-spectrogram.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
x: (bs, 2, seq_len)
|
212 |
+
sample_rate: sample rate of audio
|
213 |
+
fft_size: size of fft
|
214 |
+
n_bins: number of mel bins
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
X: (bs, n_bins)
|
218 |
+
|
219 |
+
"""
|
220 |
+
fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
|
221 |
+
fb = torch.tensor(fb).unsqueeze(0).type_as(x)
|
222 |
+
|
223 |
+
x = x.mean(dim=1, keepdim=True)
|
224 |
+
X = torch.fft.rfft(x, n=fft_size, dim=-1)
|
225 |
+
X = torch.abs(X)
|
226 |
+
X = torch.mean(X, dim=1, keepdim=True) # take mean over time
|
227 |
+
X = X.permute(0, 2, 1) # swap time and freq dims
|
228 |
+
X = torch.matmul(fb, X)
|
229 |
+
X = torch.log(X + 1e-8)
|
230 |
+
|
231 |
+
return X
|
232 |
+
|
233 |
+
|
234 |
+
def compute_barkspectrum(
|
235 |
+
x: torch.Tensor,
|
236 |
+
fft_size: int = 32768,
|
237 |
+
n_bands: int = 24,
|
238 |
+
sample_rate: int = 44100,
|
239 |
+
f_min: float = 20.0,
|
240 |
+
f_max: float = 20000.0,
|
241 |
+
mode: str = "mid-side",
|
242 |
+
**kwargs,
|
243 |
+
):
|
244 |
+
"""Compute bark-spectrogram.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
x: (bs, 2, seq_len)
|
248 |
+
fft_size: size of fft
|
249 |
+
n_bands: number of bark bins
|
250 |
+
sample_rate: sample rate of audio
|
251 |
+
f_min: minimum frequency
|
252 |
+
f_max: maximum frequency
|
253 |
+
mode: "mono", "stereo", or "mid-side"
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
X: (bs, 24)
|
257 |
+
|
258 |
+
"""
|
259 |
+
# compute filterbank
|
260 |
+
fb = barkscale_fbanks((fft_size // 2) + 1, f_min, f_max, n_bands, sample_rate)
|
261 |
+
fb = fb.unsqueeze(0).type_as(x)
|
262 |
+
fb = fb.permute(0, 2, 1)
|
263 |
+
|
264 |
+
if mode == "mono":
|
265 |
+
x = x.mean(dim=1) # average over channels
|
266 |
+
signals = [x]
|
267 |
+
elif mode == "stereo":
|
268 |
+
signals = [x[:, 0, :], x[:, 1, :]]
|
269 |
+
elif mode == "mid-side":
|
270 |
+
x_mid = x[:, 0, :] + x[:, 1, :]
|
271 |
+
x_side = x[:, 0, :] - x[:, 1, :]
|
272 |
+
signals = [x_mid, x_side]
|
273 |
+
else:
|
274 |
+
raise ValueError(f"Invalid mode {mode}")
|
275 |
+
|
276 |
+
outputs = []
|
277 |
+
for signal in signals:
|
278 |
+
X = torch.stft(
|
279 |
+
signal,
|
280 |
+
n_fft=fft_size,
|
281 |
+
hop_length=fft_size // 4,
|
282 |
+
return_complex=True,
|
283 |
+
window=torch.hann_window(fft_size).to(x.device),
|
284 |
+
) # compute stft
|
285 |
+
X = torch.abs(X) # take magnitude
|
286 |
+
X = torch.mean(X, dim=-1, keepdim=True) # take mean over time
|
287 |
+
# X = X.permute(0, 2, 1) # swap time and freq dims
|
288 |
+
X = torch.matmul(fb, X) # apply filterbank
|
289 |
+
X = torch.log(X + 1e-8)
|
290 |
+
# X = torch.cat([X, X_log], dim=-1)
|
291 |
+
outputs.append(X)
|
292 |
+
|
293 |
+
# stack into tensor
|
294 |
+
X = torch.cat(outputs, dim=-1)
|
295 |
+
|
296 |
+
return X
|
297 |
+
|
298 |
+
|
299 |
+
def compute_rms(x: torch.Tensor, **kwargs):
|
300 |
+
"""Compute root mean square energy.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
x: (bs, 1, seq_len)
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
rms: (bs, )
|
307 |
+
"""
|
308 |
+
rms = torch.sqrt(torch.mean(x**2, dim=-1).clamp(min=1e-8))
|
309 |
+
return rms
|
310 |
+
|
311 |
+
|
312 |
+
def compute_crest_factor(x: torch.Tensor, **kwargs):
|
313 |
+
"""Compute crest factor as ratio of peak to rms energy in dB.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
x: (bs, 2, seq_len)
|
317 |
+
|
318 |
+
"""
|
319 |
+
num = torch.max(torch.abs(x), dim=-1)[0]
|
320 |
+
den = compute_rms(x).clamp(min=1e-8)
|
321 |
+
cf = 20 * torch.log10((num / den).clamp(min=1e-8))
|
322 |
+
return cf
|
323 |
+
|
324 |
+
|
325 |
+
def compute_stereo_width(x: torch.Tensor, **kwargs):
|
326 |
+
"""Compute stereo width as ratio of energy in sum and difference signals.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
x: (bs, 2, seq_len)
|
330 |
+
|
331 |
+
"""
|
332 |
+
bs, chs, seq_len = x.size()
|
333 |
+
|
334 |
+
assert chs == 2, "Input must be stereo"
|
335 |
+
|
336 |
+
# compute sum and diff of stereo channels
|
337 |
+
x_sum = x[:, 0, :] + x[:, 1, :]
|
338 |
+
x_diff = x[:, 0, :] - x[:, 1, :]
|
339 |
+
|
340 |
+
# compute power of sum and diff
|
341 |
+
sum_energy = torch.mean(x_sum**2, dim=-1)
|
342 |
+
diff_energy = torch.mean(x_diff**2, dim=-1)
|
343 |
+
|
344 |
+
# compute stereo width as ratio
|
345 |
+
stereo_width = diff_energy / sum_energy.clamp(min=1e-8)
|
346 |
+
|
347 |
+
return stereo_width
|
348 |
+
|
349 |
+
|
350 |
+
def compute_stereo_imbalance(x: torch.Tensor, **kwargs):
|
351 |
+
"""Compute stereo imbalance as ratio of energy in left and right channels.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
x: (bs, 2, seq_len)
|
355 |
+
|
356 |
+
Returns:
|
357 |
+
stereo_imbalance: (bs, )
|
358 |
+
|
359 |
+
"""
|
360 |
+
left_energy = torch.mean(x[:, 0, :] ** 2, dim=-1)
|
361 |
+
right_energy = torch.mean(x[:, 1, :] ** 2, dim=-1)
|
362 |
+
|
363 |
+
stereo_imbalance = (right_energy - left_energy) / (
|
364 |
+
right_energy + left_energy
|
365 |
+
).clamp(min=1e-8)
|
366 |
+
|
367 |
+
return stereo_imbalance
|
368 |
+
|
369 |
+
|
370 |
+
class AudioFeatureLoss(torch.nn.Module):
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
weights: List[float],
|
374 |
+
sample_rate: int,
|
375 |
+
stem_separation: bool = False,
|
376 |
+
use_clap: bool = False,
|
377 |
+
) -> None:
|
378 |
+
"""Compute loss using a set of differentiable audio features.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
weights: weights for each feature
|
382 |
+
sample_rate: sample rate of audio
|
383 |
+
stem_separation: whether to compute loss on stems or mix
|
384 |
+
|
385 |
+
Based on features proposed in:
|
386 |
+
|
387 |
+
Man, B. D., et al.
|
388 |
+
"An analysis and evaluation of audio features for multitrack music mixtures."
|
389 |
+
(2014).
|
390 |
+
|
391 |
+
"""
|
392 |
+
super().__init__()
|
393 |
+
self.weights = weights
|
394 |
+
self.sample_rate = sample_rate
|
395 |
+
self.stem_separation = stem_separation
|
396 |
+
self.sources_list = ["mix"]
|
397 |
+
self.source_weights = [1.0]
|
398 |
+
self.use_clap = use_clap
|
399 |
+
|
400 |
+
self.transforms = [
|
401 |
+
compute_rms,
|
402 |
+
compute_crest_factor,
|
403 |
+
compute_stereo_width,
|
404 |
+
compute_stereo_imbalance,
|
405 |
+
compute_barkspectrum,
|
406 |
+
]
|
407 |
+
|
408 |
+
assert len(self.transforms) == len(weights)
|
409 |
+
|
410 |
+
def forward(self, input: torch.Tensor, target: torch.Tensor):
|
411 |
+
losses = {}
|
412 |
+
|
413 |
+
# reshape for example stem dim
|
414 |
+
input_stems = input.unsqueeze(1)
|
415 |
+
target_stems = target.unsqueeze(1)
|
416 |
+
|
417 |
+
n_stems = input_stems.shape[1]
|
418 |
+
|
419 |
+
# iterate over each stem compute loss for each transform
|
420 |
+
for stem_idx in range(n_stems):
|
421 |
+
input_stem = input_stems[:, stem_idx, ...]
|
422 |
+
target_stem = target_stems[:, stem_idx, ...]
|
423 |
+
|
424 |
+
for transform, weight in zip(self.transforms, self.weights):
|
425 |
+
transform_name = "_".join(transform.__name__.split("_")[1:])
|
426 |
+
key = f"{self.sources_list[stem_idx]}-{transform_name}"
|
427 |
+
input_transform = transform(input_stem, sample_rate=self.sample_rate)
|
428 |
+
target_transform = transform(target_stem, sample_rate=self.sample_rate)
|
429 |
+
val = torch.nn.functional.mse_loss(input_transform, target_transform)
|
430 |
+
losses[key] = weight * val * self.source_weights[stem_idx]
|
431 |
+
|
432 |
+
return losses
|
networks/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .architectures import *
|
2 |
+
from .network_utils import *
|
3 |
+
from .dasp_additionals import *
|
networks/architectures.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of neural networks used in the task 'Music Mastering Style Transfer'
|
3 |
+
- 'Effects Encoder'
|
4 |
+
- 'Mastering Style Transfer'
|
5 |
+
- 'Differentiable Mastering Style Transfer'
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.nn.init as init
|
11 |
+
import dasp_pytorch
|
12 |
+
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
import time
|
16 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
17 |
+
sys.path.append(currentdir)
|
18 |
+
from network_utils import *
|
19 |
+
from dasp_additionals import Multiband_Compressor, Distortion, Limiter
|
20 |
+
|
21 |
+
# compute receptive field
|
22 |
+
def compute_receptive_field(kernels, strides, dilations):
|
23 |
+
rf = 0
|
24 |
+
for i in range(len(kernels)):
|
25 |
+
rf += rf * strides[i] + (kernels[i]-strides[i]) * dilations[i]
|
26 |
+
return rf
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
# Encoder of music effects for contrastive learning of music effects
|
31 |
+
class Effects_Encoder(nn.Module):
|
32 |
+
def __init__(self, config):
|
33 |
+
super(Effects_Encoder, self).__init__()
|
34 |
+
# input is stereo channeled audio
|
35 |
+
config["channels"].insert(0, 2)
|
36 |
+
|
37 |
+
# encoder layers
|
38 |
+
encoder = []
|
39 |
+
for i in range(len(config["kernels"])):
|
40 |
+
if config["conv_block"]=='res':
|
41 |
+
encoder.append(Res_ConvBlock(dimension=1, \
|
42 |
+
in_channels=config["channels"][i], \
|
43 |
+
out_channels=config["channels"][i+1], \
|
44 |
+
kernel_size=config["kernels"][i], \
|
45 |
+
stride=config["strides"][i], \
|
46 |
+
padding="SAME", \
|
47 |
+
dilation=config["dilation"][i], \
|
48 |
+
norm=config["norm"], \
|
49 |
+
activation=config["activation"], \
|
50 |
+
last_activation=config["activation"]))
|
51 |
+
elif config["conv_block"]=='conv':
|
52 |
+
encoder.append(ConvBlock(dimension=1, \
|
53 |
+
layer_num=1, \
|
54 |
+
in_channels=config["channels"][i], \
|
55 |
+
out_channels=config["channels"][i+1], \
|
56 |
+
kernel_size=config["kernels"][i], \
|
57 |
+
stride=config["strides"][i], \
|
58 |
+
padding="VALID", \
|
59 |
+
dilation=config["dilation"][i], \
|
60 |
+
norm=config["norm"], \
|
61 |
+
activation=config["activation"], \
|
62 |
+
last_activation=config["activation"], \
|
63 |
+
mode='conv'))
|
64 |
+
self.encoder = nn.Sequential(*encoder)
|
65 |
+
|
66 |
+
# pooling method
|
67 |
+
self.glob_pool = nn.AdaptiveAvgPool1d(1)
|
68 |
+
|
69 |
+
|
70 |
+
# network forward operation
|
71 |
+
def forward(self, input):
|
72 |
+
enc_output = self.encoder(input)
|
73 |
+
glob_pooled = self.glob_pool(enc_output).squeeze(-1)
|
74 |
+
|
75 |
+
# outputs c feature
|
76 |
+
return glob_pooled
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
class TCNBlock(torch.nn.Module):
|
81 |
+
def __init__(self,
|
82 |
+
in_ch,
|
83 |
+
out_ch,
|
84 |
+
kernel_size=3,
|
85 |
+
stride=1,
|
86 |
+
dilation=1,
|
87 |
+
cond_dim=2048,
|
88 |
+
grouped=False,
|
89 |
+
causal=False,
|
90 |
+
conditional=False,
|
91 |
+
**kwargs):
|
92 |
+
super(TCNBlock, self).__init__()
|
93 |
+
|
94 |
+
self.in_ch = in_ch
|
95 |
+
self.out_ch = out_ch
|
96 |
+
self.kernel_size = kernel_size
|
97 |
+
self.dilation = dilation
|
98 |
+
self.grouped = grouped
|
99 |
+
self.causal = causal
|
100 |
+
self.conditional = conditional
|
101 |
+
|
102 |
+
groups = out_ch if grouped and (in_ch % out_ch == 0) else 1
|
103 |
+
|
104 |
+
self.pad_length = ((kernel_size-1)*dilation) if self.causal else ((kernel_size-1)*dilation)//2
|
105 |
+
self.conv1 = torch.nn.Conv1d(in_ch,
|
106 |
+
out_ch,
|
107 |
+
kernel_size=kernel_size,
|
108 |
+
stride=stride,
|
109 |
+
padding=self.pad_length,
|
110 |
+
dilation=dilation,
|
111 |
+
groups=groups,
|
112 |
+
bias=False)
|
113 |
+
if grouped:
|
114 |
+
self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1)
|
115 |
+
|
116 |
+
if conditional:
|
117 |
+
self.film = FiLM(cond_dim, out_ch)
|
118 |
+
self.bn = torch.nn.BatchNorm1d(out_ch)
|
119 |
+
|
120 |
+
self.relu = torch.nn.LeakyReLU()
|
121 |
+
self.res = torch.nn.Conv1d(in_ch,
|
122 |
+
out_ch,
|
123 |
+
kernel_size=1,
|
124 |
+
stride=stride,
|
125 |
+
groups=in_ch,
|
126 |
+
bias=False)
|
127 |
+
|
128 |
+
|
129 |
+
def forward(self, x, p):
|
130 |
+
x_in = x
|
131 |
+
|
132 |
+
x = self.relu(self.bn(self.conv1(x)))
|
133 |
+
x = self.film(x, p)
|
134 |
+
|
135 |
+
x_res = self.res(x_in)
|
136 |
+
|
137 |
+
if self.causal:
|
138 |
+
x = x[..., :-self.pad_length]
|
139 |
+
x += x_res
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
import pytorch_lightning as pl
|
146 |
+
class TCNModel(pl.LightningModule):
|
147 |
+
""" Temporal convolutional network with conditioning module.
|
148 |
+
Args:
|
149 |
+
nparams (int): Number of conditioning parameters.
|
150 |
+
ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1
|
151 |
+
noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1
|
152 |
+
nblocks (int): Number of total TCN blocks. Default: 10
|
153 |
+
kernel_size (int): Width of the convolutional kernels. Default: 3
|
154 |
+
dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
|
155 |
+
channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2
|
156 |
+
channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
|
157 |
+
stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10
|
158 |
+
grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False
|
159 |
+
causal (bool): Causal TCN configuration does not consider future input values. Default: False
|
160 |
+
skip_connections (bool): Skip connections from each block to the output. Default: False
|
161 |
+
num_examples (int): Number of evaluation audio examples to log after each epochs. Default: 4
|
162 |
+
"""
|
163 |
+
def __init__(self,
|
164 |
+
nparams,
|
165 |
+
ninputs=1,
|
166 |
+
noutputs=1,
|
167 |
+
nblocks=10,
|
168 |
+
kernel_size=3,
|
169 |
+
stride=1,
|
170 |
+
dilation_growth=1,
|
171 |
+
channel_growth=1,
|
172 |
+
channel_width=32,
|
173 |
+
stack_size=10,
|
174 |
+
cond_dim=2048,
|
175 |
+
grouped=False,
|
176 |
+
causal=False,
|
177 |
+
skip_connections=False,
|
178 |
+
num_examples=4,
|
179 |
+
save_dir=None,
|
180 |
+
**kwargs):
|
181 |
+
super(TCNModel, self).__init__()
|
182 |
+
self.save_hyperparameters()
|
183 |
+
|
184 |
+
self.blocks = torch.nn.ModuleList()
|
185 |
+
for n in range(nblocks):
|
186 |
+
in_ch = out_ch if n > 0 else ninputs
|
187 |
+
|
188 |
+
if self.hparams.channel_growth > 1:
|
189 |
+
out_ch = in_ch * self.hparams.channel_growth
|
190 |
+
else:
|
191 |
+
out_ch = self.hparams.channel_width
|
192 |
+
|
193 |
+
dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
|
194 |
+
cur_stride = stride[n] if isinstance(stride, list) else stride
|
195 |
+
self.blocks.append(TCNBlock(in_ch,
|
196 |
+
out_ch,
|
197 |
+
kernel_size=self.hparams.kernel_size,
|
198 |
+
stride=cur_stride,
|
199 |
+
dilation=dilation,
|
200 |
+
padding="same" if self.hparams.causal else "valid",
|
201 |
+
causal=self.hparams.causal,
|
202 |
+
cond_dim=cond_dim,
|
203 |
+
grouped=self.hparams.grouped,
|
204 |
+
conditional=True if self.hparams.nparams > 0 else False))
|
205 |
+
|
206 |
+
self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)
|
207 |
+
|
208 |
+
def forward(self, x, cond):
|
209 |
+
# iterate over blocks passing conditioning
|
210 |
+
for idx, block in enumerate(self.blocks):
|
211 |
+
# for SeFa
|
212 |
+
if isinstance(cond, list):
|
213 |
+
x = block(x, cond[idx])
|
214 |
+
else:
|
215 |
+
x = block(x, cond)
|
216 |
+
skips = 0
|
217 |
+
|
218 |
+
# out = torch.tanh(self.output(x + skips))
|
219 |
+
out = torch.clamp(self.output(x + skips), min=-1, max=1)
|
220 |
+
|
221 |
+
return out
|
222 |
+
|
223 |
+
def compute_receptive_field(self):
|
224 |
+
""" Compute the receptive field in samples."""
|
225 |
+
rf = self.hparams.kernel_size
|
226 |
+
for n in range(1,self.hparams.nblocks):
|
227 |
+
dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size)
|
228 |
+
rf = rf + ((self.hparams.kernel_size-1) * dilation)
|
229 |
+
return rf
|
230 |
+
|
231 |
+
# add any model hyperparameters here
|
232 |
+
@staticmethod
|
233 |
+
def add_model_specific_args(parent_parser):
|
234 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
235 |
+
# --- model related ---
|
236 |
+
parser.add_argument('--ninputs', type=int, default=1)
|
237 |
+
parser.add_argument('--noutputs', type=int, default=1)
|
238 |
+
parser.add_argument('--nblocks', type=int, default=4)
|
239 |
+
parser.add_argument('--kernel_size', type=int, default=5)
|
240 |
+
parser.add_argument('--dilation_growth', type=int, default=10)
|
241 |
+
parser.add_argument('--channel_growth', type=int, default=1)
|
242 |
+
parser.add_argument('--channel_width', type=int, default=32)
|
243 |
+
parser.add_argument('--stack_size', type=int, default=10)
|
244 |
+
parser.add_argument('--grouped', default=False, action='store_true')
|
245 |
+
parser.add_argument('--causal', default=False, action="store_true")
|
246 |
+
parser.add_argument('--skip_connections', default=False, action="store_true")
|
247 |
+
|
248 |
+
return parser
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
# Module for fitting SeFa parameters
|
253 |
+
class Dasp_Mastering_Style_Transfer(nn.Module):
|
254 |
+
def __init__(self, num_features, sample_rate, \
|
255 |
+
tgt_fx_names = ['eq', 'comp', 'imager', 'gain'], \
|
256 |
+
model_type='2mlp', \
|
257 |
+
config=None, \
|
258 |
+
batch_size=4):
|
259 |
+
super(Dasp_Mastering_Style_Transfer, self).__init__()
|
260 |
+
self.sample_rate = sample_rate
|
261 |
+
self.tgt_fx_names = tgt_fx_names
|
262 |
+
|
263 |
+
self.fx_processors = {}
|
264 |
+
self.last_predicted_params = None
|
265 |
+
for cur_fx in tgt_fx_names:
|
266 |
+
if cur_fx=='eq':
|
267 |
+
cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=sample_rate, \
|
268 |
+
min_gain_db = -20.0, \
|
269 |
+
max_gain_db = 20.0, \
|
270 |
+
min_q_factor = 0.1, \
|
271 |
+
max_q_factor=5.0)
|
272 |
+
elif cur_fx=='distortion':
|
273 |
+
cur_fx_module = Distortion(sample_rate=sample_rate,
|
274 |
+
min_gain_db = 0.0,
|
275 |
+
max_gain_db = 8.0)
|
276 |
+
elif cur_fx=='comp':
|
277 |
+
cur_fx_module = dasp_pytorch.Compressor(sample_rate=sample_rate)
|
278 |
+
elif cur_fx=='multiband_comp':
|
279 |
+
cur_fx_module = Multiband_Compressor(sample_rate=sample_rate)
|
280 |
+
elif cur_fx=='gain':
|
281 |
+
cur_fx_module = dasp_pytorch.Gain(sample_rate=sample_rate)
|
282 |
+
elif cur_fx=='imager':
|
283 |
+
continue
|
284 |
+
elif cur_fx=='limiter':
|
285 |
+
cur_fx_module = Limiter(sample_rate=sample_rate)
|
286 |
+
else:
|
287 |
+
raise AssertionError(f"current fx name ({cur_fx}) not found")
|
288 |
+
self.fx_processors[cur_fx] = cur_fx_module
|
289 |
+
total_num_param = sum([self.fx_processors[cur_fx].num_params for cur_fx in self.fx_processors])
|
290 |
+
if 'imager' in tgt_fx_names:
|
291 |
+
total_num_param += 1
|
292 |
+
|
293 |
+
''' model architecture '''
|
294 |
+
self.model_type = model_type
|
295 |
+
if self.model_type.lower()=='tcn':
|
296 |
+
self.network = TCNModel(nparams=config["condition_dimension"], ninputs=2, \
|
297 |
+
noutputs=total_num_param, \
|
298 |
+
nblocks=config["nblocks"], \
|
299 |
+
dilation_growth=config["dilation_growth"], \
|
300 |
+
kernel_size=config["kernel_size"], \
|
301 |
+
stride=config['stride'], \
|
302 |
+
channel_width=config["channel_width"], \
|
303 |
+
stack_size=config["stack_size"], \
|
304 |
+
cond_dim=config["condition_dimension"], \
|
305 |
+
causal=config["causal"])
|
306 |
+
elif self.model_type.lower()=='ito':
|
307 |
+
self.params = torch.nn.Parameter(torch.ones((batch_size,total_num_param))*0.5)
|
308 |
+
|
309 |
+
# network forward operation
|
310 |
+
def forward(self, x, embedding):
|
311 |
+
# embedding mapper
|
312 |
+
if self.model_type.lower()=='tcn':
|
313 |
+
est_param = self.network(x, embedding)
|
314 |
+
est_param = est_param.mean(axis=-1)
|
315 |
+
elif self.model_type.lower()=='ito':
|
316 |
+
est_param = self.params
|
317 |
+
est_param = torch.clamp(est_param, min=0.0, max=1.0)
|
318 |
+
|
319 |
+
if self.model_type.lower()!='ito':
|
320 |
+
est_param = F.sigmoid(est_param)
|
321 |
+
|
322 |
+
self.last_predicted_params = est_param
|
323 |
+
|
324 |
+
# dafx chain
|
325 |
+
cur_param_idx = 0
|
326 |
+
for cur_fx in self.tgt_fx_names:
|
327 |
+
if cur_fx=='imager':
|
328 |
+
cur_param_count = 1
|
329 |
+
x = dasp_pytorch.functional.stereo_widener(x, \
|
330 |
+
sample_rate=self.sample_rate, \
|
331 |
+
width=est_param[:,cur_param_idx:cur_param_idx+1])
|
332 |
+
else:
|
333 |
+
cur_param_count = self.fx_processors[cur_fx].num_params
|
334 |
+
cur_input_param = est_param[:, cur_param_idx:cur_param_idx+cur_param_count]
|
335 |
+
x = self.fx_processors[cur_fx].process_normalized(x, cur_input_param)
|
336 |
+
# update param index
|
337 |
+
cur_param_idx += cur_param_count
|
338 |
+
|
339 |
+
return x
|
340 |
+
|
341 |
+
|
342 |
+
def reset_fx_chain(self, ):
|
343 |
+
self.fx_processors = {}
|
344 |
+
for cur_fx in self.tgt_fx_names:
|
345 |
+
if cur_fx=='eq':
|
346 |
+
cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=self.sample_rate, \
|
347 |
+
min_gain_db = -20.0, \
|
348 |
+
max_gain_db = 20.0, \
|
349 |
+
min_q_factor = 0.1, \
|
350 |
+
max_q_factor=5.0)
|
351 |
+
elif cur_fx=='distortion':
|
352 |
+
cur_fx_module = Distortion(sample_rate=self.sample_rate,
|
353 |
+
min_gain_db = 0.0,
|
354 |
+
max_gain_db = 8.0)
|
355 |
+
elif cur_fx=='comp':
|
356 |
+
cur_fx_module = dasp_pytorch.Compressor(sample_rate=self.sample_rate)
|
357 |
+
elif cur_fx=='multiband_comp':
|
358 |
+
cur_fx_module = Multiband_Compressor(sample_rate=self.sample_rate)
|
359 |
+
elif cur_fx=='gain':
|
360 |
+
cur_fx_module = dasp_pytorch.Gain(sample_rate=self.sample_rate)
|
361 |
+
elif cur_fx=='imager':
|
362 |
+
continue
|
363 |
+
elif cur_fx=='limiter':
|
364 |
+
cur_fx_module = Limiter(sample_rate=self.sample_rate)
|
365 |
+
else:
|
366 |
+
raise AssertionError(f"current fx name ({cur_fx}) not found")
|
367 |
+
self.fx_processors[cur_fx] = cur_fx_module
|
368 |
+
|
369 |
+
def get_last_predicted_params(self):
|
370 |
+
if self.last_predicted_params is None:
|
371 |
+
return None
|
372 |
+
|
373 |
+
params_dict = {}
|
374 |
+
cur_param_idx = 0
|
375 |
+
|
376 |
+
for cur_fx in self.tgt_fx_names:
|
377 |
+
if cur_fx == 'imager':
|
378 |
+
cur_param_count = 1
|
379 |
+
normalized_param = self.last_predicted_params[:, cur_param_idx:cur_param_idx+1]
|
380 |
+
original_param = self.denormalize_param(normalized_param, 0, 1)
|
381 |
+
params_dict[cur_fx] = original_param
|
382 |
+
else:
|
383 |
+
cur_param_count = self.fx_processors[cur_fx].num_params
|
384 |
+
normalized_params = self.last_predicted_params[:, cur_param_idx:cur_param_idx+cur_param_count]
|
385 |
+
original_params = self.denormalize_params(cur_fx, normalized_params)
|
386 |
+
params_dict[cur_fx] = original_params
|
387 |
+
|
388 |
+
cur_param_idx += cur_param_count
|
389 |
+
|
390 |
+
return params_dict
|
391 |
+
|
392 |
+
def denormalize_params(self, fx_name, normalized_params):
|
393 |
+
fx_processor = self.fx_processors[fx_name]
|
394 |
+
original_params = {}
|
395 |
+
|
396 |
+
for i, (param_name, (min_val, max_val)) in enumerate(fx_processor.param_ranges.items()):
|
397 |
+
original_param = self.denormalize_param(normalized_params[:, i:i+1], min_val, max_val)
|
398 |
+
original_params[param_name] = original_param
|
399 |
+
|
400 |
+
return original_params
|
401 |
+
|
402 |
+
@staticmethod
|
403 |
+
def denormalize_param(normalized_param, min_val, max_val):
|
404 |
+
return normalized_param * (max_val - min_val) + min_val
|
405 |
+
|
networks/dasp_additionals.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of differentiable mastering effects based on DASP-pytorch and torchcomp libraries
|
3 |
+
- Distortion
|
4 |
+
- Multiband Compressor
|
5 |
+
- Limiter
|
6 |
+
DASP-pytorch: https://github.com/csteinmetz1/dasp-pytorch
|
7 |
+
torchcomp: https://github.com/yoyololicon/torchcomp
|
8 |
+
"""
|
9 |
+
import dasp_pytorch
|
10 |
+
from dasp_pytorch.modules import Processor
|
11 |
+
import torchcomp
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.nn as nn
|
15 |
+
import numpy as np
|
16 |
+
import time
|
17 |
+
|
18 |
+
|
19 |
+
EPS = 1e-6
|
20 |
+
|
21 |
+
class Distortion(Processor):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
sample_rate: int,
|
25 |
+
min_gain_db: float = 0.0,
|
26 |
+
max_gain_db: float = 24.0,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.sample_rate = sample_rate
|
30 |
+
self.process_fn = distortion
|
31 |
+
self.param_ranges = {
|
32 |
+
"drive_db": (min_gain_db, max_gain_db),
|
33 |
+
"parallel_weight_factor": (0.2, 0.7),
|
34 |
+
}
|
35 |
+
self.num_params = len(self.param_ranges)
|
36 |
+
|
37 |
+
def distortion(x: torch.Tensor,
|
38 |
+
sample_rate: int,
|
39 |
+
drive_db: torch.Tensor,
|
40 |
+
parallel_weight_factor: torch.Tensor()):
|
41 |
+
"""Simple soft-clipping distortion with drive control.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
x (torch.Tensor): Input audio tensor with shape (bs, chs, seq_len)
|
45 |
+
sample_rate (int): Audio sample rate.
|
46 |
+
drive_db (torch.Tensor): Drive in dB with shape (bs)
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
torch.Tensor: Output audio tensor with shape (bs, chs, seq_len)
|
50 |
+
|
51 |
+
"""
|
52 |
+
bs, chs, seq_len = x.size()
|
53 |
+
parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1)
|
54 |
+
|
55 |
+
# return torch.tanh(x * (10 ** (drive_db.view(bs, chs, -1) / 20.0))) -> wrong?
|
56 |
+
x_dist = torch.tanh(x * (10 ** (drive_db.view(bs, 1, 1) / 20.0)))
|
57 |
+
|
58 |
+
# parallel compuatation
|
59 |
+
return parallel_weight_factor * x_dist + (1-parallel_weight_factor) * x
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
class Multiband_Compressor(Processor):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
sample_rate: int,
|
67 |
+
min_threshold_db_comp: float = -60.0,
|
68 |
+
max_threshold_db_comp: float = 0.0-EPS,
|
69 |
+
min_ratio_comp: float = 1.0+EPS,
|
70 |
+
max_ratio_comp: float = 20.0,
|
71 |
+
min_attack_ms_comp: float = 5.0,
|
72 |
+
max_attack_ms_comp: float = 100.0,
|
73 |
+
min_release_ms_comp: float = 5.0,
|
74 |
+
max_release_ms_comp: float = 100.0,
|
75 |
+
min_threshold_db_exp: float = -60.0,
|
76 |
+
max_threshold_db_exp: float = 0.0-EPS,
|
77 |
+
min_ratio_exp: float = 0.0+EPS,
|
78 |
+
max_ratio_exp: float = 1.0-EPS,
|
79 |
+
min_attack_ms_exp: float = 5.0,
|
80 |
+
max_attack_ms_exp: float = 100.0,
|
81 |
+
min_release_ms_exp: float = 5.0,
|
82 |
+
max_release_ms_exp: float = 100.0,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
self.sample_rate = sample_rate
|
86 |
+
self.process_fn = multiband_compressor
|
87 |
+
self.param_ranges = {
|
88 |
+
"low_cutoff": (20, 300),
|
89 |
+
"high_cutoff": (2000, 12000),
|
90 |
+
"parallel_weight_factor": (0.2, 0.7),
|
91 |
+
|
92 |
+
"low_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
|
93 |
+
"low_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp),
|
94 |
+
"low_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
|
95 |
+
"low_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp),
|
96 |
+
"low_shelf_at": (min_attack_ms_exp, max_attack_ms_exp),
|
97 |
+
"low_shelf_rt": (min_release_ms_exp, max_release_ms_exp),
|
98 |
+
|
99 |
+
"mid_band_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
|
100 |
+
"mid_band_comp_ratio": (min_ratio_comp, max_ratio_comp),
|
101 |
+
"mid_band_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
|
102 |
+
"mid_band_exp_ratio": (min_ratio_exp, max_ratio_exp),
|
103 |
+
"mid_band_at": (min_attack_ms_exp, max_attack_ms_exp),
|
104 |
+
"mid_band_rt": (min_release_ms_exp, max_release_ms_exp),
|
105 |
+
|
106 |
+
"high_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp),
|
107 |
+
"high_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp),
|
108 |
+
"high_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp),
|
109 |
+
"high_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp),
|
110 |
+
"high_shelf_at": (min_attack_ms_exp, max_attack_ms_exp),
|
111 |
+
"high_shelf_rt": (min_release_ms_exp, max_release_ms_exp),
|
112 |
+
}
|
113 |
+
self.num_params = len(self.param_ranges)
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
def linkwitz_riley_4th_order(
|
118 |
+
x: torch.Tensor,
|
119 |
+
cutoff_freq: torch.Tensor,
|
120 |
+
sample_rate: float,
|
121 |
+
filter_type: str):
|
122 |
+
q_factor = torch.ones(cutoff_freq.shape) / torch.sqrt(torch.tensor([2.0]))
|
123 |
+
gain_db = torch.zeros(cutoff_freq.shape)
|
124 |
+
q_factor = q_factor.to(x.device)
|
125 |
+
gain_db = gain_db.to(x.device)
|
126 |
+
|
127 |
+
b, a = dasp_pytorch.signal.biquad(
|
128 |
+
gain_db,
|
129 |
+
cutoff_freq,
|
130 |
+
q_factor,
|
131 |
+
sample_rate,
|
132 |
+
filter_type
|
133 |
+
)
|
134 |
+
|
135 |
+
del gain_db
|
136 |
+
del q_factor
|
137 |
+
|
138 |
+
eff_bs = x.size(0)
|
139 |
+
# six second order sections
|
140 |
+
sos = torch.cat((b, a), dim=-1).unsqueeze(1)
|
141 |
+
|
142 |
+
# apply filter twice to phase difference amounts of 360°
|
143 |
+
x = dasp_pytorch.signal.sosfilt_via_fsm(sos, x)
|
144 |
+
x_out = dasp_pytorch.signal.sosfilt_via_fsm(sos, x)
|
145 |
+
|
146 |
+
return x_out
|
147 |
+
|
148 |
+
|
149 |
+
def multiband_compressor(
|
150 |
+
x: torch.Tensor,
|
151 |
+
sample_rate: float,
|
152 |
+
|
153 |
+
low_cutoff: torch.Tensor,
|
154 |
+
high_cutoff: torch.Tensor,
|
155 |
+
parallel_weight_factor: torch.Tensor,
|
156 |
+
|
157 |
+
low_shelf_comp_thresh: torch.Tensor,
|
158 |
+
low_shelf_comp_ratio: torch.Tensor,
|
159 |
+
low_shelf_exp_thresh: torch.Tensor,
|
160 |
+
low_shelf_exp_ratio: torch.Tensor,
|
161 |
+
low_shelf_at: torch.Tensor,
|
162 |
+
low_shelf_rt: torch.Tensor,
|
163 |
+
|
164 |
+
mid_band_comp_thresh: torch.Tensor,
|
165 |
+
mid_band_comp_ratio: torch.Tensor,
|
166 |
+
mid_band_exp_thresh: torch.Tensor,
|
167 |
+
mid_band_exp_ratio: torch.Tensor,
|
168 |
+
mid_band_at: torch.Tensor,
|
169 |
+
mid_band_rt: torch.Tensor,
|
170 |
+
|
171 |
+
high_shelf_comp_thresh: torch.Tensor,
|
172 |
+
high_shelf_comp_ratio: torch.Tensor,
|
173 |
+
high_shelf_exp_thresh: torch.Tensor,
|
174 |
+
high_shelf_exp_ratio: torch.Tensor,
|
175 |
+
high_shelf_at: torch.Tensor,
|
176 |
+
high_shelf_rt: torch.Tensor,
|
177 |
+
):
|
178 |
+
"""Multiband (Three-band) Compressor.
|
179 |
+
|
180 |
+
Low-shelf -> Mid-band -> High-shelf
|
181 |
+
|
182 |
+
Args:
|
183 |
+
x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len)
|
184 |
+
sample_rate (float): Audio sample rate.
|
185 |
+
low_cutoff (torch.Tensor): Low-shelf filter cutoff frequency in Hz.
|
186 |
+
high_cutoff (torch.Tensor): High-shelf filter cutoff frequency in Hz.
|
187 |
+
low_shelf_comp_thresh (torch.Tensor):
|
188 |
+
low_shelf_comp_ratio (torch.Tensor):
|
189 |
+
low_shelf_exp_thresh (torch.Tensor):
|
190 |
+
low_shelf_exp_ratio (torch.Tensor):
|
191 |
+
low_shelf_at (torch.Tensor):
|
192 |
+
low_shelf_rt (torch.Tensor):
|
193 |
+
mid_band_comp_thresh (torch.Tensor):
|
194 |
+
mid_band_comp_ratio (torch.Tensor):
|
195 |
+
mid_band_exp_thresh (torch.Tensor):
|
196 |
+
mid_band_exp_ratio (torch.Tensor):
|
197 |
+
mid_band_at (torch.Tensor):
|
198 |
+
mid_band_rt (torch.Tensor):
|
199 |
+
high_shelf_comp_thresh (torch.Tensor):
|
200 |
+
high_shelf_comp_ratio (torch.Tensor):
|
201 |
+
high_shelf_exp_thresh (torch.Tensor):
|
202 |
+
high_shelf_exp_ratio (torch.Tensor):
|
203 |
+
high_shelf_at (torch.Tensor):
|
204 |
+
high_shelf_rt (torch.Tensor):
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
y (torch.Tensor): Filtered signal.
|
208 |
+
"""
|
209 |
+
bs, chs, seq_len = x.size()
|
210 |
+
|
211 |
+
low_cutoff = low_cutoff.view(-1, 1, 1)
|
212 |
+
high_cutoff = high_cutoff.view(-1, 1, 1)
|
213 |
+
parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1)
|
214 |
+
|
215 |
+
eff_bs = x.size(0)
|
216 |
+
|
217 |
+
''' cross over filter '''
|
218 |
+
# Low-shelf band (low frequencies)
|
219 |
+
low_band = linkwitz_riley_4th_order(x, low_cutoff, sample_rate, filter_type="low_pass")
|
220 |
+
# High-shelf band (high frequencies)
|
221 |
+
high_band = linkwitz_riley_4th_order(x, high_cutoff, sample_rate, filter_type="high_pass")
|
222 |
+
# Mid-band (band-pass)
|
223 |
+
mid_band = x - low_band - high_band # Subtract low and high bands from original signal
|
224 |
+
|
225 |
+
''' compressor '''
|
226 |
+
try:
|
227 |
+
x_out_low = low_band * torchcomp.compexp_gain(low_band.sum(axis=1).abs(),
|
228 |
+
comp_thresh=low_shelf_comp_thresh, \
|
229 |
+
comp_ratio=low_shelf_comp_ratio, \
|
230 |
+
exp_thresh=low_shelf_exp_thresh, \
|
231 |
+
exp_ratio=low_shelf_exp_ratio, \
|
232 |
+
at=torchcomp.ms2coef(low_shelf_at, sample_rate), \
|
233 |
+
rt=torchcomp.ms2coef(low_shelf_rt, sample_rate)).unsqueeze(1)
|
234 |
+
except:
|
235 |
+
x_out_low = low_band
|
236 |
+
print('\t!!!failed computing low-band compression!!!')
|
237 |
+
try:
|
238 |
+
x_out_high = high_band * torchcomp.compexp_gain(high_band.sum(axis=1).abs(),
|
239 |
+
comp_thresh=high_shelf_comp_thresh, \
|
240 |
+
comp_ratio=high_shelf_comp_ratio, \
|
241 |
+
exp_thresh=high_shelf_exp_thresh, \
|
242 |
+
exp_ratio=high_shelf_exp_ratio, \
|
243 |
+
at=torchcomp.ms2coef(high_shelf_at, sample_rate), \
|
244 |
+
rt=torchcomp.ms2coef(high_shelf_rt, sample_rate)).unsqueeze(1)
|
245 |
+
except:
|
246 |
+
x_out_high = high_band
|
247 |
+
print('\t!!!failed computing high-band compression!!!')
|
248 |
+
try:
|
249 |
+
x_out_mid = mid_band * torchcomp.compexp_gain(mid_band.sum(axis=1).abs(),
|
250 |
+
comp_thresh=mid_band_comp_thresh, \
|
251 |
+
comp_ratio=mid_band_comp_ratio, \
|
252 |
+
exp_thresh=mid_band_exp_thresh, \
|
253 |
+
exp_ratio=mid_band_exp_ratio, \
|
254 |
+
at=torchcomp.ms2coef(mid_band_at, sample_rate), \
|
255 |
+
rt=torchcomp.ms2coef(mid_band_rt, sample_rate)).unsqueeze(1)
|
256 |
+
except:
|
257 |
+
x_out_mid = mid_band
|
258 |
+
print('\t!!!failed computing mid-band compression!!!')
|
259 |
+
x_out = x_out_low + x_out_high + x_out_mid
|
260 |
+
|
261 |
+
# parallel computation
|
262 |
+
x_out = parallel_weight_factor * x_out + (1-parallel_weight_factor) * x
|
263 |
+
|
264 |
+
# move channels back
|
265 |
+
x_out = x_out.view(bs, chs, seq_len)
|
266 |
+
|
267 |
+
return x_out
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
class Limiter(Processor):
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
sample_rate: int,
|
276 |
+
min_threshold_db: float = -60.0,
|
277 |
+
max_threshold_db: float = 0.0-EPS,
|
278 |
+
min_attack_ms: float = 5.0,
|
279 |
+
max_attack_ms: float = 100.0,
|
280 |
+
min_release_ms: float = 5.0,
|
281 |
+
max_release_ms: float = 100.0,
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
self.sample_rate = sample_rate
|
285 |
+
self.process_fn = limiter
|
286 |
+
self.param_ranges = {
|
287 |
+
"threshold": (min_threshold_db, max_threshold_db),
|
288 |
+
"at": (min_attack_ms, max_attack_ms),
|
289 |
+
"rt": (min_release_ms, max_release_ms),
|
290 |
+
}
|
291 |
+
self.num_params = len(self.param_ranges)
|
292 |
+
|
293 |
+
|
294 |
+
def limiter(
|
295 |
+
x: torch.Tensor,
|
296 |
+
sample_rate: float,
|
297 |
+
threshold: float,
|
298 |
+
at: float,
|
299 |
+
rt: float,
|
300 |
+
):
|
301 |
+
"""Limiter.
|
302 |
+
|
303 |
+
from Chin-yun's paper
|
304 |
+
|
305 |
+
Args:
|
306 |
+
x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len)
|
307 |
+
sample_rate (float): Audio sample rate.
|
308 |
+
threshold (torch.Tensor): Limiter threshold in dB.
|
309 |
+
at (torch.Tensor): Attack time.
|
310 |
+
rt (torch.Tensor): Release time.
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
y (torch.Tensor): Limited signal.
|
314 |
+
"""
|
315 |
+
bs, chs, seq_len = x.size()
|
316 |
+
|
317 |
+
x_out = x * torchcomp.limiter_gain(x.sum(axis=1).abs(),
|
318 |
+
threshold=threshold,
|
319 |
+
at=torchcomp.ms2coef(at, sample_rate),
|
320 |
+
rt=torchcomp.ms2coef(rt, sample_rate)).unsqueeze(1)
|
321 |
+
|
322 |
+
# move channels back
|
323 |
+
x_out = x_out.view(bs, chs, seq_len)
|
324 |
+
|
325 |
+
return x_out
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
class Random_Augmentation_Dasp(nn.Module):
|
331 |
+
def __init__(self, sample_rate, \
|
332 |
+
tgt_fx_names = ['eq', 'comp', 'imager', 'gain']):
|
333 |
+
super(Random_Augmentation_Dasp, self).__init__()
|
334 |
+
self.sample_rate = sample_rate
|
335 |
+
self.tgt_fx_names = tgt_fx_names
|
336 |
+
|
337 |
+
self.device = torch.device("cpu")
|
338 |
+
if torch.cuda.is_available():
|
339 |
+
self.device = torch.device(f"cuda")
|
340 |
+
|
341 |
+
self.fx_prob = {'eq': 0.9, \
|
342 |
+
'distortion': 0.3, \
|
343 |
+
'comp': 0.8, \
|
344 |
+
'multiband_comp': 0.8, \
|
345 |
+
'gain': 0.85, \
|
346 |
+
'imager': 0.6, \
|
347 |
+
'limiter': 1.0}
|
348 |
+
self.fx_processors = {}
|
349 |
+
for cur_fx in tgt_fx_names:
|
350 |
+
if cur_fx=='eq':
|
351 |
+
cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=sample_rate, \
|
352 |
+
min_gain_db = -10.0, \
|
353 |
+
max_gain_db = 10.0, \
|
354 |
+
min_q_factor = 0.5, \
|
355 |
+
max_q_factor=5.0)
|
356 |
+
elif cur_fx=='distortion':
|
357 |
+
cur_fx_module = Distortion(sample_rate=sample_rate,
|
358 |
+
min_gain_db = 0.0,
|
359 |
+
max_gain_db = 4.0)
|
360 |
+
elif cur_fx=='comp':
|
361 |
+
cur_fx_module = dasp_pytorch.Compressor(sample_rate=sample_rate)
|
362 |
+
elif cur_fx=='multiband_comp':
|
363 |
+
cur_fx_module = Multiband_Compressor(sample_rate=sample_rate,
|
364 |
+
min_threshold_db_comp = -30.0,
|
365 |
+
max_threshold_db_comp = -5.0,
|
366 |
+
min_ratio_comp = 1.5,
|
367 |
+
max_ratio_comp = 6.0,
|
368 |
+
min_attack_ms_comp = 1.0,
|
369 |
+
max_attack_ms_comp = 20.0,
|
370 |
+
min_release_ms_comp = 20.0,
|
371 |
+
max_release_ms_comp = 500.0,
|
372 |
+
min_threshold_db_exp = -30.0,
|
373 |
+
max_threshold_db_exp = -5.0,
|
374 |
+
min_ratio_exp = 0.0+EPS,
|
375 |
+
max_ratio_exp = 1.0-EPS,
|
376 |
+
min_attack_ms_exp = 1.0,
|
377 |
+
max_attack_ms_exp = 20.0,
|
378 |
+
min_release_ms_exp = 20.0,
|
379 |
+
max_release_ms_exp = 500.0,
|
380 |
+
)
|
381 |
+
elif cur_fx=='gain':
|
382 |
+
cur_fx_module = dasp_pytorch.Gain(sample_rate=sample_rate,
|
383 |
+
min_gain_db = 0.0,
|
384 |
+
max_gain_db = 6.0,)
|
385 |
+
elif cur_fx=='imager':
|
386 |
+
continue
|
387 |
+
elif cur_fx=='limiter':
|
388 |
+
cur_fx_module = Limiter(sample_rate=sample_rate,
|
389 |
+
min_threshold_db = -20.0,
|
390 |
+
max_threshold_db = 0.0-EPS,
|
391 |
+
min_attack_ms = 0.1,
|
392 |
+
max_attack_ms = 5.0,
|
393 |
+
min_release_ms = 20.0,
|
394 |
+
max_release_ms = 1000.0,)
|
395 |
+
else:
|
396 |
+
raise AssertionError(f"current fx name ({cur_fx}) not found")
|
397 |
+
self.fx_processors[cur_fx] = cur_fx_module
|
398 |
+
total_num_param = sum([self.fx_processors[cur_fx].num_params for cur_fx in self.fx_processors])
|
399 |
+
if 'imager' in tgt_fx_names:
|
400 |
+
total_num_param += 1
|
401 |
+
self.total_num_param = total_num_param
|
402 |
+
|
403 |
+
|
404 |
+
# network forward operation
|
405 |
+
def forward(self, x, rand_param=None, use_mask=None):
|
406 |
+
if rand_param==None:
|
407 |
+
rand_param = torch.rand((x.shape[0], self.total_num_param)).to(self.device)
|
408 |
+
else:
|
409 |
+
assert rand_param.shape[0]==x.shape[0] and rand_param.shape[1]==self.total_num_param
|
410 |
+
if use_mask==None:
|
411 |
+
use_mask = self.random_mask_generator(x.shape[0])
|
412 |
+
|
413 |
+
# dafx chain
|
414 |
+
cur_param_idx = 0
|
415 |
+
for cur_fx in self.tgt_fx_names:
|
416 |
+
cur_param_count = 1 if cur_fx=='imager' else self.fx_processors[cur_fx].num_params
|
417 |
+
if cur_fx=='imager':
|
418 |
+
x_processed = dasp_pytorch.functional.stereo_widener(x, \
|
419 |
+
sample_rate=self.sample_rate, \
|
420 |
+
width=rand_param[:,cur_param_idx:cur_param_idx+1])
|
421 |
+
else:
|
422 |
+
cur_input_param = rand_param[:, cur_param_idx:cur_param_idx+cur_param_count]
|
423 |
+
x_processed = self.fx_processors[cur_fx].process_normalized(x, cur_input_param)
|
424 |
+
# process all FX but decide to use the processed output based on probability
|
425 |
+
cur_mask = use_mask[cur_fx]
|
426 |
+
x = x_processed*cur_mask + x*~cur_mask
|
427 |
+
# update param index
|
428 |
+
cur_param_idx += cur_param_count
|
429 |
+
|
430 |
+
return x
|
431 |
+
|
432 |
+
|
433 |
+
def random_mask_generator(self, batch_size, repeat=1):
|
434 |
+
mask = {}
|
435 |
+
for cur_fx in self.tgt_fx_names:
|
436 |
+
mask[cur_fx] = self.fx_prob[cur_fx] > torch.rand(batch_size).view(-1, 1, 1)
|
437 |
+
if repeat>1:
|
438 |
+
mask[cur_fx] = mask[cur_fx].repeat(repeat, 1, 1)
|
439 |
+
mask[cur_fx] = mask[cur_fx].to(self.device)
|
440 |
+
return mask
|
441 |
+
|
networks/network_utils.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility File
|
3 |
+
containing functions for neural networks
|
4 |
+
"""
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.nn.init as init
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
# 2-dimensional convolutional layer
|
14 |
+
# in the order of conv -> norm -> activation
|
15 |
+
class Conv2d_layer(nn.Module):
|
16 |
+
def __init__(self, in_channels, out_channels, kernel_size, \
|
17 |
+
stride=1, \
|
18 |
+
padding="SAME", dilation=(1,1), bias=True, \
|
19 |
+
norm="batch", activation="relu", \
|
20 |
+
mode="conv"):
|
21 |
+
super(Conv2d_layer, self).__init__()
|
22 |
+
|
23 |
+
self.conv2d = nn.Sequential()
|
24 |
+
|
25 |
+
if isinstance(kernel_size, int):
|
26 |
+
kernel_size = [kernel_size, kernel_size]
|
27 |
+
if isinstance(stride, int):
|
28 |
+
stride = [stride, stride]
|
29 |
+
if isinstance(dilation, int):
|
30 |
+
dilation = [dilation, dilation]
|
31 |
+
|
32 |
+
''' padding '''
|
33 |
+
if mode=="deconv":
|
34 |
+
padding = tuple(int((current_kernel - 1)/2) for current_kernel in kernel_size)
|
35 |
+
out_padding = tuple(0 if current_stride == 1 else 1 for current_stride in stride)
|
36 |
+
elif mode=="conv":
|
37 |
+
if padding == "SAME":
|
38 |
+
f_pad = int((kernel_size[0]-1) * dilation[0])
|
39 |
+
t_pad = int((kernel_size[1]-1) * dilation[1])
|
40 |
+
t_l_pad = int(t_pad//2)
|
41 |
+
t_r_pad = t_pad - t_l_pad
|
42 |
+
f_l_pad = int(f_pad//2)
|
43 |
+
f_r_pad = f_pad - f_l_pad
|
44 |
+
padding_area = (t_l_pad, t_r_pad, f_l_pad, f_r_pad)
|
45 |
+
elif padding == "VALID":
|
46 |
+
padding = 0
|
47 |
+
else:
|
48 |
+
pass
|
49 |
+
|
50 |
+
''' convolutional layer '''
|
51 |
+
if mode=="deconv":
|
52 |
+
self.conv2d.add_module("deconv2d", nn.ConvTranspose2d(in_channels, out_channels, \
|
53 |
+
(kernel_size[0], kernel_size[1]), \
|
54 |
+
stride=stride, \
|
55 |
+
padding=padding, output_padding=out_padding, \
|
56 |
+
dilation=dilation, \
|
57 |
+
bias=bias))
|
58 |
+
elif mode=="conv":
|
59 |
+
self.conv2d.add_module(f"{mode}2d_pad", nn.ReflectionPad2d(padding_area))
|
60 |
+
self.conv2d.add_module(f"{mode}2d", nn.Conv2d(in_channels, out_channels, \
|
61 |
+
(kernel_size[0], kernel_size[1]), \
|
62 |
+
stride=stride, \
|
63 |
+
padding=0, \
|
64 |
+
dilation=dilation, \
|
65 |
+
bias=bias))
|
66 |
+
|
67 |
+
''' normalization '''
|
68 |
+
if norm=="batch":
|
69 |
+
self.conv2d.add_module("batch_norm", nn.BatchNorm2d(out_channels))
|
70 |
+
|
71 |
+
''' activation '''
|
72 |
+
if activation=="relu":
|
73 |
+
self.conv2d.add_module("relu", nn.ReLU())
|
74 |
+
elif activation=="lrelu":
|
75 |
+
self.conv2d.add_module("lrelu", nn.LeakyReLU())
|
76 |
+
|
77 |
+
|
78 |
+
def forward(self, input):
|
79 |
+
# input shape should be : batch x channel x height x width
|
80 |
+
output = self.conv2d(input)
|
81 |
+
return output
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
# 1-dimensional convolutional layer
|
86 |
+
# in the order of conv -> norm -> activation
|
87 |
+
class Conv1d_layer(nn.Module):
|
88 |
+
def __init__(self, in_channels, out_channels, kernel_size, \
|
89 |
+
stride=1, \
|
90 |
+
padding="SAME", dilation=1, bias=True, \
|
91 |
+
norm="batch", activation="relu", \
|
92 |
+
mode="conv"):
|
93 |
+
super(Conv1d_layer, self).__init__()
|
94 |
+
|
95 |
+
self.conv1d = nn.Sequential()
|
96 |
+
|
97 |
+
''' padding '''
|
98 |
+
if mode=="deconv":
|
99 |
+
padding = int(dilation * (kernel_size-1) / 2)
|
100 |
+
out_padding = 0 if stride==1 else 1
|
101 |
+
elif mode=="conv" or "alias_free" in mode:
|
102 |
+
if padding == "SAME":
|
103 |
+
pad = int((kernel_size-1) * dilation)
|
104 |
+
l_pad = int(pad//2)
|
105 |
+
r_pad = pad - l_pad
|
106 |
+
padding_area = (l_pad, r_pad)
|
107 |
+
elif padding == "VALID":
|
108 |
+
padding_area = (0, 0)
|
109 |
+
else:
|
110 |
+
pass
|
111 |
+
|
112 |
+
''' convolutional layer '''
|
113 |
+
if mode=="deconv":
|
114 |
+
self.conv1d.add_module("deconv1d", nn.ConvTranspose1d(in_channels, out_channels, kernel_size, \
|
115 |
+
stride=stride, padding=padding, output_padding=out_padding, \
|
116 |
+
dilation=dilation, \
|
117 |
+
bias=bias))
|
118 |
+
elif mode=="conv":
|
119 |
+
self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
|
120 |
+
self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
|
121 |
+
stride=stride, padding=0, \
|
122 |
+
dilation=dilation, \
|
123 |
+
bias=bias))
|
124 |
+
elif "alias_free" in mode:
|
125 |
+
if "up" in mode:
|
126 |
+
up_factor = stride * 2
|
127 |
+
down_factor = 2
|
128 |
+
elif "down" in mode:
|
129 |
+
up_factor = 2
|
130 |
+
down_factor = stride * 2
|
131 |
+
else:
|
132 |
+
raise ValueError("choose alias-free method : 'up' or 'down'")
|
133 |
+
# procedure : conv -> upsample -> lrelu -> low-pass filter -> downsample
|
134 |
+
# the torchaudio.transforms.Resample's default resampling_method is 'sinc_interpolation' which performs low-pass filter during the process
|
135 |
+
# details at https://pytorch.org/audio/stable/transforms.html
|
136 |
+
self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
|
137 |
+
self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
|
138 |
+
stride=1, padding=0, \
|
139 |
+
dilation=dilation, \
|
140 |
+
bias=bias))
|
141 |
+
self.conv1d.add_module(f"{mode}upsample", torchaudio.transforms.Resample(orig_freq=1, new_freq=up_factor))
|
142 |
+
self.conv1d.add_module(f"{mode}lrelu", nn.LeakyReLU())
|
143 |
+
self.conv1d.add_module(f"{mode}downsample", torchaudio.transforms.Resample(orig_freq=down_factor, new_freq=1))
|
144 |
+
|
145 |
+
''' normalization '''
|
146 |
+
if norm=="batch":
|
147 |
+
self.conv1d.add_module("batch_norm", nn.BatchNorm1d(out_channels))
|
148 |
+
# self.conv1d.add_module("batch_norm", nn.SyncBatchNorm(out_channels))
|
149 |
+
|
150 |
+
''' activation '''
|
151 |
+
if 'alias_free' not in mode:
|
152 |
+
if activation=="relu":
|
153 |
+
self.conv1d.add_module("relu", nn.ReLU())
|
154 |
+
elif activation=="lrelu":
|
155 |
+
self.conv1d.add_module("lrelu", nn.LeakyReLU())
|
156 |
+
|
157 |
+
|
158 |
+
def forward(self, input):
|
159 |
+
# input shape should be : batch x channel x height x width
|
160 |
+
output = self.conv1d(input)
|
161 |
+
return output
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
# Residual Block
|
166 |
+
# the input is added after the first convolutional layer, retaining its original channel size
|
167 |
+
# therefore, the second convolutional layer's output channel may differ
|
168 |
+
class Res_ConvBlock(nn.Module):
|
169 |
+
def __init__(self, dimension, \
|
170 |
+
in_channels, out_channels, \
|
171 |
+
kernel_size, \
|
172 |
+
stride=1, padding="SAME", \
|
173 |
+
dilation=1, \
|
174 |
+
bias=True, \
|
175 |
+
norm="batch", \
|
176 |
+
activation="relu", last_activation="relu", \
|
177 |
+
mode="conv"):
|
178 |
+
super(Res_ConvBlock, self).__init__()
|
179 |
+
|
180 |
+
if dimension==1:
|
181 |
+
self.conv1 = Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
|
182 |
+
self.conv2 = Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
|
183 |
+
elif dimension==2:
|
184 |
+
self.conv1 = Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
|
185 |
+
self.conv2 = Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
|
186 |
+
|
187 |
+
|
188 |
+
def forward(self, input):
|
189 |
+
c1_out = self.conv1(input) + input
|
190 |
+
c2_out = self.conv2(c1_out)
|
191 |
+
return c2_out
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
# Convoluaionl Block
|
196 |
+
# consists of multiple (number of layer_num) convolutional layers
|
197 |
+
# only the final convoluational layer outputs the desired 'out_channels'
|
198 |
+
class ConvBlock(nn.Module):
|
199 |
+
def __init__(self, dimension, layer_num, \
|
200 |
+
in_channels, out_channels, \
|
201 |
+
kernel_size, \
|
202 |
+
stride=1, padding="SAME", \
|
203 |
+
dilation=1, \
|
204 |
+
bias=True, \
|
205 |
+
norm="batch", \
|
206 |
+
activation="relu", last_activation="relu", \
|
207 |
+
mode="conv"):
|
208 |
+
super(ConvBlock, self).__init__()
|
209 |
+
|
210 |
+
conv_block = []
|
211 |
+
if dimension==1:
|
212 |
+
for i in range(layer_num-1):
|
213 |
+
conv_block.append(Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
|
214 |
+
conv_block.append(Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
|
215 |
+
elif dimension==2:
|
216 |
+
for i in range(layer_num-1):
|
217 |
+
conv_block.append(Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
|
218 |
+
conv_block.append(Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
|
219 |
+
self.conv_block = nn.Sequential(*conv_block)
|
220 |
+
|
221 |
+
|
222 |
+
def forward(self, input):
|
223 |
+
return self.conv_block(input)
|
224 |
+
|
225 |
+
|
226 |
+
# Feature-wise Linear Modulation
|
227 |
+
class FiLM(nn.Module):
|
228 |
+
def __init__(self, condition_len=2048, feature_len=1024):
|
229 |
+
super(FiLM, self).__init__()
|
230 |
+
self.film_fc = nn.Linear(condition_len, feature_len*2)
|
231 |
+
self.feat_len = feature_len
|
232 |
+
|
233 |
+
|
234 |
+
def forward(self, feature, condition, sefa=None):
|
235 |
+
# SeFA
|
236 |
+
if sefa:
|
237 |
+
weight = self.film_fc.weight.T
|
238 |
+
weight = weight / torch.linalg.norm((weight+1e-07), dim=0, keepdims=True)
|
239 |
+
eigen_values, eigen_vectors = torch.eig(torch.matmul(weight, weight.T), eigenvectors=True)
|
240 |
+
|
241 |
+
####### custom parameters #######
|
242 |
+
chosen_eig_idx = sefa[0]
|
243 |
+
alpha = eigen_values[chosen_eig_idx][0] * sefa[1]
|
244 |
+
#################################
|
245 |
+
|
246 |
+
An = eigen_vectors[chosen_eig_idx].repeat(condition.shape[0], 1)
|
247 |
+
alpha_An = alpha * An
|
248 |
+
|
249 |
+
condition += alpha_An
|
250 |
+
|
251 |
+
film_factor = self.film_fc(condition).unsqueeze(-1)
|
252 |
+
r, b = torch.split(film_factor, self.feat_len, dim=1)
|
253 |
+
return r*feature + b
|
254 |
+
|