Spaces:
Runtime error
Runtime error
akhaliq3
commited on
Commit
·
5019931
1
Parent(s):
bee7f54
spaces demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +13 -0
- bytesep/__init__.py +1 -0
- bytesep/callbacks/__init__.py +76 -0
- bytesep/callbacks/base_callbacks.py +44 -0
- bytesep/callbacks/instruments_callbacks.py +200 -0
- bytesep/callbacks/musdb18.py +485 -0
- bytesep/callbacks/voicebank_demand.py +231 -0
- bytesep/data/__init__.py +0 -0
- bytesep/data/augmentors.py +157 -0
- bytesep/data/batch_data_preprocessors.py +141 -0
- bytesep/data/data_modules.py +187 -0
- bytesep/data/samplers.py +188 -0
- bytesep/dataset_creation/__init__.py +0 -0
- bytesep/dataset_creation/create_evaluation_audios/__init__.py +0 -0
- bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py +160 -0
- bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py +164 -0
- bytesep/dataset_creation/create_evaluation_audios/violin-piano.py +162 -0
- bytesep/dataset_creation/create_indexes/__init__.py +0 -0
- bytesep/dataset_creation/create_indexes/create_indexes.py +142 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py +0 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py +173 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py +136 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py +207 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py +114 -0
- bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py +143 -0
- bytesep/inference.py +402 -0
- bytesep/inference_many.py +163 -0
- bytesep/losses.py +106 -0
- bytesep/models/__init__.py +0 -0
- bytesep/models/conditional_unet.py +496 -0
- bytesep/models/lightning_modules.py +188 -0
- bytesep/models/pytorch_modules.py +204 -0
- bytesep/models/resunet.py +516 -0
- bytesep/models/resunet_ismir2021.py +534 -0
- bytesep/models/resunet_subbandtime.py +545 -0
- bytesep/models/subband_tools/__init__.py +0 -0
- bytesep/models/subband_tools/fDomainHelper.py +255 -0
- bytesep/models/subband_tools/filters/f_4_64.mat +0 -0
- bytesep/models/subband_tools/filters/h_4_64.mat +0 -0
- bytesep/models/subband_tools/pqmf.py +136 -0
- bytesep/models/unet.py +532 -0
- bytesep/models/unet_subbandtime.py +389 -0
- bytesep/optimizers/__init__.py +0 -0
- bytesep/optimizers/lr_schedulers.py +20 -0
- bytesep/plot_results/__init__.py +0 -0
- bytesep/plot_results/musdb18.py +198 -0
- bytesep/plot_results/plot_vctk-musdb18.py +87 -0
- bytesep/train.py +299 -0
- bytesep/utils.py +189 -0
- pyproject.toml +21 -0
LICENSE
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2021 ByteDance
|
2 |
+
|
3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
you may not use this file except in compliance with the License.
|
5 |
+
You may obtain a copy of the License at
|
6 |
+
|
7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
Unless required by applicable law or agreed to in writing, software
|
10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
See the License for the specific language governing permissions and
|
13 |
+
limitations under the License.
|
bytesep/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from bytesep.inference import Separator
|
bytesep/callbacks/__init__.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def get_callbacks(
|
8 |
+
task_name: str,
|
9 |
+
config_yaml: str,
|
10 |
+
workspace: str,
|
11 |
+
checkpoints_dir: str,
|
12 |
+
statistics_path: str,
|
13 |
+
logger: pl.loggers.TensorBoardLogger,
|
14 |
+
model: nn.Module,
|
15 |
+
evaluate_device: str,
|
16 |
+
) -> List[pl.Callback]:
|
17 |
+
r"""Get callbacks of a task and config yaml file.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
task_name: str
|
21 |
+
config_yaml: str
|
22 |
+
dataset_dir: str
|
23 |
+
workspace: str, containing useful files such as audios for evaluation
|
24 |
+
checkpoints_dir: str, directory to save checkpoints
|
25 |
+
statistics_dir: str, directory to save statistics
|
26 |
+
logger: pl.loggers.TensorBoardLogger
|
27 |
+
model: nn.Module
|
28 |
+
evaluate_device: str
|
29 |
+
|
30 |
+
Return:
|
31 |
+
callbacks: List[pl.Callback]
|
32 |
+
"""
|
33 |
+
if task_name == 'musdb18':
|
34 |
+
|
35 |
+
from bytesep.callbacks.musdb18 import get_musdb18_callbacks
|
36 |
+
|
37 |
+
return get_musdb18_callbacks(
|
38 |
+
config_yaml=config_yaml,
|
39 |
+
workspace=workspace,
|
40 |
+
checkpoints_dir=checkpoints_dir,
|
41 |
+
statistics_path=statistics_path,
|
42 |
+
logger=logger,
|
43 |
+
model=model,
|
44 |
+
evaluate_device=evaluate_device,
|
45 |
+
)
|
46 |
+
|
47 |
+
elif task_name == 'voicebank-demand':
|
48 |
+
|
49 |
+
from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks
|
50 |
+
|
51 |
+
return get_voicebank_demand_callbacks(
|
52 |
+
config_yaml=config_yaml,
|
53 |
+
workspace=workspace,
|
54 |
+
checkpoints_dir=checkpoints_dir,
|
55 |
+
statistics_path=statistics_path,
|
56 |
+
logger=logger,
|
57 |
+
model=model,
|
58 |
+
evaluate_device=evaluate_device,
|
59 |
+
)
|
60 |
+
|
61 |
+
elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']:
|
62 |
+
|
63 |
+
from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks
|
64 |
+
|
65 |
+
return get_instruments_callbacks(
|
66 |
+
config_yaml=config_yaml,
|
67 |
+
workspace=workspace,
|
68 |
+
checkpoints_dir=checkpoints_dir,
|
69 |
+
statistics_path=statistics_path,
|
70 |
+
logger=logger,
|
71 |
+
model=model,
|
72 |
+
evaluate_device=evaluate_device,
|
73 |
+
)
|
74 |
+
|
75 |
+
else:
|
76 |
+
raise NotImplementedError
|
bytesep/callbacks/base_callbacks.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import NoReturn
|
4 |
+
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from pytorch_lightning.utilities import rank_zero_only
|
9 |
+
|
10 |
+
|
11 |
+
class SaveCheckpointsCallback(pl.Callback):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
model: nn.Module,
|
15 |
+
checkpoints_dir: str,
|
16 |
+
save_step_frequency: int,
|
17 |
+
):
|
18 |
+
r"""Callback to save checkpoints every #save_step_frequency steps.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model: nn.Module
|
22 |
+
checkpoints_dir: str, directory to save checkpoints
|
23 |
+
save_step_frequency: int
|
24 |
+
"""
|
25 |
+
self.model = model
|
26 |
+
self.checkpoints_dir = checkpoints_dir
|
27 |
+
self.save_step_frequency = save_step_frequency
|
28 |
+
os.makedirs(self.checkpoints_dir, exist_ok=True)
|
29 |
+
|
30 |
+
@rank_zero_only
|
31 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
32 |
+
r"""Save checkpoint."""
|
33 |
+
global_step = trainer.global_step
|
34 |
+
|
35 |
+
if global_step % self.save_step_frequency == 0:
|
36 |
+
|
37 |
+
checkpoint_path = os.path.join(
|
38 |
+
self.checkpoints_dir, "step={}.pth".format(global_step)
|
39 |
+
)
|
40 |
+
|
41 |
+
checkpoint = {'step': global_step, 'model': self.model.state_dict()}
|
42 |
+
|
43 |
+
torch.save(checkpoint, checkpoint_path)
|
44 |
+
logging.info("Save checkpoint to {}".format(checkpoint_path))
|
bytesep/callbacks/instruments_callbacks.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import List, NoReturn
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch.nn as nn
|
10 |
+
from pytorch_lightning.utilities import rank_zero_only
|
11 |
+
|
12 |
+
from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
|
13 |
+
from bytesep.inference import Separator
|
14 |
+
from bytesep.utils import StatisticsContainer, calculate_sdr, read_yaml
|
15 |
+
|
16 |
+
|
17 |
+
def get_instruments_callbacks(
|
18 |
+
config_yaml: str,
|
19 |
+
workspace: str,
|
20 |
+
checkpoints_dir: str,
|
21 |
+
statistics_path: str,
|
22 |
+
logger: pl.loggers.TensorBoardLogger,
|
23 |
+
model: nn.Module,
|
24 |
+
evaluate_device: str,
|
25 |
+
) -> List[pl.Callback]:
|
26 |
+
"""Get Voicebank-Demand callbacks of a config yaml.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
config_yaml: str
|
30 |
+
workspace: str
|
31 |
+
checkpoints_dir: str, directory to save checkpoints
|
32 |
+
statistics_dir: str, directory to save statistics
|
33 |
+
logger: pl.loggers.TensorBoardLogger
|
34 |
+
model: nn.Module
|
35 |
+
evaluate_device: str
|
36 |
+
|
37 |
+
Return:
|
38 |
+
callbacks: List[pl.Callback]
|
39 |
+
"""
|
40 |
+
configs = read_yaml(config_yaml)
|
41 |
+
task_name = configs['task_name']
|
42 |
+
target_source_types = configs['train']['target_source_types']
|
43 |
+
input_channels = configs['train']['channels']
|
44 |
+
mono = True if input_channels == 1 else False
|
45 |
+
test_audios_dir = os.path.join(workspace, "evaluation_audios", task_name, "test")
|
46 |
+
sample_rate = configs['train']['sample_rate']
|
47 |
+
evaluate_step_frequency = configs['train']['evaluate_step_frequency']
|
48 |
+
save_step_frequency = configs['train']['save_step_frequency']
|
49 |
+
test_batch_size = configs['evaluate']['batch_size']
|
50 |
+
test_segment_seconds = configs['evaluate']['segment_seconds']
|
51 |
+
|
52 |
+
test_segment_samples = int(test_segment_seconds * sample_rate)
|
53 |
+
assert len(target_source_types) == 1
|
54 |
+
target_source_type = target_source_types[0]
|
55 |
+
|
56 |
+
# save checkpoint callback
|
57 |
+
save_checkpoints_callback = SaveCheckpointsCallback(
|
58 |
+
model=model,
|
59 |
+
checkpoints_dir=checkpoints_dir,
|
60 |
+
save_step_frequency=save_step_frequency,
|
61 |
+
)
|
62 |
+
|
63 |
+
# statistics container
|
64 |
+
statistics_container = StatisticsContainer(statistics_path)
|
65 |
+
|
66 |
+
# evaluation callback
|
67 |
+
evaluate_test_callback = EvaluationCallback(
|
68 |
+
model=model,
|
69 |
+
target_source_type=target_source_type,
|
70 |
+
input_channels=input_channels,
|
71 |
+
sample_rate=sample_rate,
|
72 |
+
mono=mono,
|
73 |
+
evaluation_audios_dir=test_audios_dir,
|
74 |
+
segment_samples=test_segment_samples,
|
75 |
+
batch_size=test_batch_size,
|
76 |
+
device=evaluate_device,
|
77 |
+
evaluate_step_frequency=evaluate_step_frequency,
|
78 |
+
logger=logger,
|
79 |
+
statistics_container=statistics_container,
|
80 |
+
)
|
81 |
+
|
82 |
+
callbacks = [save_checkpoints_callback, evaluate_test_callback]
|
83 |
+
# callbacks = [save_checkpoints_callback]
|
84 |
+
|
85 |
+
return callbacks
|
86 |
+
|
87 |
+
|
88 |
+
class EvaluationCallback(pl.Callback):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
model: nn.Module,
|
92 |
+
input_channels: int,
|
93 |
+
evaluation_audios_dir: str,
|
94 |
+
target_source_type: str,
|
95 |
+
sample_rate: int,
|
96 |
+
mono: bool,
|
97 |
+
segment_samples: int,
|
98 |
+
batch_size: int,
|
99 |
+
device: str,
|
100 |
+
evaluate_step_frequency: int,
|
101 |
+
logger: pl.loggers.TensorBoardLogger,
|
102 |
+
statistics_container: StatisticsContainer,
|
103 |
+
):
|
104 |
+
r"""Callback to evaluate every #save_step_frequency steps.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
model: nn.Module
|
108 |
+
input_channels: int
|
109 |
+
evaluation_audios_dir: str, directory containing audios for evaluation
|
110 |
+
target_source_type: str, e.g., 'violin'
|
111 |
+
sample_rate: int
|
112 |
+
mono: bool
|
113 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
114 |
+
batch_size, int, e.g., 12
|
115 |
+
device: str, e.g., 'cuda'
|
116 |
+
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
|
117 |
+
logger: pl.loggers.TensorBoardLogger
|
118 |
+
statistics_container: StatisticsContainer
|
119 |
+
"""
|
120 |
+
self.model = model
|
121 |
+
self.target_source_type = target_source_type
|
122 |
+
self.sample_rate = sample_rate
|
123 |
+
self.mono = mono
|
124 |
+
self.segment_samples = segment_samples
|
125 |
+
self.evaluate_step_frequency = evaluate_step_frequency
|
126 |
+
self.logger = logger
|
127 |
+
self.statistics_container = statistics_container
|
128 |
+
|
129 |
+
self.evaluation_audios_dir = evaluation_audios_dir
|
130 |
+
|
131 |
+
# separator
|
132 |
+
self.separator = Separator(model, self.segment_samples, batch_size, device)
|
133 |
+
|
134 |
+
@rank_zero_only
|
135 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
136 |
+
r"""Evaluate losses on a few mini-batches. Losses are only used for
|
137 |
+
observing training, and are not final F1 metrics.
|
138 |
+
"""
|
139 |
+
|
140 |
+
global_step = trainer.global_step
|
141 |
+
|
142 |
+
if global_step % self.evaluate_step_frequency == 0:
|
143 |
+
|
144 |
+
mixture_audios_dir = os.path.join(self.evaluation_audios_dir, 'mixture')
|
145 |
+
clean_audios_dir = os.path.join(
|
146 |
+
self.evaluation_audios_dir, self.target_source_type
|
147 |
+
)
|
148 |
+
|
149 |
+
audio_names = sorted(os.listdir(mixture_audios_dir))
|
150 |
+
|
151 |
+
error_str = "Directory {} does not contain audios for evaluation!".format(
|
152 |
+
self.evaluation_audios_dir
|
153 |
+
)
|
154 |
+
assert len(audio_names) > 0, error_str
|
155 |
+
|
156 |
+
logging.info("--- Step {} ---".format(global_step))
|
157 |
+
logging.info("Total {} pieces for evaluation:".format(len(audio_names)))
|
158 |
+
|
159 |
+
eval_time = time.time()
|
160 |
+
|
161 |
+
sdrs = []
|
162 |
+
|
163 |
+
for n, audio_name in enumerate(audio_names):
|
164 |
+
|
165 |
+
# Load audio.
|
166 |
+
mixture_path = os.path.join(mixture_audios_dir, audio_name)
|
167 |
+
clean_path = os.path.join(clean_audios_dir, audio_name)
|
168 |
+
|
169 |
+
mixture, origin_fs = librosa.core.load(
|
170 |
+
mixture_path, sr=self.sample_rate, mono=self.mono
|
171 |
+
)
|
172 |
+
|
173 |
+
# Target
|
174 |
+
clean, origin_fs = librosa.core.load(
|
175 |
+
clean_path, sr=self.sample_rate, mono=self.mono
|
176 |
+
)
|
177 |
+
|
178 |
+
if mixture.ndim == 1:
|
179 |
+
mixture = mixture[None, :]
|
180 |
+
# (channels_num, audio_length)
|
181 |
+
|
182 |
+
input_dict = {'waveform': mixture}
|
183 |
+
|
184 |
+
# separate
|
185 |
+
sep_wav = self.separator.separate(input_dict)
|
186 |
+
# (channels_num, audio_length)
|
187 |
+
|
188 |
+
sdr = calculate_sdr(ref=clean, est=sep_wav)
|
189 |
+
|
190 |
+
print("{} SDR: {:.3f}".format(audio_name, sdr))
|
191 |
+
sdrs.append(sdr)
|
192 |
+
|
193 |
+
logging.info("-----------------------------")
|
194 |
+
logging.info('Avg SDR: {:.3f}'.format(np.mean(sdrs)))
|
195 |
+
|
196 |
+
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
|
197 |
+
|
198 |
+
statistics = {"sdr": np.mean(sdrs)}
|
199 |
+
self.statistics_container.append(global_step, statistics, 'test')
|
200 |
+
self.statistics_container.dump()
|
bytesep/callbacks/musdb18.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Dict, List, NoReturn
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import musdb
|
8 |
+
import museval
|
9 |
+
import numpy as np
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
import torch.nn as nn
|
12 |
+
from pytorch_lightning.utilities import rank_zero_only
|
13 |
+
|
14 |
+
from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
|
15 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio
|
16 |
+
from bytesep.inference import Separator
|
17 |
+
from bytesep.utils import StatisticsContainer, read_yaml
|
18 |
+
|
19 |
+
|
20 |
+
def get_musdb18_callbacks(
|
21 |
+
config_yaml: str,
|
22 |
+
workspace: str,
|
23 |
+
checkpoints_dir: str,
|
24 |
+
statistics_path: str,
|
25 |
+
logger: pl.loggers.TensorBoardLogger,
|
26 |
+
model: nn.Module,
|
27 |
+
evaluate_device: str,
|
28 |
+
) -> List[pl.Callback]:
|
29 |
+
r"""Get MUSDB18 callbacks of a config yaml.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
config_yaml: str
|
33 |
+
workspace: str
|
34 |
+
checkpoints_dir: str, directory to save checkpoints
|
35 |
+
statistics_dir: str, directory to save statistics
|
36 |
+
logger: pl.loggers.TensorBoardLogger
|
37 |
+
model: nn.Module
|
38 |
+
evaluate_device: str
|
39 |
+
|
40 |
+
Return:
|
41 |
+
callbacks: List[pl.Callback]
|
42 |
+
"""
|
43 |
+
configs = read_yaml(config_yaml)
|
44 |
+
task_name = configs['task_name']
|
45 |
+
evaluation_callback = configs['train']['evaluation_callback']
|
46 |
+
target_source_types = configs['train']['target_source_types']
|
47 |
+
input_channels = configs['train']['channels']
|
48 |
+
evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
|
49 |
+
test_segment_seconds = configs['evaluate']['segment_seconds']
|
50 |
+
sample_rate = configs['train']['sample_rate']
|
51 |
+
test_segment_samples = int(test_segment_seconds * sample_rate)
|
52 |
+
test_batch_size = configs['evaluate']['batch_size']
|
53 |
+
|
54 |
+
evaluate_step_frequency = configs['train']['evaluate_step_frequency']
|
55 |
+
save_step_frequency = configs['train']['save_step_frequency']
|
56 |
+
|
57 |
+
# save checkpoint callback
|
58 |
+
save_checkpoints_callback = SaveCheckpointsCallback(
|
59 |
+
model=model,
|
60 |
+
checkpoints_dir=checkpoints_dir,
|
61 |
+
save_step_frequency=save_step_frequency,
|
62 |
+
)
|
63 |
+
|
64 |
+
# evaluation callback
|
65 |
+
EvaluationCallback = _get_evaluation_callback_class(evaluation_callback)
|
66 |
+
|
67 |
+
# statistics container
|
68 |
+
statistics_container = StatisticsContainer(statistics_path)
|
69 |
+
|
70 |
+
# evaluation callback
|
71 |
+
evaluate_train_callback = EvaluationCallback(
|
72 |
+
dataset_dir=evaluation_audios_dir,
|
73 |
+
model=model,
|
74 |
+
target_source_types=target_source_types,
|
75 |
+
input_channels=input_channels,
|
76 |
+
sample_rate=sample_rate,
|
77 |
+
split='train',
|
78 |
+
segment_samples=test_segment_samples,
|
79 |
+
batch_size=test_batch_size,
|
80 |
+
device=evaluate_device,
|
81 |
+
evaluate_step_frequency=evaluate_step_frequency,
|
82 |
+
logger=logger,
|
83 |
+
statistics_container=statistics_container,
|
84 |
+
)
|
85 |
+
|
86 |
+
evaluate_test_callback = EvaluationCallback(
|
87 |
+
dataset_dir=evaluation_audios_dir,
|
88 |
+
model=model,
|
89 |
+
target_source_types=target_source_types,
|
90 |
+
input_channels=input_channels,
|
91 |
+
sample_rate=sample_rate,
|
92 |
+
split='test',
|
93 |
+
segment_samples=test_segment_samples,
|
94 |
+
batch_size=test_batch_size,
|
95 |
+
device=evaluate_device,
|
96 |
+
evaluate_step_frequency=evaluate_step_frequency,
|
97 |
+
logger=logger,
|
98 |
+
statistics_container=statistics_container,
|
99 |
+
)
|
100 |
+
|
101 |
+
# callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback]
|
102 |
+
callbacks = [save_checkpoints_callback, evaluate_test_callback]
|
103 |
+
|
104 |
+
return callbacks
|
105 |
+
|
106 |
+
|
107 |
+
def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback:
|
108 |
+
r"""Get evaluation callback class."""
|
109 |
+
if evaluation_callback == "Musdb18EvaluationCallback":
|
110 |
+
return Musdb18EvaluationCallback
|
111 |
+
|
112 |
+
if evaluation_callback == 'Musdb18ConditionalEvaluationCallback':
|
113 |
+
return Musdb18ConditionalEvaluationCallback
|
114 |
+
|
115 |
+
else:
|
116 |
+
raise NotImplementedError
|
117 |
+
|
118 |
+
|
119 |
+
class Musdb18EvaluationCallback(pl.Callback):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
dataset_dir: str,
|
123 |
+
model: nn.Module,
|
124 |
+
target_source_types: str,
|
125 |
+
input_channels: int,
|
126 |
+
split: str,
|
127 |
+
sample_rate: int,
|
128 |
+
segment_samples: int,
|
129 |
+
batch_size: int,
|
130 |
+
device: str,
|
131 |
+
evaluate_step_frequency: int,
|
132 |
+
logger: pl.loggers.TensorBoardLogger,
|
133 |
+
statistics_container: StatisticsContainer,
|
134 |
+
):
|
135 |
+
r"""Callback to evaluate every #save_step_frequency steps.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
dataset_dir: str
|
139 |
+
model: nn.Module
|
140 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
141 |
+
input_channels: int
|
142 |
+
split: 'train' | 'test'
|
143 |
+
sample_rate: int
|
144 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
145 |
+
batch_size, int, e.g., 12
|
146 |
+
device: str, e.g., 'cuda'
|
147 |
+
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
|
148 |
+
logger: object
|
149 |
+
statistics_container: StatisticsContainer
|
150 |
+
"""
|
151 |
+
self.model = model
|
152 |
+
self.target_source_types = target_source_types
|
153 |
+
self.input_channels = input_channels
|
154 |
+
self.sample_rate = sample_rate
|
155 |
+
self.split = split
|
156 |
+
self.segment_samples = segment_samples
|
157 |
+
self.evaluate_step_frequency = evaluate_step_frequency
|
158 |
+
self.logger = logger
|
159 |
+
self.statistics_container = statistics_container
|
160 |
+
self.mono = input_channels == 1
|
161 |
+
self.resample_type = "kaiser_fast"
|
162 |
+
|
163 |
+
self.mus = musdb.DB(root=dataset_dir, subsets=[split])
|
164 |
+
|
165 |
+
error_msg = "The directory {} is empty!".format(dataset_dir)
|
166 |
+
assert len(self.mus) > 0, error_msg
|
167 |
+
|
168 |
+
# separator
|
169 |
+
self.separator = Separator(model, self.segment_samples, batch_size, device)
|
170 |
+
|
171 |
+
@rank_zero_only
|
172 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
173 |
+
r"""Evaluate separation SDRs of audio recordings."""
|
174 |
+
global_step = trainer.global_step
|
175 |
+
|
176 |
+
if global_step % self.evaluate_step_frequency == 0:
|
177 |
+
|
178 |
+
sdr_dict = {}
|
179 |
+
|
180 |
+
logging.info("--- Step {} ---".format(global_step))
|
181 |
+
logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
|
182 |
+
|
183 |
+
eval_time = time.time()
|
184 |
+
|
185 |
+
for track in self.mus.tracks:
|
186 |
+
|
187 |
+
audio_name = track.name
|
188 |
+
|
189 |
+
# Get waveform of mixture.
|
190 |
+
mixture = track.audio.T
|
191 |
+
# (channels_num, audio_samples)
|
192 |
+
|
193 |
+
mixture = preprocess_audio(
|
194 |
+
audio=mixture,
|
195 |
+
mono=self.mono,
|
196 |
+
origin_sr=track.rate,
|
197 |
+
sr=self.sample_rate,
|
198 |
+
resample_type=self.resample_type,
|
199 |
+
)
|
200 |
+
# (channels_num, audio_samples)
|
201 |
+
|
202 |
+
target_dict = {}
|
203 |
+
sdr_dict[audio_name] = {}
|
204 |
+
|
205 |
+
# Get waveform of all target source types.
|
206 |
+
for j, source_type in enumerate(self.target_source_types):
|
207 |
+
# E.g., ['vocals', 'bass', ...]
|
208 |
+
|
209 |
+
audio = track.targets[source_type].audio.T
|
210 |
+
|
211 |
+
audio = preprocess_audio(
|
212 |
+
audio=audio,
|
213 |
+
mono=self.mono,
|
214 |
+
origin_sr=track.rate,
|
215 |
+
sr=self.sample_rate,
|
216 |
+
resample_type=self.resample_type,
|
217 |
+
)
|
218 |
+
# (channels_num, audio_samples)
|
219 |
+
|
220 |
+
target_dict[source_type] = audio
|
221 |
+
# (channels_num, audio_samples)
|
222 |
+
|
223 |
+
# Separate.
|
224 |
+
input_dict = {'waveform': mixture}
|
225 |
+
|
226 |
+
sep_wavs = self.separator.separate(input_dict)
|
227 |
+
# sep_wavs: (target_sources_num * channels_num, audio_samples)
|
228 |
+
|
229 |
+
# Post process separation results.
|
230 |
+
sep_wavs = preprocess_audio(
|
231 |
+
audio=sep_wavs,
|
232 |
+
mono=self.mono,
|
233 |
+
origin_sr=self.sample_rate,
|
234 |
+
sr=track.rate,
|
235 |
+
resample_type=self.resample_type,
|
236 |
+
)
|
237 |
+
# sep_wavs: (target_sources_num * channels_num, audio_samples)
|
238 |
+
|
239 |
+
sep_wavs = librosa.util.fix_length(
|
240 |
+
sep_wavs, size=mixture.shape[1], axis=1
|
241 |
+
)
|
242 |
+
# sep_wavs: (target_sources_num * channels_num, audio_samples)
|
243 |
+
|
244 |
+
sep_wav_dict = get_separated_wavs_from_simo_output(
|
245 |
+
sep_wavs, self.input_channels, self.target_source_types
|
246 |
+
)
|
247 |
+
# output_dict: dict, e.g., {
|
248 |
+
# 'vocals': (channels_num, audio_samples),
|
249 |
+
# 'bass': (channels_num, audio_samples),
|
250 |
+
# ...,
|
251 |
+
# }
|
252 |
+
|
253 |
+
# Evaluate for all target source types.
|
254 |
+
for source_type in self.target_source_types:
|
255 |
+
# E.g., ['vocals', 'bass', ...]
|
256 |
+
|
257 |
+
# Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan).
|
258 |
+
(sdrs, _, _, _) = museval.evaluate(
|
259 |
+
[target_dict[source_type].T], [sep_wav_dict[source_type].T]
|
260 |
+
)
|
261 |
+
|
262 |
+
sdr = np.nanmedian(sdrs)
|
263 |
+
sdr_dict[audio_name][source_type] = sdr
|
264 |
+
|
265 |
+
logging.info(
|
266 |
+
"{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
|
267 |
+
)
|
268 |
+
|
269 |
+
logging.info("-----------------------------")
|
270 |
+
median_sdr_dict = {}
|
271 |
+
|
272 |
+
# Calculate median SDRs of all songs.
|
273 |
+
for source_type in self.target_source_types:
|
274 |
+
# E.g., ['vocals', 'bass', ...]
|
275 |
+
|
276 |
+
median_sdr = np.median(
|
277 |
+
[
|
278 |
+
sdr_dict[audio_name][source_type]
|
279 |
+
for audio_name in sdr_dict.keys()
|
280 |
+
]
|
281 |
+
)
|
282 |
+
|
283 |
+
median_sdr_dict[source_type] = median_sdr
|
284 |
+
|
285 |
+
logging.info(
|
286 |
+
"Step: {}, {}, Median SDR: {:.3f}".format(
|
287 |
+
global_step, source_type, median_sdr
|
288 |
+
)
|
289 |
+
)
|
290 |
+
|
291 |
+
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
|
292 |
+
|
293 |
+
statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
|
294 |
+
self.statistics_container.append(global_step, statistics, self.split)
|
295 |
+
self.statistics_container.dump()
|
296 |
+
|
297 |
+
|
298 |
+
def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict:
|
299 |
+
r"""Get separated waveforms of target sources from a single input multiple
|
300 |
+
output (SIMO) system.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
x: (target_sources_num * channels_num, audio_samples)
|
304 |
+
input_channels: int
|
305 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
output_dict: dict, e.g., {
|
309 |
+
'vocals': (channels_num, audio_samples),
|
310 |
+
'bass': (channels_num, audio_samples),
|
311 |
+
...,
|
312 |
+
}
|
313 |
+
"""
|
314 |
+
output_dict = {}
|
315 |
+
|
316 |
+
for j, source_type in enumerate(target_source_types):
|
317 |
+
output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels]
|
318 |
+
|
319 |
+
return output_dict
|
320 |
+
|
321 |
+
|
322 |
+
class Musdb18ConditionalEvaluationCallback(pl.Callback):
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
dataset_dir: str,
|
326 |
+
model: nn.Module,
|
327 |
+
target_source_types: str,
|
328 |
+
input_channels: int,
|
329 |
+
split: str,
|
330 |
+
sample_rate: int,
|
331 |
+
segment_samples: int,
|
332 |
+
batch_size: int,
|
333 |
+
device: str,
|
334 |
+
evaluate_step_frequency: int,
|
335 |
+
logger: pl.loggers.TensorBoardLogger,
|
336 |
+
statistics_container: StatisticsContainer,
|
337 |
+
):
|
338 |
+
r"""Callback to evaluate every #save_step_frequency steps.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
dataset_dir: str
|
342 |
+
model: nn.Module
|
343 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
344 |
+
input_channels: int
|
345 |
+
split: 'train' | 'test'
|
346 |
+
sample_rate: int
|
347 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
348 |
+
batch_size, int, e.g., 12
|
349 |
+
device: str, e.g., 'cuda'
|
350 |
+
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
|
351 |
+
logger: object
|
352 |
+
statistics_container: StatisticsContainer
|
353 |
+
"""
|
354 |
+
self.model = model
|
355 |
+
self.target_source_types = target_source_types
|
356 |
+
self.input_channels = input_channels
|
357 |
+
self.sample_rate = sample_rate
|
358 |
+
self.split = split
|
359 |
+
self.segment_samples = segment_samples
|
360 |
+
self.evaluate_step_frequency = evaluate_step_frequency
|
361 |
+
self.logger = logger
|
362 |
+
self.statistics_container = statistics_container
|
363 |
+
self.mono = input_channels == 1
|
364 |
+
self.resample_type = "kaiser_fast"
|
365 |
+
|
366 |
+
self.mus = musdb.DB(root=dataset_dir, subsets=[split])
|
367 |
+
|
368 |
+
error_msg = "The directory {} is empty!".format(dataset_dir)
|
369 |
+
assert len(self.mus) > 0, error_msg
|
370 |
+
|
371 |
+
# separator
|
372 |
+
self.separator = Separator(model, self.segment_samples, batch_size, device)
|
373 |
+
|
374 |
+
@rank_zero_only
|
375 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
376 |
+
r"""Evaluate separation SDRs of audio recordings."""
|
377 |
+
global_step = trainer.global_step
|
378 |
+
|
379 |
+
if global_step % self.evaluate_step_frequency == 0:
|
380 |
+
|
381 |
+
sdr_dict = {}
|
382 |
+
|
383 |
+
logging.info("--- Step {} ---".format(global_step))
|
384 |
+
logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
|
385 |
+
|
386 |
+
eval_time = time.time()
|
387 |
+
|
388 |
+
for track in self.mus.tracks:
|
389 |
+
|
390 |
+
audio_name = track.name
|
391 |
+
|
392 |
+
# Get waveform of mixture.
|
393 |
+
mixture = track.audio.T
|
394 |
+
# (channels_num, audio_samples)
|
395 |
+
|
396 |
+
mixture = preprocess_audio(
|
397 |
+
audio=mixture,
|
398 |
+
mono=self.mono,
|
399 |
+
origin_sr=track.rate,
|
400 |
+
sr=self.sample_rate,
|
401 |
+
resample_type=self.resample_type,
|
402 |
+
)
|
403 |
+
# (channels_num, audio_samples)
|
404 |
+
|
405 |
+
target_dict = {}
|
406 |
+
sdr_dict[audio_name] = {}
|
407 |
+
|
408 |
+
# Get waveform of all target source types.
|
409 |
+
for j, source_type in enumerate(self.target_source_types):
|
410 |
+
# E.g., ['vocals', 'bass', ...]
|
411 |
+
|
412 |
+
audio = track.targets[source_type].audio.T
|
413 |
+
|
414 |
+
audio = preprocess_audio(
|
415 |
+
audio=audio,
|
416 |
+
mono=self.mono,
|
417 |
+
origin_sr=track.rate,
|
418 |
+
sr=self.sample_rate,
|
419 |
+
resample_type=self.resample_type,
|
420 |
+
)
|
421 |
+
# (channels_num, audio_samples)
|
422 |
+
|
423 |
+
target_dict[source_type] = audio
|
424 |
+
# (channels_num, audio_samples)
|
425 |
+
|
426 |
+
condition = np.zeros(len(self.target_source_types))
|
427 |
+
condition[j] = 1
|
428 |
+
|
429 |
+
input_dict = {'waveform': mixture, 'condition': condition}
|
430 |
+
|
431 |
+
sep_wav = self.separator.separate(input_dict)
|
432 |
+
# sep_wav: (channels_num, audio_samples)
|
433 |
+
|
434 |
+
sep_wav = preprocess_audio(
|
435 |
+
audio=sep_wav,
|
436 |
+
mono=self.mono,
|
437 |
+
origin_sr=self.sample_rate,
|
438 |
+
sr=track.rate,
|
439 |
+
resample_type=self.resample_type,
|
440 |
+
)
|
441 |
+
# sep_wav: (channels_num, audio_samples)
|
442 |
+
|
443 |
+
sep_wav = librosa.util.fix_length(
|
444 |
+
sep_wav, size=mixture.shape[1], axis=1
|
445 |
+
)
|
446 |
+
# sep_wav: (target_sources_num * channels_num, audio_samples)
|
447 |
+
|
448 |
+
# Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan)
|
449 |
+
(sdrs, _, _, _) = museval.evaluate(
|
450 |
+
[target_dict[source_type].T], [sep_wav.T]
|
451 |
+
)
|
452 |
+
|
453 |
+
sdr = np.nanmedian(sdrs)
|
454 |
+
sdr_dict[audio_name][source_type] = sdr
|
455 |
+
|
456 |
+
logging.info(
|
457 |
+
"{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
|
458 |
+
)
|
459 |
+
|
460 |
+
logging.info("-----------------------------")
|
461 |
+
median_sdr_dict = {}
|
462 |
+
|
463 |
+
# Calculate median SDRs of all songs.
|
464 |
+
for source_type in self.target_source_types:
|
465 |
+
|
466 |
+
median_sdr = np.median(
|
467 |
+
[
|
468 |
+
sdr_dict[audio_name][source_type]
|
469 |
+
for audio_name in sdr_dict.keys()
|
470 |
+
]
|
471 |
+
)
|
472 |
+
|
473 |
+
median_sdr_dict[source_type] = median_sdr
|
474 |
+
|
475 |
+
logging.info(
|
476 |
+
"Step: {}, {}, Median SDR: {:.3f}".format(
|
477 |
+
global_step, source_type, median_sdr
|
478 |
+
)
|
479 |
+
)
|
480 |
+
|
481 |
+
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
|
482 |
+
|
483 |
+
statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
|
484 |
+
self.statistics_container.append(global_step, statistics, self.split)
|
485 |
+
self.statistics_container.dump()
|
bytesep/callbacks/voicebank_demand.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import List, NoReturn
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import pysepm
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
import torch.nn as nn
|
11 |
+
from pesq import pesq
|
12 |
+
from pytorch_lightning.utilities import rank_zero_only
|
13 |
+
|
14 |
+
from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
|
15 |
+
from bytesep.inference import Separator
|
16 |
+
from bytesep.utils import StatisticsContainer, read_yaml
|
17 |
+
|
18 |
+
|
19 |
+
def get_voicebank_demand_callbacks(
|
20 |
+
config_yaml: str,
|
21 |
+
workspace: str,
|
22 |
+
checkpoints_dir: str,
|
23 |
+
statistics_path: str,
|
24 |
+
logger: pl.loggers.TensorBoardLogger,
|
25 |
+
model: nn.Module,
|
26 |
+
evaluate_device: str,
|
27 |
+
) -> List[pl.Callback]:
|
28 |
+
"""Get Voicebank-Demand callbacks of a config yaml.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
config_yaml: str
|
32 |
+
workspace: str
|
33 |
+
checkpoints_dir: str, directory to save checkpoints
|
34 |
+
statistics_dir: str, directory to save statistics
|
35 |
+
logger: pl.loggers.TensorBoardLogger
|
36 |
+
model: nn.Module
|
37 |
+
evaluate_device: str
|
38 |
+
|
39 |
+
Return:
|
40 |
+
callbacks: List[pl.Callback]
|
41 |
+
"""
|
42 |
+
configs = read_yaml(config_yaml)
|
43 |
+
task_name = configs['task_name']
|
44 |
+
target_source_types = configs['train']['target_source_types']
|
45 |
+
input_channels = configs['train']['channels']
|
46 |
+
evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
|
47 |
+
sample_rate = configs['train']['sample_rate']
|
48 |
+
evaluate_step_frequency = configs['train']['evaluate_step_frequency']
|
49 |
+
save_step_frequency = configs['train']['save_step_frequency']
|
50 |
+
test_batch_size = configs['evaluate']['batch_size']
|
51 |
+
test_segment_seconds = configs['evaluate']['segment_seconds']
|
52 |
+
|
53 |
+
test_segment_samples = int(test_segment_seconds * sample_rate)
|
54 |
+
assert len(target_source_types) == 1
|
55 |
+
target_source_type = target_source_types[0]
|
56 |
+
assert target_source_type == 'speech'
|
57 |
+
|
58 |
+
# save checkpoint callback
|
59 |
+
save_checkpoints_callback = SaveCheckpointsCallback(
|
60 |
+
model=model,
|
61 |
+
checkpoints_dir=checkpoints_dir,
|
62 |
+
save_step_frequency=save_step_frequency,
|
63 |
+
)
|
64 |
+
|
65 |
+
# statistics container
|
66 |
+
statistics_container = StatisticsContainer(statistics_path)
|
67 |
+
|
68 |
+
# evaluation callback
|
69 |
+
evaluate_test_callback = EvaluationCallback(
|
70 |
+
model=model,
|
71 |
+
input_channels=input_channels,
|
72 |
+
sample_rate=sample_rate,
|
73 |
+
evaluation_audios_dir=evaluation_audios_dir,
|
74 |
+
segment_samples=test_segment_samples,
|
75 |
+
batch_size=test_batch_size,
|
76 |
+
device=evaluate_device,
|
77 |
+
evaluate_step_frequency=evaluate_step_frequency,
|
78 |
+
logger=logger,
|
79 |
+
statistics_container=statistics_container,
|
80 |
+
)
|
81 |
+
|
82 |
+
callbacks = [save_checkpoints_callback, evaluate_test_callback]
|
83 |
+
|
84 |
+
return callbacks
|
85 |
+
|
86 |
+
|
87 |
+
class EvaluationCallback(pl.Callback):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
model: nn.Module,
|
91 |
+
input_channels: int,
|
92 |
+
evaluation_audios_dir,
|
93 |
+
sample_rate: int,
|
94 |
+
segment_samples: int,
|
95 |
+
batch_size: int,
|
96 |
+
device: str,
|
97 |
+
evaluate_step_frequency: int,
|
98 |
+
logger: pl.loggers.TensorBoardLogger,
|
99 |
+
statistics_container: StatisticsContainer,
|
100 |
+
):
|
101 |
+
r"""Callback to evaluate every #save_step_frequency steps.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
model: nn.Module
|
105 |
+
input_channels: int
|
106 |
+
evaluation_audios_dir: str, directory containing audios for evaluation
|
107 |
+
sample_rate: int
|
108 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
109 |
+
batch_size, int, e.g., 12
|
110 |
+
device: str, e.g., 'cuda'
|
111 |
+
evaluate_step_frequency: int, evaluate every #save_step_frequency steps
|
112 |
+
logger: pl.loggers.TensorBoardLogger
|
113 |
+
statistics_container: StatisticsContainer
|
114 |
+
"""
|
115 |
+
self.model = model
|
116 |
+
self.mono = True
|
117 |
+
self.sample_rate = sample_rate
|
118 |
+
self.segment_samples = segment_samples
|
119 |
+
self.evaluate_step_frequency = evaluate_step_frequency
|
120 |
+
self.logger = logger
|
121 |
+
self.statistics_container = statistics_container
|
122 |
+
|
123 |
+
self.clean_dir = os.path.join(evaluation_audios_dir, "clean_testset_wav")
|
124 |
+
self.noisy_dir = os.path.join(evaluation_audios_dir, "noisy_testset_wav")
|
125 |
+
|
126 |
+
self.EVALUATION_SAMPLE_RATE = 16000 # Evaluation sample rate of the
|
127 |
+
# Voicebank-Demand task.
|
128 |
+
|
129 |
+
# separator
|
130 |
+
self.separator = Separator(model, self.segment_samples, batch_size, device)
|
131 |
+
|
132 |
+
@rank_zero_only
|
133 |
+
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
|
134 |
+
r"""Evaluate losses on a few mini-batches. Losses are only used for
|
135 |
+
observing training, and are not final F1 metrics.
|
136 |
+
"""
|
137 |
+
|
138 |
+
global_step = trainer.global_step
|
139 |
+
|
140 |
+
if global_step % self.evaluate_step_frequency == 0:
|
141 |
+
|
142 |
+
audio_names = sorted(
|
143 |
+
[
|
144 |
+
audio_name
|
145 |
+
for audio_name in sorted(os.listdir(self.clean_dir))
|
146 |
+
if audio_name.endswith('.wav')
|
147 |
+
]
|
148 |
+
)
|
149 |
+
|
150 |
+
error_str = "Directory {} does not contain audios for evaluation!".format(
|
151 |
+
self.clean_dir
|
152 |
+
)
|
153 |
+
assert len(audio_names) > 0, error_str
|
154 |
+
|
155 |
+
pesqs, csigs, cbaks, covls, ssnrs = [], [], [], [], []
|
156 |
+
|
157 |
+
logging.info("--- Step {} ---".format(global_step))
|
158 |
+
logging.info("Total {} pieces for evaluation:".format(len(audio_names)))
|
159 |
+
|
160 |
+
eval_time = time.time()
|
161 |
+
|
162 |
+
for n, audio_name in enumerate(audio_names):
|
163 |
+
|
164 |
+
# Load audio.
|
165 |
+
clean_path = os.path.join(self.clean_dir, audio_name)
|
166 |
+
mixture_path = os.path.join(self.noisy_dir, audio_name)
|
167 |
+
|
168 |
+
mixture, _ = librosa.core.load(
|
169 |
+
mixture_path, sr=self.sample_rate, mono=self.mono
|
170 |
+
)
|
171 |
+
|
172 |
+
if mixture.ndim == 1:
|
173 |
+
mixture = mixture[None, :]
|
174 |
+
# (channels_num, audio_length)
|
175 |
+
|
176 |
+
# Separate.
|
177 |
+
input_dict = {'waveform': mixture}
|
178 |
+
|
179 |
+
sep_wav = self.separator.separate(input_dict)
|
180 |
+
# (channels_num, audio_length)
|
181 |
+
|
182 |
+
# Target
|
183 |
+
clean, _ = librosa.core.load(
|
184 |
+
clean_path, sr=self.EVALUATION_SAMPLE_RATE, mono=self.mono
|
185 |
+
)
|
186 |
+
|
187 |
+
# to mono
|
188 |
+
sep_wav = np.squeeze(sep_wav)
|
189 |
+
|
190 |
+
# Resample for evaluation.
|
191 |
+
sep_wav = librosa.resample(
|
192 |
+
sep_wav,
|
193 |
+
orig_sr=self.sample_rate,
|
194 |
+
target_sr=self.EVALUATION_SAMPLE_RATE,
|
195 |
+
)
|
196 |
+
|
197 |
+
sep_wav = librosa.util.fix_length(sep_wav, size=len(clean), axis=0)
|
198 |
+
# (channels, audio_length)
|
199 |
+
|
200 |
+
# Evaluate metrics
|
201 |
+
pesq_ = pesq(self.EVALUATION_SAMPLE_RATE, clean, sep_wav, 'wb')
|
202 |
+
|
203 |
+
(csig, cbak, covl) = pysepm.composite(
|
204 |
+
clean, sep_wav, self.EVALUATION_SAMPLE_RATE
|
205 |
+
)
|
206 |
+
|
207 |
+
ssnr = pysepm.SNRseg(clean, sep_wav, self.EVALUATION_SAMPLE_RATE)
|
208 |
+
|
209 |
+
pesqs.append(pesq_)
|
210 |
+
csigs.append(csig)
|
211 |
+
cbaks.append(cbak)
|
212 |
+
covls.append(covl)
|
213 |
+
ssnrs.append(ssnr)
|
214 |
+
print(
|
215 |
+
'{}, {}, PESQ: {:.3f}, CSIG: {:.3f}, CBAK: {:.3f}, COVL: {:.3f}, SSNR: {:.3f}'.format(
|
216 |
+
n, audio_name, pesq_, csig, cbak, covl, ssnr
|
217 |
+
)
|
218 |
+
)
|
219 |
+
|
220 |
+
logging.info("-----------------------------")
|
221 |
+
logging.info('Avg PESQ: {:.3f}'.format(np.mean(pesqs)))
|
222 |
+
logging.info('Avg CSIG: {:.3f}'.format(np.mean(csigs)))
|
223 |
+
logging.info('Avg CBAK: {:.3f}'.format(np.mean(cbaks)))
|
224 |
+
logging.info('Avg COVL: {:.3f}'.format(np.mean(covls)))
|
225 |
+
logging.info('Avg SSNR: {:.3f}'.format(np.mean(ssnrs)))
|
226 |
+
|
227 |
+
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
|
228 |
+
|
229 |
+
statistics = {"pesq": np.mean(pesqs)}
|
230 |
+
self.statistics_container.append(global_step, statistics, 'test')
|
231 |
+
self.statistics_container.dump()
|
bytesep/data/__init__.py
ADDED
File without changes
|
bytesep/data/augmentors.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from bytesep.utils import db_to_magnitude, get_pitch_shift_factor, magnitude_to_db
|
7 |
+
|
8 |
+
|
9 |
+
class Augmentor:
|
10 |
+
def __init__(self, augmentations: Dict, random_seed=1234):
|
11 |
+
r"""Augmentor for data augmentation of a waveform.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
augmentations: Dict, e.g, {
|
15 |
+
'mixaudio': {'vocals': 2, 'accompaniment': 2}
|
16 |
+
'pitch_shift': {'vocals': 4, 'accompaniment': 4},
|
17 |
+
...,
|
18 |
+
}
|
19 |
+
random_seed: int
|
20 |
+
"""
|
21 |
+
self.augmentations = augmentations
|
22 |
+
self.random_state = np.random.RandomState(random_seed)
|
23 |
+
|
24 |
+
def __call__(self, waveform: np.array, source_type: str) -> np.array:
|
25 |
+
r"""Augment a waveform.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
waveform: (channels_num, audio_samples)
|
29 |
+
source_type: str
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
new_waveform: (channels_num, new_audio_samples)
|
33 |
+
"""
|
34 |
+
if 'pitch_shift' in self.augmentations.keys():
|
35 |
+
waveform = self.pitch_shift(waveform, source_type)
|
36 |
+
|
37 |
+
if 'magnitude_scale' in self.augmentations.keys():
|
38 |
+
waveform = self.magnitude_scale(waveform, source_type)
|
39 |
+
|
40 |
+
if 'swap_channel' in self.augmentations.keys():
|
41 |
+
waveform = self.swap_channel(waveform, source_type)
|
42 |
+
|
43 |
+
if 'flip_axis' in self.augmentations.keys():
|
44 |
+
waveform = self.flip_axis(waveform, source_type)
|
45 |
+
|
46 |
+
return waveform
|
47 |
+
|
48 |
+
def pitch_shift(self, waveform: np.array, source_type: str) -> np.array:
|
49 |
+
r"""Shift the pitch of a waveform. We use resampling for fast pitch
|
50 |
+
shifting, so the speed will also be chaneged. The length of the returned
|
51 |
+
waveform will be changed.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
waveform: (channels_num, audio_samples)
|
55 |
+
source_type: str
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
new_waveform: (channels_num, new_audio_samples)
|
59 |
+
"""
|
60 |
+
|
61 |
+
# maximum pitch shift in semitones
|
62 |
+
max_pitch_shift = self.augmentations['pitch_shift'][source_type]
|
63 |
+
|
64 |
+
if max_pitch_shift == 0: # No pitch shift augmentations.
|
65 |
+
return waveform
|
66 |
+
|
67 |
+
# random pitch shift
|
68 |
+
rand_pitch = self.random_state.uniform(
|
69 |
+
low=-max_pitch_shift, high=max_pitch_shift
|
70 |
+
)
|
71 |
+
|
72 |
+
# We use librosa.resample instead of librosa.effects.pitch_shift
|
73 |
+
# because it is 10x times faster.
|
74 |
+
pitch_shift_factor = get_pitch_shift_factor(rand_pitch)
|
75 |
+
dummy_sample_rate = 10000 # Dummy constant.
|
76 |
+
|
77 |
+
channels_num = waveform.shape[0]
|
78 |
+
|
79 |
+
if channels_num == 1:
|
80 |
+
waveform = np.squeeze(waveform)
|
81 |
+
|
82 |
+
new_waveform = librosa.resample(
|
83 |
+
y=waveform,
|
84 |
+
orig_sr=dummy_sample_rate,
|
85 |
+
target_sr=dummy_sample_rate / pitch_shift_factor,
|
86 |
+
res_type='linear',
|
87 |
+
axis=-1,
|
88 |
+
)
|
89 |
+
|
90 |
+
if channels_num == 1:
|
91 |
+
new_waveform = new_waveform[None, :]
|
92 |
+
|
93 |
+
return new_waveform
|
94 |
+
|
95 |
+
def magnitude_scale(self, waveform: np.array, source_type: str) -> np.array:
|
96 |
+
r"""Scale the magnitude of a waveform.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
waveform: (channels_num, audio_samples)
|
100 |
+
source_type: str
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
new_waveform: (channels_num, audio_samples)
|
104 |
+
"""
|
105 |
+
lower_db = self.augmentations['magnitude_scale'][source_type]['lower_db']
|
106 |
+
higher_db = self.augmentations['magnitude_scale'][source_type]['higher_db']
|
107 |
+
|
108 |
+
if lower_db == 0 and higher_db == 0: # No magnitude scale augmentation.
|
109 |
+
return waveform
|
110 |
+
|
111 |
+
# The magnitude (in dB) of the sample with the maximum value.
|
112 |
+
waveform_db = magnitude_to_db(np.max(np.abs(waveform)))
|
113 |
+
|
114 |
+
new_waveform_db = self.random_state.uniform(
|
115 |
+
waveform_db + lower_db, min(waveform_db + higher_db, 0)
|
116 |
+
)
|
117 |
+
|
118 |
+
relative_db = new_waveform_db - waveform_db
|
119 |
+
|
120 |
+
relative_scale = db_to_magnitude(relative_db)
|
121 |
+
|
122 |
+
new_waveform = waveform * relative_scale
|
123 |
+
|
124 |
+
return new_waveform
|
125 |
+
|
126 |
+
def swap_channel(self, waveform: np.array, source_type: str) -> np.array:
|
127 |
+
r"""Randomly swap channels.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
waveform: (channels_num, audio_samples)
|
131 |
+
source_type: str
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
new_waveform: (channels_num, audio_samples)
|
135 |
+
"""
|
136 |
+
ndim = waveform.shape[0]
|
137 |
+
|
138 |
+
if ndim == 1:
|
139 |
+
return waveform
|
140 |
+
else:
|
141 |
+
random_axes = self.random_state.permutation(ndim)
|
142 |
+
return waveform[random_axes, :]
|
143 |
+
|
144 |
+
def flip_axis(self, waveform: np.array, source_type: str) -> np.array:
|
145 |
+
r"""Randomly flip the waveform along x-axis.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
waveform: (channels_num, audio_samples)
|
149 |
+
source_type: str
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
new_waveform: (channels_num, audio_samples)
|
153 |
+
"""
|
154 |
+
ndim = waveform.shape[0]
|
155 |
+
random_values = self.random_state.choice([-1, 1], size=ndim)
|
156 |
+
|
157 |
+
return waveform * random_values[:, None]
|
bytesep/data/batch_data_preprocessors.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class BasicBatchDataPreprocessor:
|
7 |
+
def __init__(self, target_source_types: List[str]):
|
8 |
+
r"""Batch data preprocessor. Used for preparing mixtures and targets for
|
9 |
+
training. If there are multiple target source types, the waveforms of
|
10 |
+
those sources will be stacked along the channel dimension.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
14 |
+
"""
|
15 |
+
self.target_source_types = target_source_types
|
16 |
+
|
17 |
+
def __call__(self, batch_data_dict: Dict) -> List[Dict]:
|
18 |
+
r"""Format waveforms and targets for training.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
batch_data_dict: dict, e.g., {
|
22 |
+
'mixture': (batch_size, channels_num, segment_samples),
|
23 |
+
'vocals': (batch_size, channels_num, segment_samples),
|
24 |
+
'bass': (batch_size, channels_num, segment_samples),
|
25 |
+
...,
|
26 |
+
}
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
input_dict: dict, e.g., {
|
30 |
+
'waveform': (batch_size, channels_num, segment_samples),
|
31 |
+
}
|
32 |
+
output_dict: dict, e.g., {
|
33 |
+
'target': (batch_size, target_sources_num * channels_num, segment_samples)
|
34 |
+
}
|
35 |
+
"""
|
36 |
+
mixtures = batch_data_dict['mixture']
|
37 |
+
# mixtures: (batch_size, channels_num, segment_samples)
|
38 |
+
|
39 |
+
# Concatenate waveforms of multiple targets along the channel axis.
|
40 |
+
targets = torch.cat(
|
41 |
+
[batch_data_dict[source_type] for source_type in self.target_source_types],
|
42 |
+
dim=1,
|
43 |
+
)
|
44 |
+
# targets: (batch_size, target_sources_num * channels_num, segment_samples)
|
45 |
+
|
46 |
+
input_dict = {'waveform': mixtures}
|
47 |
+
target_dict = {'waveform': targets}
|
48 |
+
|
49 |
+
return input_dict, target_dict
|
50 |
+
|
51 |
+
|
52 |
+
class ConditionalSisoBatchDataPreprocessor:
|
53 |
+
def __init__(self, target_source_types: List[str]):
|
54 |
+
r"""Conditional single input single output (SISO) batch data
|
55 |
+
preprocessor. Select one target source from several target sources as
|
56 |
+
training target and prepare the corresponding conditional vector.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
target_source_types: List[str], e.g., ['vocals', 'bass', ...]
|
60 |
+
"""
|
61 |
+
self.target_source_types = target_source_types
|
62 |
+
|
63 |
+
def __call__(self, batch_data_dict: Dict) -> List[Dict]:
|
64 |
+
r"""Format waveforms and targets for training.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
batch_data_dict: dict, e.g., {
|
68 |
+
'mixture': (batch_size, channels_num, segment_samples),
|
69 |
+
'vocals': (batch_size, channels_num, segment_samples),
|
70 |
+
'bass': (batch_size, channels_num, segment_samples),
|
71 |
+
...,
|
72 |
+
}
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
input_dict: dict, e.g., {
|
76 |
+
'waveform': (batch_size, channels_num, segment_samples),
|
77 |
+
'condition': (batch_size, target_sources_num),
|
78 |
+
}
|
79 |
+
output_dict: dict, e.g., {
|
80 |
+
'target': (batch_size, channels_num, segment_samples)
|
81 |
+
}
|
82 |
+
"""
|
83 |
+
|
84 |
+
batch_size = len(batch_data_dict['mixture'])
|
85 |
+
target_sources_num = len(self.target_source_types)
|
86 |
+
|
87 |
+
assert (
|
88 |
+
batch_size % target_sources_num == 0
|
89 |
+
), "Batch size should be \
|
90 |
+
evenly divided by target sources number."
|
91 |
+
|
92 |
+
mixtures = batch_data_dict['mixture']
|
93 |
+
# mixtures: (batch_size, channels_num, segment_samples)
|
94 |
+
|
95 |
+
conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device)
|
96 |
+
# conditions: (batch_size, target_sources_num)
|
97 |
+
|
98 |
+
targets = []
|
99 |
+
|
100 |
+
for n in range(batch_size):
|
101 |
+
|
102 |
+
k = n % target_sources_num # source class index
|
103 |
+
source_type = self.target_source_types[k]
|
104 |
+
|
105 |
+
targets.append(batch_data_dict[source_type][n])
|
106 |
+
|
107 |
+
conditions[n, k] = 1
|
108 |
+
|
109 |
+
# conditions will looks like:
|
110 |
+
# [[1, 0, 0, 0],
|
111 |
+
# [0, 1, 0, 0],
|
112 |
+
# [0, 0, 1, 0],
|
113 |
+
# [0, 0, 0, 1],
|
114 |
+
# [1, 0, 0, 0],
|
115 |
+
# [0, 1, 0, 0],
|
116 |
+
# ...,
|
117 |
+
# ]
|
118 |
+
|
119 |
+
targets = torch.stack(targets, dim=0)
|
120 |
+
# targets: (batch_size, channels_num, segment_samples)
|
121 |
+
|
122 |
+
input_dict = {
|
123 |
+
'waveform': mixtures,
|
124 |
+
'condition': conditions,
|
125 |
+
}
|
126 |
+
|
127 |
+
target_dict = {'waveform': targets}
|
128 |
+
|
129 |
+
return input_dict, target_dict
|
130 |
+
|
131 |
+
|
132 |
+
def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object:
|
133 |
+
r"""Get batch data preprocessor class."""
|
134 |
+
if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor':
|
135 |
+
return BasicBatchDataPreprocessor
|
136 |
+
|
137 |
+
elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor':
|
138 |
+
return ConditionalSisoBatchDataPreprocessor
|
139 |
+
|
140 |
+
else:
|
141 |
+
raise NotImplementedError
|
bytesep/data/data_modules.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, NoReturn, Optional
|
2 |
+
|
3 |
+
import h5py
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from pytorch_lightning.core.datamodule import LightningDataModule
|
8 |
+
|
9 |
+
from bytesep.data.samplers import DistributedSamplerWrapper
|
10 |
+
from bytesep.utils import int16_to_float32
|
11 |
+
|
12 |
+
|
13 |
+
class DataModule(LightningDataModule):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
train_sampler: object,
|
17 |
+
train_dataset: object,
|
18 |
+
num_workers: int,
|
19 |
+
distributed: bool,
|
20 |
+
):
|
21 |
+
r"""Data module.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
train_sampler: Sampler object
|
25 |
+
train_dataset: Dataset object
|
26 |
+
num_workers: int
|
27 |
+
distributed: bool
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self._train_sampler = train_sampler
|
31 |
+
self.train_dataset = train_dataset
|
32 |
+
self.num_workers = num_workers
|
33 |
+
self.distributed = distributed
|
34 |
+
|
35 |
+
def setup(self, stage: Optional[str] = None) -> NoReturn:
|
36 |
+
r"""called on every device."""
|
37 |
+
|
38 |
+
# SegmentSampler is used for selecting segments for training.
|
39 |
+
# On multiple devices, each SegmentSampler samples a part of mini-batch
|
40 |
+
# data.
|
41 |
+
if self.distributed:
|
42 |
+
self.train_sampler = DistributedSamplerWrapper(self._train_sampler)
|
43 |
+
|
44 |
+
else:
|
45 |
+
self.train_sampler = self._train_sampler
|
46 |
+
|
47 |
+
def train_dataloader(self) -> torch.utils.data.DataLoader:
|
48 |
+
r"""Get train loader."""
|
49 |
+
train_loader = torch.utils.data.DataLoader(
|
50 |
+
dataset=self.train_dataset,
|
51 |
+
batch_sampler=self.train_sampler,
|
52 |
+
collate_fn=collate_fn,
|
53 |
+
num_workers=self.num_workers,
|
54 |
+
pin_memory=True,
|
55 |
+
)
|
56 |
+
|
57 |
+
return train_loader
|
58 |
+
|
59 |
+
|
60 |
+
class Dataset:
|
61 |
+
def __init__(self, augmentor: object, segment_samples: int):
|
62 |
+
r"""Used for getting data according to a meta.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
augmentor: Augmentor class
|
66 |
+
segment_samples: int
|
67 |
+
"""
|
68 |
+
self.augmentor = augmentor
|
69 |
+
self.segment_samples = segment_samples
|
70 |
+
|
71 |
+
def __getitem__(self, meta: Dict) -> Dict:
|
72 |
+
r"""Return data according to a meta. E.g., an input meta looks like: {
|
73 |
+
'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
|
74 |
+
'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}.
|
75 |
+
}
|
76 |
+
|
77 |
+
Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation).
|
78 |
+
Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation).
|
79 |
+
Finally, mixture is created by summing vocals and accompaniment.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
meta: dict, e.g., {
|
83 |
+
'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
|
84 |
+
'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}
|
85 |
+
}
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
data_dict: dict, e.g., {
|
89 |
+
'vocals': (channels, segments_num),
|
90 |
+
'accompaniment': (channels, segments_num),
|
91 |
+
'mixture': (channels, segments_num),
|
92 |
+
}
|
93 |
+
"""
|
94 |
+
source_types = meta.keys()
|
95 |
+
data_dict = {}
|
96 |
+
|
97 |
+
for source_type in source_types:
|
98 |
+
# E.g., ['vocals', 'bass', ...]
|
99 |
+
|
100 |
+
waveforms = [] # Audio segments to be mix-audio augmented.
|
101 |
+
|
102 |
+
for m in meta[source_type]:
|
103 |
+
# E.g., {
|
104 |
+
# 'hdf5_path': '.../song_A.h5',
|
105 |
+
# 'key_in_hdf5': 'vocals',
|
106 |
+
# 'begin_sample': '13406400',
|
107 |
+
# 'end_sample': 13538700,
|
108 |
+
# }
|
109 |
+
|
110 |
+
hdf5_path = m['hdf5_path']
|
111 |
+
key_in_hdf5 = m['key_in_hdf5']
|
112 |
+
bgn_sample = m['begin_sample']
|
113 |
+
end_sample = m['end_sample']
|
114 |
+
|
115 |
+
with h5py.File(hdf5_path, 'r') as hf:
|
116 |
+
|
117 |
+
if source_type == 'audioset':
|
118 |
+
index_in_hdf5 = m['index_in_hdf5']
|
119 |
+
waveform = int16_to_float32(
|
120 |
+
hf['waveform'][index_in_hdf5][bgn_sample:end_sample]
|
121 |
+
)
|
122 |
+
waveform = waveform[None, :]
|
123 |
+
else:
|
124 |
+
waveform = int16_to_float32(
|
125 |
+
hf[key_in_hdf5][:, bgn_sample:end_sample]
|
126 |
+
)
|
127 |
+
|
128 |
+
if self.augmentor:
|
129 |
+
waveform = self.augmentor(waveform, source_type)
|
130 |
+
|
131 |
+
waveform = librosa.util.fix_length(
|
132 |
+
waveform, size=self.segment_samples, axis=1
|
133 |
+
)
|
134 |
+
# (channels_num, segments_num)
|
135 |
+
|
136 |
+
waveforms.append(waveform)
|
137 |
+
# E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)]
|
138 |
+
|
139 |
+
# mix-audio augmentation
|
140 |
+
data_dict[source_type] = np.sum(waveforms, axis=0)
|
141 |
+
# data_dict[source_type]: (channels_num, audio_samples)
|
142 |
+
|
143 |
+
# data_dict looks like: {
|
144 |
+
# 'voclas': (channels_num, audio_samples),
|
145 |
+
# 'accompaniment': (channels_num, audio_samples)
|
146 |
+
# }
|
147 |
+
|
148 |
+
# Mix segments from different sources.
|
149 |
+
mixture = np.sum(
|
150 |
+
[data_dict[source_type] for source_type in source_types], axis=0
|
151 |
+
)
|
152 |
+
data_dict['mixture'] = mixture
|
153 |
+
# shape: (channels_num, audio_samples)
|
154 |
+
|
155 |
+
return data_dict
|
156 |
+
|
157 |
+
|
158 |
+
def collate_fn(list_data_dict: List[Dict]) -> Dict:
|
159 |
+
r"""Collate mini-batch data to inputs and targets for training.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
list_data_dict: e.g., [
|
163 |
+
{'vocals': (channels_num, segment_samples),
|
164 |
+
'accompaniment': (channels_num, segment_samples),
|
165 |
+
'mixture': (channels_num, segment_samples)
|
166 |
+
},
|
167 |
+
{'vocals': (channels_num, segment_samples),
|
168 |
+
'accompaniment': (channels_num, segment_samples),
|
169 |
+
'mixture': (channels_num, segment_samples)
|
170 |
+
},
|
171 |
+
...]
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
data_dict: e.g. {
|
175 |
+
'vocals': (batch_size, channels_num, segment_samples),
|
176 |
+
'accompaniment': (batch_size, channels_num, segment_samples),
|
177 |
+
'mixture': (batch_size, channels_num, segment_samples)
|
178 |
+
}
|
179 |
+
"""
|
180 |
+
data_dict = {}
|
181 |
+
|
182 |
+
for key in list_data_dict[0].keys():
|
183 |
+
data_dict[key] = torch.Tensor(
|
184 |
+
np.array([data_dict[key] for data_dict in list_data_dict])
|
185 |
+
)
|
186 |
+
|
187 |
+
return data_dict
|
bytesep/data/samplers.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from typing import Dict, List, NoReturn
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
class SegmentSampler:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
indexes_path: str,
|
12 |
+
segment_samples: int,
|
13 |
+
mixaudio_dict: Dict,
|
14 |
+
batch_size: int,
|
15 |
+
steps_per_epoch: int,
|
16 |
+
random_seed=1234,
|
17 |
+
):
|
18 |
+
r"""Sample training indexes of sources.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
indexes_path: str, path of indexes dict
|
22 |
+
segment_samplers: int
|
23 |
+
mixaudio_dict, dict, including hyper-parameters for mix-audio data
|
24 |
+
augmentation, e.g., {'voclas': 2, 'accompaniment': 2}
|
25 |
+
batch_size: int
|
26 |
+
steps_per_epoch: int, #steps_per_epoch is called an `epoch`
|
27 |
+
random_seed: int
|
28 |
+
"""
|
29 |
+
self.segment_samples = segment_samples
|
30 |
+
self.mixaudio_dict = mixaudio_dict
|
31 |
+
self.batch_size = batch_size
|
32 |
+
self.steps_per_epoch = steps_per_epoch
|
33 |
+
|
34 |
+
self.meta_dict = pickle.load(open(indexes_path, "rb"))
|
35 |
+
# E.g., {
|
36 |
+
# 'vocals': [
|
37 |
+
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
|
38 |
+
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
|
39 |
+
# ...
|
40 |
+
# ],
|
41 |
+
# 'accompaniment': [
|
42 |
+
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
|
43 |
+
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
|
44 |
+
# ...
|
45 |
+
# ]
|
46 |
+
# }
|
47 |
+
|
48 |
+
self.source_types = self.meta_dict.keys()
|
49 |
+
# E.g., ['vocals', 'accompaniment']
|
50 |
+
|
51 |
+
self.pointers_dict = {source_type: 0 for source_type in self.source_types}
|
52 |
+
# E.g., {'vocals': 0, 'accompaniment': 0}
|
53 |
+
|
54 |
+
self.indexes_dict = {
|
55 |
+
source_type: np.arange(len(self.meta_dict[source_type]))
|
56 |
+
for source_type in self.source_types
|
57 |
+
}
|
58 |
+
# E.g. {
|
59 |
+
# 'vocals': [0, 1, ..., 225751],
|
60 |
+
# 'accompaniment': [0, 1, ..., 225751]
|
61 |
+
# }
|
62 |
+
|
63 |
+
self.random_state = np.random.RandomState(random_seed)
|
64 |
+
|
65 |
+
# Shuffle indexes.
|
66 |
+
for source_type in self.source_types:
|
67 |
+
self.random_state.shuffle(self.indexes_dict[source_type])
|
68 |
+
print("{}: {}".format(source_type, len(self.indexes_dict[source_type])))
|
69 |
+
|
70 |
+
def __iter__(self) -> List[Dict]:
|
71 |
+
r"""Yield a batch of meta info.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [
|
75 |
+
{'vocals': [
|
76 |
+
{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
|
77 |
+
{'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
|
78 |
+
'accompaniment': [
|
79 |
+
{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760},
|
80 |
+
{'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}]
|
81 |
+
}
|
82 |
+
...
|
83 |
+
]
|
84 |
+
"""
|
85 |
+
batch_size = self.batch_size
|
86 |
+
|
87 |
+
while True:
|
88 |
+
batch_meta_dict = {source_type: [] for source_type in self.source_types}
|
89 |
+
|
90 |
+
for source_type in self.source_types:
|
91 |
+
# E.g., ['vocals', 'accompaniment']
|
92 |
+
|
93 |
+
# Loop until get a mini-batch.
|
94 |
+
while len(batch_meta_dict[source_type]) != batch_size:
|
95 |
+
|
96 |
+
largest_index = (
|
97 |
+
len(self.indexes_dict[source_type])
|
98 |
+
- self.mixaudio_dict[source_type]
|
99 |
+
)
|
100 |
+
# E.g., 225750 = 225752 - 2
|
101 |
+
|
102 |
+
if self.pointers_dict[source_type] > largest_index:
|
103 |
+
|
104 |
+
# Reset pointer, and shuffle indexes.
|
105 |
+
self.pointers_dict[source_type] = 0
|
106 |
+
self.random_state.shuffle(self.indexes_dict[source_type])
|
107 |
+
|
108 |
+
source_metas = []
|
109 |
+
mix_audios_num = self.mixaudio_dict[source_type]
|
110 |
+
|
111 |
+
for _ in range(mix_audios_num):
|
112 |
+
|
113 |
+
pointer = self.pointers_dict[source_type]
|
114 |
+
# E.g., 1
|
115 |
+
|
116 |
+
index = self.indexes_dict[source_type][pointer]
|
117 |
+
# E.g., 12231
|
118 |
+
|
119 |
+
self.pointers_dict[source_type] += 1
|
120 |
+
|
121 |
+
source_meta = self.meta_dict[source_type][index]
|
122 |
+
# E.g., ['song_A.h5', 198450, 330750]
|
123 |
+
|
124 |
+
# source_metas.append(new_source_meta)
|
125 |
+
source_metas.append(source_meta)
|
126 |
+
|
127 |
+
batch_meta_dict[source_type].append(source_metas)
|
128 |
+
# When mix-audio is 2, batch_meta_dict looks like: {
|
129 |
+
# 'vocals': [
|
130 |
+
# [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
|
131 |
+
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}],
|
132 |
+
# [{'hdf5_path': 'songC.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1186290, 'end_sample': 1318590},
|
133 |
+
# {'hdf5_path': 'songD.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 8462790, 'end_sample': 8595090}]
|
134 |
+
# ]
|
135 |
+
# 'accompaniment': [
|
136 |
+
# [{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250},
|
137 |
+
# {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260}],
|
138 |
+
# [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 2795940, 'end_sample': 2928240},
|
139 |
+
# {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 10923570, 'end_sample': 11055870}]
|
140 |
+
# ]
|
141 |
+
# }
|
142 |
+
|
143 |
+
batch_meta_list = [
|
144 |
+
{
|
145 |
+
source_type: batch_meta_dict[source_type][i]
|
146 |
+
for source_type in self.source_types
|
147 |
+
}
|
148 |
+
for i in range(batch_size)
|
149 |
+
]
|
150 |
+
# When mix-audio is 2, batch_meta_list looks like: [
|
151 |
+
# {'vocals': [
|
152 |
+
# {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
|
153 |
+
# {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
|
154 |
+
# 'accompaniment': [
|
155 |
+
# {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760},
|
156 |
+
# {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}]
|
157 |
+
# }
|
158 |
+
# ...
|
159 |
+
# ]
|
160 |
+
|
161 |
+
yield batch_meta_list
|
162 |
+
|
163 |
+
def __len__(self) -> int:
|
164 |
+
return self.steps_per_epoch
|
165 |
+
|
166 |
+
def state_dict(self) -> Dict:
|
167 |
+
state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict}
|
168 |
+
return state
|
169 |
+
|
170 |
+
def load_state_dict(self, state) -> NoReturn:
|
171 |
+
self.pointers_dict = state['pointers_dict']
|
172 |
+
self.indexes_dict = state['indexes_dict']
|
173 |
+
|
174 |
+
|
175 |
+
class DistributedSamplerWrapper:
|
176 |
+
def __init__(self, sampler):
|
177 |
+
r"""Distributed wrapper of sampler."""
|
178 |
+
self.sampler = sampler
|
179 |
+
|
180 |
+
def __iter__(self):
|
181 |
+
num_replicas = dist.get_world_size()
|
182 |
+
rank = dist.get_rank()
|
183 |
+
|
184 |
+
for indices in self.sampler:
|
185 |
+
yield indices[rank::num_replicas]
|
186 |
+
|
187 |
+
def __len__(self) -> int:
|
188 |
+
return len(self.sampler)
|
bytesep/dataset_creation/__init__.py
ADDED
File without changes
|
bytesep/dataset_creation/create_evaluation_audios/__init__.py
ADDED
File without changes
|
bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import NoReturn
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import soundfile
|
8 |
+
|
9 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
|
10 |
+
read_csv as read_instruments_solo_csv,
|
11 |
+
)
|
12 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import (
|
13 |
+
read_csv as read_maestro_csv,
|
14 |
+
)
|
15 |
+
from bytesep.utils import load_random_segment
|
16 |
+
|
17 |
+
|
18 |
+
def create_evaluation(args) -> NoReturn:
|
19 |
+
r"""Random mix and write out audios for evaluation.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
piano_dataset_dir: str, the directory of the piano dataset
|
23 |
+
symphony_dataset_dir: str, the directory of the symphony dataset
|
24 |
+
evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
|
25 |
+
sample_rate: int
|
26 |
+
channels: int, e.g., 1 | 2
|
27 |
+
evaluation_segments_num: int
|
28 |
+
mono: bool
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
NoReturn
|
32 |
+
"""
|
33 |
+
|
34 |
+
# arguments & parameters
|
35 |
+
piano_dataset_dir = args.piano_dataset_dir
|
36 |
+
symphony_dataset_dir = args.symphony_dataset_dir
|
37 |
+
evaluation_audios_dir = args.evaluation_audios_dir
|
38 |
+
sample_rate = args.sample_rate
|
39 |
+
channels = args.channels
|
40 |
+
evaluation_segments_num = args.evaluation_segments_num
|
41 |
+
mono = True if channels == 1 else False
|
42 |
+
|
43 |
+
split = 'test'
|
44 |
+
segment_seconds = 10.0
|
45 |
+
|
46 |
+
random_state = np.random.RandomState(1234)
|
47 |
+
|
48 |
+
piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv')
|
49 |
+
piano_names_dict = read_maestro_csv(piano_meta_csv)
|
50 |
+
piano_audio_names = piano_names_dict[split]
|
51 |
+
|
52 |
+
symphony_meta_csv = os.path.join(symphony_dataset_dir, 'validation.csv')
|
53 |
+
symphony_names_dict = read_instruments_solo_csv(symphony_meta_csv)
|
54 |
+
symphony_audio_names = symphony_names_dict[split]
|
55 |
+
|
56 |
+
for source_type in ['piano', 'symphony', 'mixture']:
|
57 |
+
output_dir = os.path.join(evaluation_audios_dir, split, source_type)
|
58 |
+
os.makedirs(output_dir, exist_ok=True)
|
59 |
+
|
60 |
+
for n in range(evaluation_segments_num):
|
61 |
+
|
62 |
+
print('{} / {}'.format(n, evaluation_segments_num))
|
63 |
+
|
64 |
+
# Randomly select and write out a clean piano segment.
|
65 |
+
piano_audio_name = random_state.choice(piano_audio_names)
|
66 |
+
piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name)
|
67 |
+
|
68 |
+
piano_audio = load_random_segment(
|
69 |
+
audio_path=piano_audio_path,
|
70 |
+
random_state=random_state,
|
71 |
+
segment_seconds=segment_seconds,
|
72 |
+
mono=mono,
|
73 |
+
sample_rate=sample_rate,
|
74 |
+
)
|
75 |
+
|
76 |
+
output_piano_path = os.path.join(
|
77 |
+
evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n)
|
78 |
+
)
|
79 |
+
soundfile.write(
|
80 |
+
file=output_piano_path, data=piano_audio.T, samplerate=sample_rate
|
81 |
+
)
|
82 |
+
print("Write out to {}".format(output_piano_path))
|
83 |
+
|
84 |
+
# Randomly select and write out a clean symphony segment.
|
85 |
+
symphony_audio_name = random_state.choice(symphony_audio_names)
|
86 |
+
symphony_audio_path = os.path.join(
|
87 |
+
symphony_dataset_dir, "mp3s", symphony_audio_name
|
88 |
+
)
|
89 |
+
|
90 |
+
symphony_audio = load_random_segment(
|
91 |
+
audio_path=symphony_audio_path,
|
92 |
+
random_state=random_state,
|
93 |
+
segment_seconds=segment_seconds,
|
94 |
+
mono=mono,
|
95 |
+
sample_rate=sample_rate,
|
96 |
+
)
|
97 |
+
|
98 |
+
output_symphony_path = os.path.join(
|
99 |
+
evaluation_audios_dir, split, 'symphony', '{:04d}.wav'.format(n)
|
100 |
+
)
|
101 |
+
soundfile.write(
|
102 |
+
file=output_symphony_path, data=symphony_audio.T, samplerate=sample_rate
|
103 |
+
)
|
104 |
+
print("Write out to {}".format(output_symphony_path))
|
105 |
+
|
106 |
+
# Mix piano and symphony segments and write out a mixture segment.
|
107 |
+
mixture_audio = symphony_audio + piano_audio
|
108 |
+
output_mixture_path = os.path.join(
|
109 |
+
evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
|
110 |
+
)
|
111 |
+
soundfile.write(
|
112 |
+
file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
|
113 |
+
)
|
114 |
+
print("Write out to {}".format(output_mixture_path))
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
parser = argparse.ArgumentParser()
|
119 |
+
|
120 |
+
parser.add_argument(
|
121 |
+
"--piano_dataset_dir",
|
122 |
+
type=str,
|
123 |
+
required=True,
|
124 |
+
help="The directory of the piano dataset.",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--symphony_dataset_dir",
|
128 |
+
type=str,
|
129 |
+
required=True,
|
130 |
+
help="The directory of the symphony dataset.",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--evaluation_audios_dir",
|
134 |
+
type=str,
|
135 |
+
required=True,
|
136 |
+
help="The directory to write out randomly selected and mixed audio segments.",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--sample_rate",
|
140 |
+
type=int,
|
141 |
+
required=True,
|
142 |
+
help="Sample rate.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--channels",
|
146 |
+
type=int,
|
147 |
+
required=True,
|
148 |
+
help="Audio channels, e.g, 1 or 2.",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--evaluation_segments_num",
|
152 |
+
type=int,
|
153 |
+
required=True,
|
154 |
+
help="The number of segments to create for evaluation.",
|
155 |
+
)
|
156 |
+
|
157 |
+
# Parse arguments.
|
158 |
+
args = parser.parse_args()
|
159 |
+
|
160 |
+
create_evaluation(args)
|
bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import soundfile
|
4 |
+
from typing import NoReturn
|
5 |
+
|
6 |
+
import musdb
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from bytesep.utils import load_audio
|
10 |
+
|
11 |
+
|
12 |
+
def create_evaluation(args) -> NoReturn:
|
13 |
+
r"""Random mix and write out audios for evaluation.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
vctk_dataset_dir: str, the directory of the VCTK dataset
|
17 |
+
symphony_dataset_dir: str, the directory of the symphony dataset
|
18 |
+
evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
|
19 |
+
sample_rate: int
|
20 |
+
channels: int, e.g., 1 | 2
|
21 |
+
evaluation_segments_num: int
|
22 |
+
mono: bool
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
NoReturn
|
26 |
+
"""
|
27 |
+
|
28 |
+
# arguments & parameters
|
29 |
+
vctk_dataset_dir = args.vctk_dataset_dir
|
30 |
+
musdb18_dataset_dir = args.musdb18_dataset_dir
|
31 |
+
evaluation_audios_dir = args.evaluation_audios_dir
|
32 |
+
sample_rate = args.sample_rate
|
33 |
+
channels = args.channels
|
34 |
+
evaluation_segments_num = args.evaluation_segments_num
|
35 |
+
mono = True if channels == 1 else False
|
36 |
+
|
37 |
+
split = 'test'
|
38 |
+
random_state = np.random.RandomState(1234)
|
39 |
+
|
40 |
+
# paths
|
41 |
+
audios_dir = os.path.join(vctk_dataset_dir, "wav48", split)
|
42 |
+
|
43 |
+
for source_type in ['speech', 'music', 'mixture']:
|
44 |
+
output_dir = os.path.join(evaluation_audios_dir, split, source_type)
|
45 |
+
os.makedirs(output_dir, exist_ok=True)
|
46 |
+
|
47 |
+
# Get VCTK audio paths.
|
48 |
+
speech_audio_paths = []
|
49 |
+
speaker_ids = sorted(os.listdir(audios_dir))
|
50 |
+
|
51 |
+
for speaker_id in speaker_ids:
|
52 |
+
speaker_audios_dir = os.path.join(audios_dir, speaker_id)
|
53 |
+
|
54 |
+
audio_names = sorted(os.listdir(speaker_audios_dir))
|
55 |
+
|
56 |
+
for audio_name in audio_names:
|
57 |
+
speaker_audio_path = os.path.join(speaker_audios_dir, audio_name)
|
58 |
+
speech_audio_paths.append(speaker_audio_path)
|
59 |
+
|
60 |
+
# Get Musdb18 audio paths.
|
61 |
+
mus = musdb.DB(root=musdb18_dataset_dir, subsets=[split])
|
62 |
+
track_indexes = np.arange(len(mus.tracks))
|
63 |
+
|
64 |
+
for n in range(evaluation_segments_num):
|
65 |
+
|
66 |
+
print('{} / {}'.format(n, evaluation_segments_num))
|
67 |
+
|
68 |
+
# Randomly select and write out a clean speech segment.
|
69 |
+
speech_audio_path = random_state.choice(speech_audio_paths)
|
70 |
+
|
71 |
+
speech_audio = load_audio(
|
72 |
+
audio_path=speech_audio_path, mono=mono, sample_rate=sample_rate
|
73 |
+
)
|
74 |
+
# (channels_num, audio_samples)
|
75 |
+
|
76 |
+
if channels == 2:
|
77 |
+
speech_audio = np.tile(speech_audio, (2, 1))
|
78 |
+
# (channels_num, audio_samples)
|
79 |
+
|
80 |
+
output_speech_path = os.path.join(
|
81 |
+
evaluation_audios_dir, split, 'speech', '{:04d}.wav'.format(n)
|
82 |
+
)
|
83 |
+
soundfile.write(
|
84 |
+
file=output_speech_path, data=speech_audio.T, samplerate=sample_rate
|
85 |
+
)
|
86 |
+
print("Write out to {}".format(output_speech_path))
|
87 |
+
|
88 |
+
# Randomly select and write out a clean music segment.
|
89 |
+
track_index = random_state.choice(track_indexes)
|
90 |
+
track = mus[track_index]
|
91 |
+
|
92 |
+
segment_samples = speech_audio.shape[1]
|
93 |
+
start_sample = int(
|
94 |
+
random_state.uniform(0.0, segment_samples - speech_audio.shape[1])
|
95 |
+
)
|
96 |
+
|
97 |
+
music_audio = track.audio[start_sample : start_sample + segment_samples, :].T
|
98 |
+
# (channels_num, audio_samples)
|
99 |
+
|
100 |
+
output_music_path = os.path.join(
|
101 |
+
evaluation_audios_dir, split, 'music', '{:04d}.wav'.format(n)
|
102 |
+
)
|
103 |
+
soundfile.write(
|
104 |
+
file=output_music_path, data=music_audio.T, samplerate=sample_rate
|
105 |
+
)
|
106 |
+
print("Write out to {}".format(output_music_path))
|
107 |
+
|
108 |
+
# Mix speech and music segments and write out a mixture segment.
|
109 |
+
mixture_audio = speech_audio + music_audio
|
110 |
+
# (channels_num, audio_samples)
|
111 |
+
|
112 |
+
output_mixture_path = os.path.join(
|
113 |
+
evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
|
114 |
+
)
|
115 |
+
soundfile.write(
|
116 |
+
file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
|
117 |
+
)
|
118 |
+
print("Write out to {}".format(output_mixture_path))
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
parser = argparse.ArgumentParser()
|
123 |
+
|
124 |
+
parser.add_argument(
|
125 |
+
"--vctk_dataset_dir",
|
126 |
+
type=str,
|
127 |
+
required=True,
|
128 |
+
help="The directory of the VCTK dataset.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--musdb18_dataset_dir",
|
132 |
+
type=str,
|
133 |
+
required=True,
|
134 |
+
help="The directory of the MUSDB18 dataset.",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--evaluation_audios_dir",
|
138 |
+
type=str,
|
139 |
+
required=True,
|
140 |
+
help="The directory to write out randomly selected and mixed audio segments.",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--sample_rate",
|
144 |
+
type=int,
|
145 |
+
required=True,
|
146 |
+
help="Sample rate",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--channels",
|
150 |
+
type=int,
|
151 |
+
required=True,
|
152 |
+
help="Audio channels, e.g, 1 or 2.",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--evaluation_segments_num",
|
156 |
+
type=int,
|
157 |
+
required=True,
|
158 |
+
help="The number of segments to create for evaluation.",
|
159 |
+
)
|
160 |
+
|
161 |
+
# Parse arguments.
|
162 |
+
args = parser.parse_args()
|
163 |
+
|
164 |
+
create_evaluation(args)
|
bytesep/dataset_creation/create_evaluation_audios/violin-piano.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import NoReturn
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import soundfile
|
8 |
+
|
9 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
|
10 |
+
read_csv as read_instruments_solo_csv,
|
11 |
+
)
|
12 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import (
|
13 |
+
read_csv as read_maestro_csv,
|
14 |
+
)
|
15 |
+
from bytesep.utils import load_random_segment
|
16 |
+
|
17 |
+
|
18 |
+
def create_evaluation(args) -> NoReturn:
|
19 |
+
r"""Random mix and write out audios for evaluation.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
violin_dataset_dir: str, the directory of the violin dataset
|
23 |
+
piano_dataset_dir: str, the directory of the piano dataset
|
24 |
+
evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
|
25 |
+
sample_rate: int
|
26 |
+
channels: int, e.g., 1 | 2
|
27 |
+
evaluation_segments_num: int
|
28 |
+
mono: bool
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
NoReturn
|
32 |
+
"""
|
33 |
+
|
34 |
+
# arguments & parameters
|
35 |
+
violin_dataset_dir = args.violin_dataset_dir
|
36 |
+
piano_dataset_dir = args.piano_dataset_dir
|
37 |
+
evaluation_audios_dir = args.evaluation_audios_dir
|
38 |
+
sample_rate = args.sample_rate
|
39 |
+
channels = args.channels
|
40 |
+
evaluation_segments_num = args.evaluation_segments_num
|
41 |
+
mono = True if channels == 1 else False
|
42 |
+
|
43 |
+
split = 'test'
|
44 |
+
segment_seconds = 10.0
|
45 |
+
|
46 |
+
random_state = np.random.RandomState(1234)
|
47 |
+
|
48 |
+
violin_meta_csv = os.path.join(violin_dataset_dir, 'validation.csv')
|
49 |
+
violin_names_dict = read_instruments_solo_csv(violin_meta_csv)
|
50 |
+
violin_audio_names = violin_names_dict['{}'.format(split)]
|
51 |
+
|
52 |
+
piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv')
|
53 |
+
piano_names_dict = read_maestro_csv(piano_meta_csv)
|
54 |
+
piano_audio_names = piano_names_dict['{}'.format(split)]
|
55 |
+
|
56 |
+
for source_type in ['violin', 'piano', 'mixture']:
|
57 |
+
output_dir = os.path.join(evaluation_audios_dir, split, source_type)
|
58 |
+
os.makedirs(output_dir, exist_ok=True)
|
59 |
+
|
60 |
+
for n in range(evaluation_segments_num):
|
61 |
+
|
62 |
+
print('{} / {}'.format(n, evaluation_segments_num))
|
63 |
+
|
64 |
+
# Randomly select and write out a clean violin segment.
|
65 |
+
violin_audio_name = random_state.choice(violin_audio_names)
|
66 |
+
violin_audio_path = os.path.join(violin_dataset_dir, "mp3s", violin_audio_name)
|
67 |
+
|
68 |
+
violin_audio = load_random_segment(
|
69 |
+
audio_path=violin_audio_path,
|
70 |
+
random_state=random_state,
|
71 |
+
segment_seconds=segment_seconds,
|
72 |
+
mono=mono,
|
73 |
+
sample_rate=sample_rate,
|
74 |
+
)
|
75 |
+
# (channels_num, audio_samples)
|
76 |
+
|
77 |
+
output_violin_path = os.path.join(
|
78 |
+
evaluation_audios_dir, split, 'violin', '{:04d}.wav'.format(n)
|
79 |
+
)
|
80 |
+
soundfile.write(
|
81 |
+
file=output_violin_path, data=violin_audio.T, samplerate=sample_rate
|
82 |
+
)
|
83 |
+
print("Write out to {}".format(output_violin_path))
|
84 |
+
|
85 |
+
# Randomly select and write out a clean piano segment.
|
86 |
+
piano_audio_name = random_state.choice(piano_audio_names)
|
87 |
+
piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name)
|
88 |
+
|
89 |
+
piano_audio = load_random_segment(
|
90 |
+
audio_path=piano_audio_path,
|
91 |
+
random_state=random_state,
|
92 |
+
segment_seconds=segment_seconds,
|
93 |
+
mono=mono,
|
94 |
+
sample_rate=sample_rate,
|
95 |
+
)
|
96 |
+
# (channels_num, audio_samples)
|
97 |
+
|
98 |
+
output_piano_path = os.path.join(
|
99 |
+
evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n)
|
100 |
+
)
|
101 |
+
soundfile.write(
|
102 |
+
file=output_piano_path, data=piano_audio.T, samplerate=sample_rate
|
103 |
+
)
|
104 |
+
print("Write out to {}".format(output_piano_path))
|
105 |
+
|
106 |
+
# Mix violin and piano segments and write out a mixture segment.
|
107 |
+
mixture_audio = violin_audio + piano_audio
|
108 |
+
# (channels_num, audio_samples)
|
109 |
+
|
110 |
+
output_mixture_path = os.path.join(
|
111 |
+
evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
|
112 |
+
)
|
113 |
+
soundfile.write(
|
114 |
+
file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
|
115 |
+
)
|
116 |
+
print("Write out to {}".format(output_mixture_path))
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
parser = argparse.ArgumentParser()
|
121 |
+
|
122 |
+
parser.add_argument(
|
123 |
+
"--violin_dataset_dir",
|
124 |
+
type=str,
|
125 |
+
required=True,
|
126 |
+
help="The directory of the violin dataset.",
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--piano_dataset_dir",
|
130 |
+
type=str,
|
131 |
+
required=True,
|
132 |
+
help="The directory of the piano dataset.",
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--evaluation_audios_dir",
|
136 |
+
type=str,
|
137 |
+
required=True,
|
138 |
+
help="The directory to write out randomly selected and mixed audio segments.",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--sample_rate",
|
142 |
+
type=int,
|
143 |
+
required=True,
|
144 |
+
help="Sample rate",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--channels",
|
148 |
+
type=int,
|
149 |
+
required=True,
|
150 |
+
help="Audio channels, e.g, 1 or 2.",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--evaluation_segments_num",
|
154 |
+
type=int,
|
155 |
+
required=True,
|
156 |
+
help="The number of segments to create for evaluation.",
|
157 |
+
)
|
158 |
+
|
159 |
+
# Parse arguments.
|
160 |
+
args = parser.parse_args()
|
161 |
+
|
162 |
+
create_evaluation(args)
|
bytesep/dataset_creation/create_indexes/__init__.py
ADDED
File without changes
|
bytesep/dataset_creation/create_indexes/create_indexes.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
from typing import NoReturn
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
|
8 |
+
from bytesep.utils import read_yaml
|
9 |
+
|
10 |
+
|
11 |
+
def create_indexes(args) -> NoReturn:
|
12 |
+
r"""Create and write out training indexes into disk. The indexes may contain
|
13 |
+
information from multiple datasets. During training, training indexes will
|
14 |
+
be shuffled and iterated for selecting segments to be mixed. E.g., the
|
15 |
+
training indexes_dict looks like: {
|
16 |
+
'vocals': [
|
17 |
+
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
|
18 |
+
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
|
19 |
+
...
|
20 |
+
]
|
21 |
+
'accompaniment': [
|
22 |
+
{'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
|
23 |
+
{'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
|
24 |
+
...
|
25 |
+
]
|
26 |
+
}
|
27 |
+
"""
|
28 |
+
|
29 |
+
# Arugments & parameters
|
30 |
+
workspace = args.workspace
|
31 |
+
config_yaml = args.config_yaml
|
32 |
+
|
33 |
+
# Only create indexes for training, because evalution is on entire pieces.
|
34 |
+
split = "train"
|
35 |
+
|
36 |
+
# Read config file.
|
37 |
+
configs = read_yaml(config_yaml)
|
38 |
+
|
39 |
+
sample_rate = configs["sample_rate"]
|
40 |
+
segment_samples = int(configs["segment_seconds"] * sample_rate)
|
41 |
+
|
42 |
+
# Path to write out index.
|
43 |
+
indexes_path = os.path.join(workspace, configs[split]["indexes"])
|
44 |
+
os.makedirs(os.path.dirname(indexes_path), exist_ok=True)
|
45 |
+
|
46 |
+
source_types = configs[split]["source_types"].keys()
|
47 |
+
# E.g., ['vocals', 'accompaniment']
|
48 |
+
|
49 |
+
indexes_dict = {source_type: [] for source_type in source_types}
|
50 |
+
# E.g., indexes_dict will looks like: {
|
51 |
+
# 'vocals': [
|
52 |
+
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
|
53 |
+
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
|
54 |
+
# ...
|
55 |
+
# ]
|
56 |
+
# 'accompaniment': [
|
57 |
+
# {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
|
58 |
+
# {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
|
59 |
+
# ...
|
60 |
+
# ]
|
61 |
+
# }
|
62 |
+
|
63 |
+
# Get training indexes for each source type.
|
64 |
+
for source_type in source_types:
|
65 |
+
# E.g., ['vocals', 'bass', ...]
|
66 |
+
|
67 |
+
print("--- {} ---".format(source_type))
|
68 |
+
|
69 |
+
dataset_types = configs[split]["source_types"][source_type]
|
70 |
+
# E.g., ['musdb18', ...]
|
71 |
+
|
72 |
+
# Each source can come from mulitple datasets.
|
73 |
+
for dataset_type in dataset_types:
|
74 |
+
|
75 |
+
hdf5s_dir = os.path.join(
|
76 |
+
workspace, dataset_types[dataset_type]["hdf5s_directory"]
|
77 |
+
)
|
78 |
+
|
79 |
+
hop_samples = int(dataset_types[dataset_type]["hop_seconds"] * sample_rate)
|
80 |
+
|
81 |
+
key_in_hdf5 = dataset_types[dataset_type]["key_in_hdf5"]
|
82 |
+
# E.g., 'vocals'
|
83 |
+
|
84 |
+
hdf5_names = sorted(os.listdir(hdf5s_dir))
|
85 |
+
print("Hdf5 files num: {}".format(len(hdf5_names)))
|
86 |
+
|
87 |
+
# Traverse all packed hdf5 files of a dataset.
|
88 |
+
for n, hdf5_name in enumerate(hdf5_names):
|
89 |
+
|
90 |
+
print(n, hdf5_name)
|
91 |
+
hdf5_path = os.path.join(hdf5s_dir, hdf5_name)
|
92 |
+
|
93 |
+
with h5py.File(hdf5_path, "r") as hf:
|
94 |
+
|
95 |
+
bgn_sample = 0
|
96 |
+
while bgn_sample + segment_samples < hf[key_in_hdf5].shape[-1]:
|
97 |
+
meta = {
|
98 |
+
'hdf5_path': hdf5_path,
|
99 |
+
'key_in_hdf5': key_in_hdf5,
|
100 |
+
'begin_sample': bgn_sample,
|
101 |
+
'end_sample': bgn_sample + segment_samples,
|
102 |
+
}
|
103 |
+
indexes_dict[source_type].append(meta)
|
104 |
+
|
105 |
+
bgn_sample += hop_samples
|
106 |
+
|
107 |
+
# If the audio length is shorter than the segment length,
|
108 |
+
# then use the entire audio as a segment.
|
109 |
+
if bgn_sample == 0:
|
110 |
+
meta = {
|
111 |
+
'hdf5_path': hdf5_path,
|
112 |
+
'key_in_hdf5': key_in_hdf5,
|
113 |
+
'begin_sample': 0,
|
114 |
+
'end_sample': segment_samples,
|
115 |
+
}
|
116 |
+
indexes_dict[source_type].append(meta)
|
117 |
+
|
118 |
+
print(
|
119 |
+
"Total indexes for {}: {}".format(
|
120 |
+
source_type, len(indexes_dict[source_type])
|
121 |
+
)
|
122 |
+
)
|
123 |
+
|
124 |
+
pickle.dump(indexes_dict, open(indexes_path, "wb"))
|
125 |
+
print("Write index dict to {}".format(indexes_path))
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
parser = argparse.ArgumentParser()
|
130 |
+
|
131 |
+
parser.add_argument(
|
132 |
+
"--workspace", type=str, required=True, help="Directory of workspace."
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--config_yaml", type=str, required=True, help="User defined config file."
|
136 |
+
)
|
137 |
+
|
138 |
+
# Parse arguments.
|
139 |
+
args = parser.parse_args()
|
140 |
+
|
141 |
+
# Create training indexes.
|
142 |
+
create_indexes(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py
ADDED
File without changes
|
bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
from typing import Dict, List, NoReturn
|
7 |
+
|
8 |
+
import h5py
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
from bytesep.utils import float32_to_int16, load_audio
|
13 |
+
|
14 |
+
|
15 |
+
def read_csv(meta_csv) -> Dict:
|
16 |
+
r"""Get train & test names from csv.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
meta_csv: str
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
names_dict: dict, e.g., {
|
23 |
+
'train', ['songA.mp3', 'songB.mp3', ...],
|
24 |
+
'test': ['songE.mp3', 'songF.mp3', ...]
|
25 |
+
}
|
26 |
+
"""
|
27 |
+
df = pd.read_csv(meta_csv, sep=',')
|
28 |
+
|
29 |
+
names_dict = {}
|
30 |
+
|
31 |
+
for split in ['train', 'test']:
|
32 |
+
audio_indexes = df['split'] == split
|
33 |
+
audio_names = list(df['audio_name'][audio_indexes])
|
34 |
+
audio_names = [
|
35 |
+
'{}.mp3'.format(pathlib.Path(audio_name).stem) for audio_name in audio_names
|
36 |
+
]
|
37 |
+
names_dict[split] = audio_names
|
38 |
+
|
39 |
+
return names_dict
|
40 |
+
|
41 |
+
|
42 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
43 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
dataset_dir: str
|
47 |
+
split: str, 'train' | 'test'
|
48 |
+
source_type: str
|
49 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
50 |
+
sample_rate: int
|
51 |
+
channels_num: int
|
52 |
+
mono: bool
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
NoReturn
|
56 |
+
"""
|
57 |
+
|
58 |
+
# arguments & parameters
|
59 |
+
dataset_dir = args.dataset_dir
|
60 |
+
split = args.split
|
61 |
+
source_type = args.source_type
|
62 |
+
hdf5s_dir = args.hdf5s_dir
|
63 |
+
sample_rate = args.sample_rate
|
64 |
+
channels = args.channels
|
65 |
+
mono = True if channels == 1 else False
|
66 |
+
|
67 |
+
# Only pack data for training data.
|
68 |
+
assert split == "train"
|
69 |
+
|
70 |
+
# paths
|
71 |
+
audios_dir = os.path.join(dataset_dir, 'mp3s')
|
72 |
+
meta_csv = os.path.join(dataset_dir, 'validation.csv')
|
73 |
+
|
74 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
75 |
+
|
76 |
+
# Read train & test names.
|
77 |
+
names_dict = read_csv(meta_csv)
|
78 |
+
|
79 |
+
audio_names = names_dict[split]
|
80 |
+
|
81 |
+
params = []
|
82 |
+
|
83 |
+
for audio_index, audio_name in enumerate(audio_names):
|
84 |
+
|
85 |
+
audio_path = os.path.join(audios_dir, audio_name)
|
86 |
+
|
87 |
+
hdf5_path = os.path.join(
|
88 |
+
hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
|
89 |
+
)
|
90 |
+
|
91 |
+
param = (
|
92 |
+
audio_index,
|
93 |
+
audio_name,
|
94 |
+
source_type,
|
95 |
+
audio_path,
|
96 |
+
mono,
|
97 |
+
sample_rate,
|
98 |
+
hdf5_path,
|
99 |
+
)
|
100 |
+
params.append(param)
|
101 |
+
|
102 |
+
# Uncomment for debug.
|
103 |
+
# write_single_audio_to_hdf5(params[0])
|
104 |
+
# os._exit()
|
105 |
+
|
106 |
+
pack_hdf5s_time = time.time()
|
107 |
+
|
108 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
109 |
+
# Maximum works on the machine
|
110 |
+
pool.map(write_single_audio_to_hdf5, params)
|
111 |
+
|
112 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
113 |
+
|
114 |
+
|
115 |
+
def write_single_audio_to_hdf5(param: List) -> NoReturn:
|
116 |
+
r"""Write single audio into hdf5 file."""
|
117 |
+
|
118 |
+
(
|
119 |
+
audio_index,
|
120 |
+
audio_name,
|
121 |
+
source_type,
|
122 |
+
audio_path,
|
123 |
+
mono,
|
124 |
+
sample_rate,
|
125 |
+
hdf5_path,
|
126 |
+
) = param
|
127 |
+
|
128 |
+
with h5py.File(hdf5_path, "w") as hf:
|
129 |
+
|
130 |
+
hf.attrs.create("audio_name", data=audio_name, dtype="S100")
|
131 |
+
hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
|
132 |
+
|
133 |
+
audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate)
|
134 |
+
# audio: (channels_num, audio_samples)
|
135 |
+
|
136 |
+
hf.create_dataset(
|
137 |
+
name=source_type, data=float32_to_int16(audio), dtype=np.int16
|
138 |
+
)
|
139 |
+
|
140 |
+
print('{} Write hdf5 to {}'.format(audio_index, hdf5_path))
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
parser = argparse.ArgumentParser()
|
145 |
+
|
146 |
+
parser.add_argument(
|
147 |
+
"--dataset_dir",
|
148 |
+
type=str,
|
149 |
+
required=True,
|
150 |
+
help="Directory of the instruments solo dataset.",
|
151 |
+
)
|
152 |
+
parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
|
153 |
+
parser.add_argument(
|
154 |
+
"--source_type",
|
155 |
+
type=str,
|
156 |
+
required=True,
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--hdf5s_dir",
|
160 |
+
type=str,
|
161 |
+
required=True,
|
162 |
+
help="Directory to write out hdf5 files.",
|
163 |
+
)
|
164 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
165 |
+
parser.add_argument(
|
166 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
167 |
+
)
|
168 |
+
|
169 |
+
# Parse arguments.
|
170 |
+
args = parser.parse_args()
|
171 |
+
|
172 |
+
# Pack audios to hdf5 files.
|
173 |
+
pack_audios_to_hdf5s(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
from typing import Dict, NoReturn
|
7 |
+
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
|
11 |
+
write_single_audio_to_hdf5,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def read_csv(meta_csv) -> Dict:
|
16 |
+
r"""Get train & test names from csv.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
meta_csv: str
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
names_dict: dict, e.g., {
|
23 |
+
'train', ['a1.mp3', 'a2.mp3'],
|
24 |
+
'test': ['b1.mp3', 'b2.mp3']
|
25 |
+
}
|
26 |
+
"""
|
27 |
+
df = pd.read_csv(meta_csv, sep=',')
|
28 |
+
|
29 |
+
names_dict = {}
|
30 |
+
|
31 |
+
for split in ['train', 'test']:
|
32 |
+
audio_indexes = df['split'] == split
|
33 |
+
audio_names = list(df['audio_filename'][audio_indexes])
|
34 |
+
names_dict[split] = audio_names
|
35 |
+
|
36 |
+
return names_dict
|
37 |
+
|
38 |
+
|
39 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
40 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
dataset_dir: str
|
44 |
+
split: str, 'train' | 'test'
|
45 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
46 |
+
sample_rate: int
|
47 |
+
channels_num: int
|
48 |
+
mono: bool
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
NoReturn
|
52 |
+
"""
|
53 |
+
|
54 |
+
# arguments & parameters
|
55 |
+
dataset_dir = args.dataset_dir
|
56 |
+
split = args.split
|
57 |
+
hdf5s_dir = args.hdf5s_dir
|
58 |
+
sample_rate = args.sample_rate
|
59 |
+
channels = args.channels
|
60 |
+
mono = True if channels == 1 else False
|
61 |
+
|
62 |
+
source_type = "piano"
|
63 |
+
|
64 |
+
# Only pack data for training data.
|
65 |
+
assert split == "train"
|
66 |
+
|
67 |
+
# paths
|
68 |
+
meta_csv = os.path.join(dataset_dir, 'maestro-v2.0.0.csv')
|
69 |
+
|
70 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
71 |
+
|
72 |
+
# Read train & test names.
|
73 |
+
names_dict = read_csv(meta_csv)
|
74 |
+
|
75 |
+
audio_names = names_dict['{}'.format(split)]
|
76 |
+
|
77 |
+
params = []
|
78 |
+
|
79 |
+
for audio_index, audio_name in enumerate(audio_names):
|
80 |
+
|
81 |
+
audio_path = os.path.join(dataset_dir, audio_name)
|
82 |
+
|
83 |
+
hdf5_path = os.path.join(
|
84 |
+
hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
|
85 |
+
)
|
86 |
+
|
87 |
+
param = (
|
88 |
+
audio_index,
|
89 |
+
audio_name,
|
90 |
+
source_type,
|
91 |
+
audio_path,
|
92 |
+
mono,
|
93 |
+
sample_rate,
|
94 |
+
hdf5_path,
|
95 |
+
)
|
96 |
+
params.append(param)
|
97 |
+
|
98 |
+
# Uncomment for debug.
|
99 |
+
# write_single_audio_to_hdf5(params[0])
|
100 |
+
# os._exit(0)
|
101 |
+
|
102 |
+
pack_hdf5s_time = time.time()
|
103 |
+
|
104 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
105 |
+
# Maximum works on the machine
|
106 |
+
pool.map(write_single_audio_to_hdf5, params)
|
107 |
+
|
108 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
parser = argparse.ArgumentParser()
|
113 |
+
|
114 |
+
parser.add_argument(
|
115 |
+
"--dataset_dir",
|
116 |
+
type=str,
|
117 |
+
required=True,
|
118 |
+
help="Directory of the MAESTRO dataset.",
|
119 |
+
)
|
120 |
+
parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
|
121 |
+
parser.add_argument(
|
122 |
+
"--hdf5s_dir",
|
123 |
+
type=str,
|
124 |
+
required=True,
|
125 |
+
help="Directory to write out hdf5 files.",
|
126 |
+
)
|
127 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
128 |
+
parser.add_argument(
|
129 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
130 |
+
)
|
131 |
+
|
132 |
+
# Parse arguments.
|
133 |
+
args = parser.parse_args()
|
134 |
+
|
135 |
+
# Pack audios to hdf5 files.
|
136 |
+
pack_audios_to_hdf5s(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from concurrent.futures import ProcessPoolExecutor
|
5 |
+
from typing import NoReturn
|
6 |
+
|
7 |
+
import h5py
|
8 |
+
import librosa
|
9 |
+
import musdb
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from bytesep.utils import float32_to_int16
|
13 |
+
|
14 |
+
# Source types of the MUSDB18 dataset.
|
15 |
+
SOURCE_TYPES = ["vocals", "drums", "bass", "other", "accompaniment"]
|
16 |
+
|
17 |
+
|
18 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
19 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
dataset_dir: str
|
23 |
+
subset: str, 'train' | 'test'
|
24 |
+
split: str, '' | 'train' | 'valid'
|
25 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
26 |
+
sample_rate: int
|
27 |
+
channels_num: int
|
28 |
+
mono: bool
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
NoReturn
|
32 |
+
"""
|
33 |
+
|
34 |
+
# arguments & parameters
|
35 |
+
dataset_dir = args.dataset_dir
|
36 |
+
subset = args.subset
|
37 |
+
split = None if args.split == "" else args.split
|
38 |
+
hdf5s_dir = args.hdf5s_dir
|
39 |
+
sample_rate = args.sample_rate
|
40 |
+
channels = args.channels
|
41 |
+
|
42 |
+
mono = True if channels == 1 else False
|
43 |
+
source_types = SOURCE_TYPES
|
44 |
+
resample_type = "kaiser_fast"
|
45 |
+
|
46 |
+
# Paths
|
47 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
48 |
+
|
49 |
+
# Dataset of corresponding subset and split.
|
50 |
+
mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split)
|
51 |
+
print("Subset: {}, Split: {}, Total pieces: {}".format(subset, split, len(mus)))
|
52 |
+
|
53 |
+
params = [] # A list of params for multiple processing.
|
54 |
+
|
55 |
+
for track_index in range(len(mus.tracks)):
|
56 |
+
|
57 |
+
param = (
|
58 |
+
dataset_dir,
|
59 |
+
subset,
|
60 |
+
split,
|
61 |
+
track_index,
|
62 |
+
source_types,
|
63 |
+
mono,
|
64 |
+
sample_rate,
|
65 |
+
resample_type,
|
66 |
+
hdf5s_dir,
|
67 |
+
)
|
68 |
+
|
69 |
+
params.append(param)
|
70 |
+
|
71 |
+
# Uncomment for debug.
|
72 |
+
# write_single_audio_to_hdf5(params[0])
|
73 |
+
# os._exit(0)
|
74 |
+
|
75 |
+
pack_hdf5s_time = time.time()
|
76 |
+
|
77 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
78 |
+
# Maximum works on the machine
|
79 |
+
pool.map(write_single_audio_to_hdf5, params)
|
80 |
+
|
81 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
82 |
+
|
83 |
+
|
84 |
+
def write_single_audio_to_hdf5(param) -> NoReturn:
|
85 |
+
r"""Write single audio into hdf5 file."""
|
86 |
+
(
|
87 |
+
dataset_dir,
|
88 |
+
subset,
|
89 |
+
split,
|
90 |
+
track_index,
|
91 |
+
source_types,
|
92 |
+
mono,
|
93 |
+
sample_rate,
|
94 |
+
resample_type,
|
95 |
+
hdf5s_dir,
|
96 |
+
) = param
|
97 |
+
|
98 |
+
# Dataset of corresponding subset and split.
|
99 |
+
mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split)
|
100 |
+
track = mus.tracks[track_index]
|
101 |
+
|
102 |
+
# Path to write out hdf5 file.
|
103 |
+
hdf5_path = os.path.join(hdf5s_dir, "{}.h5".format(track.name))
|
104 |
+
|
105 |
+
with h5py.File(hdf5_path, "w") as hf:
|
106 |
+
|
107 |
+
hf.attrs.create("audio_name", data=track.name.encode(), dtype="S100")
|
108 |
+
hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
|
109 |
+
|
110 |
+
for source_type in source_types:
|
111 |
+
|
112 |
+
audio = track.targets[source_type].audio.T
|
113 |
+
# (channels_num, audio_samples)
|
114 |
+
|
115 |
+
# Preprocess audio to mono / stereo, and resample.
|
116 |
+
audio = preprocess_audio(
|
117 |
+
audio, mono, track.rate, sample_rate, resample_type
|
118 |
+
)
|
119 |
+
# audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate)
|
120 |
+
# (channels_num, audio_samples) | (audio_samples,)
|
121 |
+
|
122 |
+
hf.create_dataset(
|
123 |
+
name=source_type, data=float32_to_int16(audio), dtype=np.int16
|
124 |
+
)
|
125 |
+
|
126 |
+
# Mixture
|
127 |
+
audio = track.audio.T
|
128 |
+
# (channels_num, audio_samples)
|
129 |
+
|
130 |
+
# Preprocess audio to mono / stereo, and resample.
|
131 |
+
audio = preprocess_audio(audio, mono, track.rate, sample_rate, resample_type)
|
132 |
+
# (channels_num, audio_samples)
|
133 |
+
|
134 |
+
hf.create_dataset(name="mixture", data=float32_to_int16(audio), dtype=np.int16)
|
135 |
+
|
136 |
+
print("{} Write to {}, {}".format(track_index, hdf5_path, audio.shape))
|
137 |
+
|
138 |
+
|
139 |
+
def preprocess_audio(audio, mono, origin_sr, sr, resample_type) -> np.array:
|
140 |
+
r"""Preprocess audio to mono / stereo, and resample.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
audio: (channels_num, audio_samples), input audio
|
144 |
+
mono: bool
|
145 |
+
origin_sr: float, original sample rate
|
146 |
+
sr: float, target sample rate
|
147 |
+
resample_type: str, e.g., 'kaiser_fast'
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
output: ndarray, output audio
|
151 |
+
"""
|
152 |
+
if mono:
|
153 |
+
audio = np.mean(audio, axis=0)
|
154 |
+
# (audio_samples,)
|
155 |
+
|
156 |
+
output = librosa.core.resample(
|
157 |
+
audio, orig_sr=origin_sr, target_sr=sr, res_type=resample_type
|
158 |
+
)
|
159 |
+
# (audio_samples,) | (channels_num, audio_samples)
|
160 |
+
|
161 |
+
if output.ndim == 1:
|
162 |
+
output = output[None, :]
|
163 |
+
# (1, audio_samples,)
|
164 |
+
|
165 |
+
return output
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
parser = argparse.ArgumentParser()
|
170 |
+
|
171 |
+
parser.add_argument(
|
172 |
+
"--dataset_dir",
|
173 |
+
type=str,
|
174 |
+
required=True,
|
175 |
+
help="Directory of the MUSDB18 dataset.",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--subset",
|
179 |
+
type=str,
|
180 |
+
required=True,
|
181 |
+
choices=["train", "test"],
|
182 |
+
help="Train subset: 100 pieces; test subset: 50 pieces.",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--split",
|
186 |
+
type=str,
|
187 |
+
required=True,
|
188 |
+
choices=["", "train", "valid"],
|
189 |
+
help="Use '' to use all 100 pieces to train. Use 'train' to use 86 \
|
190 |
+
pieces for train, and use 'test' to use 14 pieces for valid.",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--hdf5s_dir",
|
194 |
+
type=str,
|
195 |
+
required=True,
|
196 |
+
help="Directory to write out hdf5 files.",
|
197 |
+
)
|
198 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
199 |
+
parser.add_argument(
|
200 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
201 |
+
)
|
202 |
+
|
203 |
+
# Parse arguments.
|
204 |
+
args = parser.parse_args()
|
205 |
+
|
206 |
+
# Pack audios into hdf5 files.
|
207 |
+
pack_audios_to_hdf5s(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
from typing import NoReturn
|
7 |
+
|
8 |
+
from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
|
9 |
+
write_single_audio_to_hdf5,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
14 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
dataset_dir: str
|
18 |
+
split: str, 'train' | 'test'
|
19 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
20 |
+
sample_rate: int
|
21 |
+
channels_num: int
|
22 |
+
mono: bool
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
NoReturn
|
26 |
+
"""
|
27 |
+
|
28 |
+
# arguments & parameters
|
29 |
+
dataset_dir = args.dataset_dir
|
30 |
+
split = args.split
|
31 |
+
hdf5s_dir = args.hdf5s_dir
|
32 |
+
sample_rate = args.sample_rate
|
33 |
+
channels = args.channels
|
34 |
+
mono = True if channels == 1 else False
|
35 |
+
|
36 |
+
source_type = "speech"
|
37 |
+
|
38 |
+
# Only pack data for training data.
|
39 |
+
assert split == "train"
|
40 |
+
|
41 |
+
audios_dir = os.path.join(dataset_dir, 'wav48', split)
|
42 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
43 |
+
|
44 |
+
speaker_ids = sorted(os.listdir(audios_dir))
|
45 |
+
|
46 |
+
params = []
|
47 |
+
audio_index = 0
|
48 |
+
|
49 |
+
for speaker_id in speaker_ids:
|
50 |
+
|
51 |
+
speaker_audios_dir = os.path.join(audios_dir, speaker_id)
|
52 |
+
|
53 |
+
audio_names = sorted(os.listdir(speaker_audios_dir))
|
54 |
+
|
55 |
+
for audio_name in audio_names:
|
56 |
+
|
57 |
+
audio_path = os.path.join(speaker_audios_dir, audio_name)
|
58 |
+
|
59 |
+
hdf5_path = os.path.join(
|
60 |
+
hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
|
61 |
+
)
|
62 |
+
|
63 |
+
param = (
|
64 |
+
audio_index,
|
65 |
+
audio_name,
|
66 |
+
source_type,
|
67 |
+
audio_path,
|
68 |
+
mono,
|
69 |
+
sample_rate,
|
70 |
+
hdf5_path,
|
71 |
+
)
|
72 |
+
params.append(param)
|
73 |
+
|
74 |
+
audio_index += 1
|
75 |
+
|
76 |
+
# Uncomment for debug.
|
77 |
+
# write_single_audio_to_hdf5(params[0])
|
78 |
+
# os._exit(0)
|
79 |
+
|
80 |
+
pack_hdf5s_time = time.time()
|
81 |
+
|
82 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
83 |
+
# Maximum works on the machine
|
84 |
+
pool.map(write_single_audio_to_hdf5, params)
|
85 |
+
|
86 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
|
92 |
+
parser.add_argument(
|
93 |
+
"--dataset_dir",
|
94 |
+
type=str,
|
95 |
+
required=True,
|
96 |
+
help="Directory of the VCTK dataset.",
|
97 |
+
)
|
98 |
+
parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
|
99 |
+
parser.add_argument(
|
100 |
+
"--hdf5s_dir",
|
101 |
+
type=str,
|
102 |
+
required=True,
|
103 |
+
help="Directory to write out hdf5 files.",
|
104 |
+
)
|
105 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
106 |
+
parser.add_argument(
|
107 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
108 |
+
)
|
109 |
+
|
110 |
+
# Parse arguments.
|
111 |
+
args = parser.parse_args()
|
112 |
+
|
113 |
+
# Pack audios into hdf5 files.
|
114 |
+
pack_audios_to_hdf5s(args)
|
bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from concurrent.futures import ProcessPoolExecutor
|
6 |
+
from typing import List, NoReturn
|
7 |
+
|
8 |
+
import h5py
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from bytesep.utils import float32_to_int16, load_audio
|
12 |
+
|
13 |
+
|
14 |
+
def pack_audios_to_hdf5s(args) -> NoReturn:
|
15 |
+
r"""Pack (resampled) audio files into hdf5 files to speed up loading.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
dataset_dir: str
|
19 |
+
split: str, 'train' | 'test'
|
20 |
+
hdf5s_dir: str, directory to write out hdf5 files
|
21 |
+
sample_rate: int
|
22 |
+
channels_num: int
|
23 |
+
mono: bool
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
NoReturn
|
27 |
+
"""
|
28 |
+
|
29 |
+
# arguments & parameters
|
30 |
+
dataset_dir = args.dataset_dir
|
31 |
+
split = args.split
|
32 |
+
hdf5s_dir = args.hdf5s_dir
|
33 |
+
sample_rate = args.sample_rate
|
34 |
+
channels = args.channels
|
35 |
+
mono = True if channels == 1 else False
|
36 |
+
|
37 |
+
# Only pack data for training data.
|
38 |
+
assert split == "train"
|
39 |
+
|
40 |
+
speech_dir = os.path.join(dataset_dir, "clean_{}set_wav".format(split))
|
41 |
+
mixture_dir = os.path.join(dataset_dir, "noisy_{}set_wav".format(split))
|
42 |
+
|
43 |
+
os.makedirs(hdf5s_dir, exist_ok=True)
|
44 |
+
|
45 |
+
# Read names.
|
46 |
+
audio_names = sorted(os.listdir(speech_dir))
|
47 |
+
|
48 |
+
params = []
|
49 |
+
|
50 |
+
for audio_index, audio_name in enumerate(audio_names):
|
51 |
+
|
52 |
+
speech_path = os.path.join(speech_dir, audio_name)
|
53 |
+
mixture_path = os.path.join(mixture_dir, audio_name)
|
54 |
+
|
55 |
+
hdf5_path = os.path.join(
|
56 |
+
hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
|
57 |
+
)
|
58 |
+
|
59 |
+
param = (
|
60 |
+
audio_index,
|
61 |
+
audio_name,
|
62 |
+
speech_path,
|
63 |
+
mixture_path,
|
64 |
+
mono,
|
65 |
+
sample_rate,
|
66 |
+
hdf5_path,
|
67 |
+
)
|
68 |
+
params.append(param)
|
69 |
+
|
70 |
+
# Uncomment for debug.
|
71 |
+
# write_single_audio_to_hdf5(params[0])
|
72 |
+
# os._exit(0)
|
73 |
+
|
74 |
+
pack_hdf5s_time = time.time()
|
75 |
+
|
76 |
+
with ProcessPoolExecutor(max_workers=None) as pool:
|
77 |
+
# Maximum works on the machine
|
78 |
+
pool.map(write_single_audio_to_hdf5, params)
|
79 |
+
|
80 |
+
print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
|
81 |
+
|
82 |
+
|
83 |
+
def write_single_audio_to_hdf5(param: List) -> NoReturn:
|
84 |
+
r"""Write single audio into hdf5 file."""
|
85 |
+
|
86 |
+
(
|
87 |
+
audio_index,
|
88 |
+
audio_name,
|
89 |
+
speech_path,
|
90 |
+
mixture_path,
|
91 |
+
mono,
|
92 |
+
sample_rate,
|
93 |
+
hdf5_path,
|
94 |
+
) = param
|
95 |
+
|
96 |
+
with h5py.File(hdf5_path, "w") as hf:
|
97 |
+
|
98 |
+
hf.attrs.create("audio_name", data=audio_name, dtype="S100")
|
99 |
+
hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
|
100 |
+
|
101 |
+
speech = load_audio(audio_path=speech_path, mono=mono, sample_rate=sample_rate)
|
102 |
+
# speech: (channels_num, audio_samples)
|
103 |
+
|
104 |
+
mixture = load_audio(
|
105 |
+
audio_path=mixture_path, mono=mono, sample_rate=sample_rate
|
106 |
+
)
|
107 |
+
# mixture: (channels_num, audio_samples)
|
108 |
+
|
109 |
+
noise = mixture - speech
|
110 |
+
# noise: (channels_num, audio_samples)
|
111 |
+
|
112 |
+
hf.create_dataset(name='speech', data=float32_to_int16(speech), dtype=np.int16)
|
113 |
+
hf.create_dataset(name='noise', data=float32_to_int16(noise), dtype=np.int16)
|
114 |
+
|
115 |
+
print('{} Write hdf5 to {}'.format(audio_index, hdf5_path))
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
parser = argparse.ArgumentParser()
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--dataset_dir",
|
123 |
+
type=str,
|
124 |
+
required=True,
|
125 |
+
help="Directory of the Voicebank-Demand dataset.",
|
126 |
+
)
|
127 |
+
parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
|
128 |
+
parser.add_argument(
|
129 |
+
"--hdf5s_dir",
|
130 |
+
type=str,
|
131 |
+
required=True,
|
132 |
+
help="Directory to write out hdf5 files.",
|
133 |
+
)
|
134 |
+
parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
|
135 |
+
parser.add_argument(
|
136 |
+
"--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
|
137 |
+
)
|
138 |
+
|
139 |
+
# Parse arguments.
|
140 |
+
args = parser.parse_args()
|
141 |
+
|
142 |
+
# Pack audios into hdf5 files.
|
143 |
+
pack_audios_to_hdf5s(args)
|
bytesep/inference.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Dict
|
5 |
+
import pathlib
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import soundfile
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from bytesep.models.lightning_modules import get_model_class
|
14 |
+
from bytesep.utils import read_yaml
|
15 |
+
|
16 |
+
|
17 |
+
class Separator:
|
18 |
+
def __init__(
|
19 |
+
self, model: nn.Module, segment_samples: int, batch_size: int, device: str
|
20 |
+
):
|
21 |
+
r"""Separate to separate an audio clip into a target source.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model: nn.Module, trained model
|
25 |
+
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
|
26 |
+
batch_size, int, e.g., 12
|
27 |
+
device: str, e.g., 'cuda'
|
28 |
+
"""
|
29 |
+
self.model = model
|
30 |
+
self.segment_samples = segment_samples
|
31 |
+
self.batch_size = batch_size
|
32 |
+
self.device = device
|
33 |
+
|
34 |
+
def separate(self, input_dict: Dict) -> np.array:
|
35 |
+
r"""Separate an audio clip into a target source.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
input_dict: dict, e.g., {
|
39 |
+
waveform: (channels_num, audio_samples),
|
40 |
+
...,
|
41 |
+
}
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples)
|
45 |
+
"""
|
46 |
+
audio = input_dict['waveform']
|
47 |
+
|
48 |
+
audio_samples = audio.shape[-1]
|
49 |
+
|
50 |
+
# Pad the audio with zero in the end so that the length of audio can be
|
51 |
+
# evenly divided by segment_samples.
|
52 |
+
audio = self.pad_audio(audio)
|
53 |
+
|
54 |
+
# Enframe long audio into segments.
|
55 |
+
segments = self.enframe(audio, self.segment_samples)
|
56 |
+
# (segments_num, channels_num, segment_samples)
|
57 |
+
|
58 |
+
segments_input_dict = {'waveform': segments}
|
59 |
+
|
60 |
+
if 'condition' in input_dict.keys():
|
61 |
+
segments_num = len(segments)
|
62 |
+
segments_input_dict['condition'] = np.tile(
|
63 |
+
input_dict['condition'][None, :], (segments_num, 1)
|
64 |
+
)
|
65 |
+
# (batch_size, segments_num)
|
66 |
+
|
67 |
+
# Separate in mini-batches.
|
68 |
+
sep_segments = self._forward_in_mini_batches(
|
69 |
+
self.model, segments_input_dict, self.batch_size
|
70 |
+
)['waveform']
|
71 |
+
# (segments_num, channels_num, segment_samples)
|
72 |
+
|
73 |
+
# Deframe segments into long audio.
|
74 |
+
sep_audio = self.deframe(sep_segments)
|
75 |
+
# (channels_num, padded_audio_samples)
|
76 |
+
|
77 |
+
sep_audio = sep_audio[:, 0:audio_samples]
|
78 |
+
# (channels_num, audio_samples)
|
79 |
+
|
80 |
+
return sep_audio
|
81 |
+
|
82 |
+
def pad_audio(self, audio: np.array) -> np.array:
|
83 |
+
r"""Pad the audio with zero in the end so that the length of audio can
|
84 |
+
be evenly divided by segment_samples.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
audio: (channels_num, audio_samples)
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
padded_audio: (channels_num, audio_samples)
|
91 |
+
"""
|
92 |
+
channels_num, audio_samples = audio.shape
|
93 |
+
|
94 |
+
# Number of segments
|
95 |
+
segments_num = int(np.ceil(audio_samples / self.segment_samples))
|
96 |
+
|
97 |
+
pad_samples = segments_num * self.segment_samples - audio_samples
|
98 |
+
|
99 |
+
padded_audio = np.concatenate(
|
100 |
+
(audio, np.zeros((channels_num, pad_samples))), axis=1
|
101 |
+
)
|
102 |
+
# (channels_num, padded_audio_samples)
|
103 |
+
|
104 |
+
return padded_audio
|
105 |
+
|
106 |
+
def enframe(self, audio: np.array, segment_samples: int) -> np.array:
|
107 |
+
r"""Enframe long audio into segments.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
audio: (channels_num, audio_samples)
|
111 |
+
segment_samples: int
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
segments: (segments_num, channels_num, segment_samples)
|
115 |
+
"""
|
116 |
+
audio_samples = audio.shape[1]
|
117 |
+
assert audio_samples % segment_samples == 0
|
118 |
+
|
119 |
+
hop_samples = segment_samples // 2
|
120 |
+
segments = []
|
121 |
+
|
122 |
+
pointer = 0
|
123 |
+
while pointer + segment_samples <= audio_samples:
|
124 |
+
segments.append(audio[:, pointer : pointer + segment_samples])
|
125 |
+
pointer += hop_samples
|
126 |
+
|
127 |
+
segments = np.array(segments)
|
128 |
+
|
129 |
+
return segments
|
130 |
+
|
131 |
+
def deframe(self, segments: np.array) -> np.array:
|
132 |
+
r"""Deframe segments into long audio.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
segments: (segments_num, channels_num, segment_samples)
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
output: (channels_num, audio_samples)
|
139 |
+
"""
|
140 |
+
(segments_num, _, segment_samples) = segments.shape
|
141 |
+
|
142 |
+
if segments_num == 1:
|
143 |
+
return segments[0]
|
144 |
+
|
145 |
+
assert self._is_integer(segment_samples * 0.25)
|
146 |
+
assert self._is_integer(segment_samples * 0.75)
|
147 |
+
|
148 |
+
output = []
|
149 |
+
|
150 |
+
output.append(segments[0, :, 0 : int(segment_samples * 0.75)])
|
151 |
+
|
152 |
+
for i in range(1, segments_num - 1):
|
153 |
+
output.append(
|
154 |
+
segments[
|
155 |
+
i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75)
|
156 |
+
]
|
157 |
+
)
|
158 |
+
|
159 |
+
output.append(segments[-1, :, int(segment_samples * 0.25) :])
|
160 |
+
|
161 |
+
output = np.concatenate(output, axis=-1)
|
162 |
+
|
163 |
+
return output
|
164 |
+
|
165 |
+
def _is_integer(self, x: float) -> bool:
|
166 |
+
if x - int(x) < 1e-10:
|
167 |
+
return True
|
168 |
+
else:
|
169 |
+
return False
|
170 |
+
|
171 |
+
def _forward_in_mini_batches(
|
172 |
+
self, model: nn.Module, segments_input_dict: Dict, batch_size: int
|
173 |
+
) -> Dict:
|
174 |
+
r"""Forward data to model in mini-batch.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
model: nn.Module
|
178 |
+
segments_input_dict: dict, e.g., {
|
179 |
+
'waveform': (segments_num, channels_num, segment_samples),
|
180 |
+
...,
|
181 |
+
}
|
182 |
+
batch_size: int
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
output_dict: dict, e.g. {
|
186 |
+
'waveform': (segments_num, channels_num, segment_samples),
|
187 |
+
}
|
188 |
+
"""
|
189 |
+
output_dict = {}
|
190 |
+
|
191 |
+
pointer = 0
|
192 |
+
segments_num = len(segments_input_dict['waveform'])
|
193 |
+
|
194 |
+
while True:
|
195 |
+
if pointer >= segments_num:
|
196 |
+
break
|
197 |
+
|
198 |
+
batch_input_dict = {}
|
199 |
+
|
200 |
+
for key in segments_input_dict.keys():
|
201 |
+
batch_input_dict[key] = torch.Tensor(
|
202 |
+
segments_input_dict[key][pointer : pointer + batch_size]
|
203 |
+
).to(self.device)
|
204 |
+
|
205 |
+
pointer += batch_size
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
model.eval()
|
209 |
+
batch_output_dict = model(batch_input_dict)
|
210 |
+
|
211 |
+
for key in batch_output_dict.keys():
|
212 |
+
self._append_to_dict(
|
213 |
+
output_dict, key, batch_output_dict[key].data.cpu().numpy()
|
214 |
+
)
|
215 |
+
|
216 |
+
for key in output_dict.keys():
|
217 |
+
output_dict[key] = np.concatenate(output_dict[key], axis=0)
|
218 |
+
|
219 |
+
return output_dict
|
220 |
+
|
221 |
+
def _append_to_dict(self, dict, key, value):
|
222 |
+
if key in dict.keys():
|
223 |
+
dict[key].append(value)
|
224 |
+
else:
|
225 |
+
dict[key] = [value]
|
226 |
+
|
227 |
+
|
228 |
+
class SeparatorWrapper:
|
229 |
+
def __init__(
|
230 |
+
self, source_type='vocals', model=None, checkpoint_path=None, device='cuda'
|
231 |
+
):
|
232 |
+
|
233 |
+
input_channels = 2
|
234 |
+
target_sources_num = 1
|
235 |
+
model_type = "ResUNet143_Subbandtime"
|
236 |
+
segment_samples = 44100 * 10
|
237 |
+
batch_size = 1
|
238 |
+
|
239 |
+
self.checkpoint_path = self.download_checkpoints(checkpoint_path, source_type)
|
240 |
+
|
241 |
+
if device == 'cuda' and torch.cuda.is_available():
|
242 |
+
self.device = 'cuda'
|
243 |
+
else:
|
244 |
+
self.device = 'cpu'
|
245 |
+
|
246 |
+
# Get model class.
|
247 |
+
Model = get_model_class(model_type)
|
248 |
+
|
249 |
+
# Create model.
|
250 |
+
self.model = Model(
|
251 |
+
input_channels=input_channels, target_sources_num=target_sources_num
|
252 |
+
)
|
253 |
+
|
254 |
+
# Load checkpoint.
|
255 |
+
checkpoint = torch.load(self.checkpoint_path, map_location='cpu')
|
256 |
+
self.model.load_state_dict(checkpoint["model"])
|
257 |
+
|
258 |
+
# Move model to device.
|
259 |
+
self.model.to(self.device)
|
260 |
+
|
261 |
+
# Create separator.
|
262 |
+
self.separator = Separator(
|
263 |
+
model=self.model,
|
264 |
+
segment_samples=segment_samples,
|
265 |
+
batch_size=batch_size,
|
266 |
+
device=self.device,
|
267 |
+
)
|
268 |
+
|
269 |
+
def download_checkpoints(self, checkpoint_path, source_type):
|
270 |
+
|
271 |
+
if source_type == "vocals":
|
272 |
+
checkpoint_bare_name = "resunet143_subbtandtime_vocals_8.8dB_350k_steps"
|
273 |
+
|
274 |
+
elif source_type == "accompaniment":
|
275 |
+
checkpoint_bare_name = (
|
276 |
+
"resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth"
|
277 |
+
)
|
278 |
+
|
279 |
+
else:
|
280 |
+
raise NotImplementedError
|
281 |
+
|
282 |
+
if not checkpoint_path:
|
283 |
+
checkpoint_path = '{}/bytesep_data/{}.pth'.format(
|
284 |
+
str(pathlib.Path.home()), checkpoint_bare_name
|
285 |
+
)
|
286 |
+
|
287 |
+
print('Checkpoint path: {}'.format(checkpoint_path))
|
288 |
+
|
289 |
+
if (
|
290 |
+
not os.path.exists(checkpoint_path)
|
291 |
+
or os.path.getsize(checkpoint_path) < 4e8
|
292 |
+
):
|
293 |
+
|
294 |
+
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
|
295 |
+
|
296 |
+
zenodo_dir = "https://zenodo.org/record/5507029/files"
|
297 |
+
zenodo_path = os.path.join(
|
298 |
+
zenodo_dir, "{}?download=1".format(checkpoint_bare_name)
|
299 |
+
)
|
300 |
+
|
301 |
+
os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
|
302 |
+
|
303 |
+
return checkpoint_path
|
304 |
+
|
305 |
+
def separate(self, audio):
|
306 |
+
|
307 |
+
input_dict = {'waveform': audio}
|
308 |
+
|
309 |
+
sep_wav = self.separator.separate(input_dict)
|
310 |
+
|
311 |
+
return sep_wav
|
312 |
+
|
313 |
+
|
314 |
+
def inference(args):
|
315 |
+
|
316 |
+
# Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
|
317 |
+
import torch.distributed as dist
|
318 |
+
|
319 |
+
dist.init_process_group(
|
320 |
+
'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
|
321 |
+
)
|
322 |
+
|
323 |
+
# Arguments & parameters
|
324 |
+
config_yaml = args.config_yaml
|
325 |
+
checkpoint_path = args.checkpoint_path
|
326 |
+
audio_path = args.audio_path
|
327 |
+
output_path = args.output_path
|
328 |
+
device = (
|
329 |
+
torch.device('cuda')
|
330 |
+
if args.cuda and torch.cuda.is_available()
|
331 |
+
else torch.device('cpu')
|
332 |
+
)
|
333 |
+
|
334 |
+
configs = read_yaml(config_yaml)
|
335 |
+
sample_rate = configs['train']['sample_rate']
|
336 |
+
input_channels = configs['train']['channels']
|
337 |
+
target_source_types = configs['train']['target_source_types']
|
338 |
+
target_sources_num = len(target_source_types)
|
339 |
+
model_type = configs['train']['model_type']
|
340 |
+
|
341 |
+
segment_samples = int(30 * sample_rate)
|
342 |
+
batch_size = 1
|
343 |
+
|
344 |
+
print("Using {} for separating ..".format(device))
|
345 |
+
|
346 |
+
# paths
|
347 |
+
if os.path.dirname(output_path) != "":
|
348 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
349 |
+
|
350 |
+
# Get model class.
|
351 |
+
Model = get_model_class(model_type)
|
352 |
+
|
353 |
+
# Create model.
|
354 |
+
model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
|
355 |
+
|
356 |
+
# Load checkpoint.
|
357 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
358 |
+
model.load_state_dict(checkpoint["model"])
|
359 |
+
|
360 |
+
# Move model to device.
|
361 |
+
model.to(device)
|
362 |
+
|
363 |
+
# Create separator.
|
364 |
+
separator = Separator(
|
365 |
+
model=model,
|
366 |
+
segment_samples=segment_samples,
|
367 |
+
batch_size=batch_size,
|
368 |
+
device=device,
|
369 |
+
)
|
370 |
+
|
371 |
+
# Load audio.
|
372 |
+
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=False)
|
373 |
+
|
374 |
+
# audio = audio[None, :]
|
375 |
+
|
376 |
+
input_dict = {'waveform': audio}
|
377 |
+
|
378 |
+
# Separate
|
379 |
+
separate_time = time.time()
|
380 |
+
|
381 |
+
sep_wav = separator.separate(input_dict)
|
382 |
+
# (channels_num, audio_samples)
|
383 |
+
|
384 |
+
print('Separate time: {:.3f} s'.format(time.time() - separate_time))
|
385 |
+
|
386 |
+
# Write out separated audio.
|
387 |
+
soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
|
388 |
+
os.system("ffmpeg -y -loglevel panic -i _zz.wav {}".format(output_path))
|
389 |
+
print('Write out to {}'.format(output_path))
|
390 |
+
|
391 |
+
|
392 |
+
if __name__ == "__main__":
|
393 |
+
|
394 |
+
parser = argparse.ArgumentParser(description="")
|
395 |
+
parser.add_argument("--config_yaml", type=str, required=True)
|
396 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
397 |
+
parser.add_argument("--audio_path", type=str, required=True)
|
398 |
+
parser.add_argument("--output_path", type=str, required=True)
|
399 |
+
parser.add_argument("--cuda", action='store_true', default=True)
|
400 |
+
|
401 |
+
args = parser.parse_args()
|
402 |
+
inference(args)
|
bytesep/inference_many.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import time
|
5 |
+
from typing import NoReturn
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import soundfile
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from bytesep.inference import Separator
|
13 |
+
from bytesep.models.lightning_modules import get_model_class
|
14 |
+
from bytesep.utils import read_yaml
|
15 |
+
|
16 |
+
|
17 |
+
def inference(args) -> NoReturn:
|
18 |
+
r"""Separate all audios in a directory.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
config_yaml: str, the config file of a model being trained
|
22 |
+
checkpoint_path: str, the path of checkpoint to be loaded
|
23 |
+
audios_dir: str, the directory of audios to be separated
|
24 |
+
output_dir: str, the directory to write out separated audios
|
25 |
+
scale_volume: bool, if True then the volume is scaled to the maximum value of 1.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
NoReturn
|
29 |
+
"""
|
30 |
+
|
31 |
+
# Arguments & parameters
|
32 |
+
config_yaml = args.config_yaml
|
33 |
+
checkpoint_path = args.checkpoint_path
|
34 |
+
audios_dir = args.audios_dir
|
35 |
+
output_dir = args.output_dir
|
36 |
+
scale_volume = args.scale_volume
|
37 |
+
device = (
|
38 |
+
torch.device('cuda')
|
39 |
+
if args.cuda and torch.cuda.is_available()
|
40 |
+
else torch.device('cpu')
|
41 |
+
)
|
42 |
+
|
43 |
+
configs = read_yaml(config_yaml)
|
44 |
+
sample_rate = configs['train']['sample_rate']
|
45 |
+
input_channels = configs['train']['channels']
|
46 |
+
target_source_types = configs['train']['target_source_types']
|
47 |
+
target_sources_num = len(target_source_types)
|
48 |
+
model_type = configs['train']['model_type']
|
49 |
+
mono = input_channels == 1
|
50 |
+
|
51 |
+
segment_samples = int(30 * sample_rate)
|
52 |
+
batch_size = 1
|
53 |
+
device = "cuda"
|
54 |
+
|
55 |
+
models_contains_inplaceabn = True
|
56 |
+
|
57 |
+
# Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
|
58 |
+
if models_contains_inplaceabn:
|
59 |
+
|
60 |
+
import torch.distributed as dist
|
61 |
+
|
62 |
+
dist.init_process_group(
|
63 |
+
'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
|
64 |
+
)
|
65 |
+
|
66 |
+
print("Using {} for separating ..".format(device))
|
67 |
+
|
68 |
+
# paths
|
69 |
+
os.makedirs(output_dir, exist_ok=True)
|
70 |
+
|
71 |
+
# Get model class.
|
72 |
+
Model = get_model_class(model_type)
|
73 |
+
|
74 |
+
# Create model.
|
75 |
+
model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
|
76 |
+
|
77 |
+
# Load checkpoint.
|
78 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
79 |
+
model.load_state_dict(checkpoint["model"])
|
80 |
+
|
81 |
+
# Move model to device.
|
82 |
+
model.to(device)
|
83 |
+
|
84 |
+
# Create separator.
|
85 |
+
separator = Separator(
|
86 |
+
model=model,
|
87 |
+
segment_samples=segment_samples,
|
88 |
+
batch_size=batch_size,
|
89 |
+
device=device,
|
90 |
+
)
|
91 |
+
|
92 |
+
audio_names = sorted(os.listdir(audios_dir))
|
93 |
+
|
94 |
+
for audio_name in audio_names:
|
95 |
+
audio_path = os.path.join(audios_dir, audio_name)
|
96 |
+
|
97 |
+
# Load audio.
|
98 |
+
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=mono)
|
99 |
+
|
100 |
+
if audio.ndim == 1:
|
101 |
+
audio = audio[None, :]
|
102 |
+
|
103 |
+
input_dict = {'waveform': audio}
|
104 |
+
|
105 |
+
# Separate
|
106 |
+
separate_time = time.time()
|
107 |
+
|
108 |
+
sep_wav = separator.separate(input_dict)
|
109 |
+
# (channels_num, audio_samples)
|
110 |
+
|
111 |
+
print('Separate time: {:.3f} s'.format(time.time() - separate_time))
|
112 |
+
|
113 |
+
# Write out separated audio.
|
114 |
+
if scale_volume:
|
115 |
+
sep_wav /= np.max(np.abs(sep_wav))
|
116 |
+
|
117 |
+
soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
|
118 |
+
|
119 |
+
output_path = os.path.join(
|
120 |
+
output_dir, '{}.mp3'.format(pathlib.Path(audio_name).stem)
|
121 |
+
)
|
122 |
+
os.system('ffmpeg -y -loglevel panic -i _zz.wav "{}"'.format(output_path))
|
123 |
+
print('Write out to {}'.format(output_path))
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
|
128 |
+
parser = argparse.ArgumentParser(description="")
|
129 |
+
parser.add_argument(
|
130 |
+
"--config_yaml",
|
131 |
+
type=str,
|
132 |
+
required=True,
|
133 |
+
help="The config file of a model being trained.",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--checkpoint_path",
|
137 |
+
type=str,
|
138 |
+
required=True,
|
139 |
+
help="The path of checkpoint to be loaded.",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--audios_dir",
|
143 |
+
type=str,
|
144 |
+
required=True,
|
145 |
+
help="The directory of audios to be separated.",
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--output_dir",
|
149 |
+
type=str,
|
150 |
+
required=True,
|
151 |
+
help="The directory to write out separated audios.",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
'--scale_volume',
|
155 |
+
action='store_true',
|
156 |
+
default=False,
|
157 |
+
help="set to True if separated audios are scaled to the maximum value of 1.",
|
158 |
+
)
|
159 |
+
parser.add_argument("--cuda", action='store_true', default=True)
|
160 |
+
|
161 |
+
args = parser.parse_args()
|
162 |
+
|
163 |
+
inference(args)
|
bytesep/losses.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchlibrosa.stft import STFT
|
7 |
+
|
8 |
+
from bytesep.models.pytorch_modules import Base
|
9 |
+
|
10 |
+
|
11 |
+
def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
|
12 |
+
r"""L1 loss.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
output: torch.Tensor
|
16 |
+
target: torch.Tensor
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
loss: torch.float
|
20 |
+
"""
|
21 |
+
return torch.mean(torch.abs(output - target))
|
22 |
+
|
23 |
+
|
24 |
+
def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
|
25 |
+
r"""L1 loss in the time-domain.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
output: torch.Tensor
|
29 |
+
target: torch.Tensor
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
loss: torch.float
|
33 |
+
"""
|
34 |
+
return l1(output, target)
|
35 |
+
|
36 |
+
|
37 |
+
class L1_Wav_L1_Sp(nn.Module, Base):
|
38 |
+
def __init__(self):
|
39 |
+
r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
|
40 |
+
super(L1_Wav_L1_Sp, self).__init__()
|
41 |
+
|
42 |
+
self.window_size = 2048
|
43 |
+
hop_size = 441
|
44 |
+
center = True
|
45 |
+
pad_mode = "reflect"
|
46 |
+
window = "hann"
|
47 |
+
|
48 |
+
self.stft = STFT(
|
49 |
+
n_fft=self.window_size,
|
50 |
+
hop_length=hop_size,
|
51 |
+
win_length=self.window_size,
|
52 |
+
window=window,
|
53 |
+
center=center,
|
54 |
+
pad_mode=pad_mode,
|
55 |
+
freeze_parameters=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
def __call__(
|
59 |
+
self, output: torch.Tensor, target: torch.Tensor, **kwargs
|
60 |
+
) -> torch.Tensor:
|
61 |
+
r"""L1 loss in the time-domain and on the spectrogram.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
output: torch.Tensor
|
65 |
+
target: torch.Tensor
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
loss: torch.float
|
69 |
+
"""
|
70 |
+
|
71 |
+
# L1 loss in the time-domain.
|
72 |
+
wav_loss = l1_wav(output, target)
|
73 |
+
|
74 |
+
# L1 loss on the spectrogram.
|
75 |
+
sp_loss = l1(
|
76 |
+
self.wav_to_spectrogram(output, eps=1e-8),
|
77 |
+
self.wav_to_spectrogram(target, eps=1e-8),
|
78 |
+
)
|
79 |
+
|
80 |
+
# sp_loss /= math.sqrt(self.window_size)
|
81 |
+
# sp_loss *= 1.
|
82 |
+
|
83 |
+
# Total loss.
|
84 |
+
return wav_loss + sp_loss
|
85 |
+
|
86 |
+
return sp_loss
|
87 |
+
|
88 |
+
|
89 |
+
def get_loss_function(loss_type: str) -> Callable:
|
90 |
+
r"""Get loss function.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
loss_type: str
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
loss function: Callable
|
97 |
+
"""
|
98 |
+
|
99 |
+
if loss_type == "l1_wav":
|
100 |
+
return l1_wav
|
101 |
+
|
102 |
+
elif loss_type == "l1_wav_l1_sp":
|
103 |
+
return L1_Wav_L1_Sp()
|
104 |
+
|
105 |
+
else:
|
106 |
+
raise NotImplementedError
|
bytesep/models/__init__.py
ADDED
File without changes
|
bytesep/models/conditional_unet.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.optim as optim
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from torchlibrosa.stft import STFT, ISTFT, magphase
|
13 |
+
|
14 |
+
from bytesep.models.pytorch_modules import (
|
15 |
+
Base,
|
16 |
+
init_bn,
|
17 |
+
init_embedding,
|
18 |
+
init_layer,
|
19 |
+
act,
|
20 |
+
Subband,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class ConvBlock(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
in_channels,
|
28 |
+
out_channels,
|
29 |
+
condition_size,
|
30 |
+
kernel_size,
|
31 |
+
activation,
|
32 |
+
momentum,
|
33 |
+
):
|
34 |
+
super(ConvBlock, self).__init__()
|
35 |
+
|
36 |
+
self.activation = activation
|
37 |
+
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
38 |
+
|
39 |
+
self.conv1 = nn.Conv2d(
|
40 |
+
in_channels=in_channels,
|
41 |
+
out_channels=out_channels,
|
42 |
+
kernel_size=kernel_size,
|
43 |
+
stride=(1, 1),
|
44 |
+
dilation=(1, 1),
|
45 |
+
padding=padding,
|
46 |
+
bias=False,
|
47 |
+
)
|
48 |
+
|
49 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
50 |
+
|
51 |
+
self.conv2 = nn.Conv2d(
|
52 |
+
in_channels=out_channels,
|
53 |
+
out_channels=out_channels,
|
54 |
+
kernel_size=kernel_size,
|
55 |
+
stride=(1, 1),
|
56 |
+
dilation=(1, 1),
|
57 |
+
padding=padding,
|
58 |
+
bias=False,
|
59 |
+
)
|
60 |
+
|
61 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
62 |
+
|
63 |
+
self.beta1 = nn.Linear(condition_size, out_channels, bias=True)
|
64 |
+
self.beta2 = nn.Linear(condition_size, out_channels, bias=True)
|
65 |
+
|
66 |
+
self.init_weights()
|
67 |
+
|
68 |
+
def init_weights(self):
|
69 |
+
init_layer(self.conv1)
|
70 |
+
init_layer(self.conv2)
|
71 |
+
init_bn(self.bn1)
|
72 |
+
init_bn(self.bn2)
|
73 |
+
init_embedding(self.beta1)
|
74 |
+
init_embedding(self.beta2)
|
75 |
+
|
76 |
+
def forward(self, x, condition):
|
77 |
+
|
78 |
+
b1 = self.beta1(condition)[:, :, None, None]
|
79 |
+
b2 = self.beta2(condition)[:, :, None, None]
|
80 |
+
|
81 |
+
x = act(self.bn1(self.conv1(x)) + b1, self.activation)
|
82 |
+
x = act(self.bn2(self.conv2(x)) + b2, self.activation)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class EncoderBlock(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
in_channels,
|
90 |
+
out_channels,
|
91 |
+
condition_size,
|
92 |
+
kernel_size,
|
93 |
+
downsample,
|
94 |
+
activation,
|
95 |
+
momentum,
|
96 |
+
):
|
97 |
+
super(EncoderBlock, self).__init__()
|
98 |
+
|
99 |
+
self.conv_block = ConvBlock(
|
100 |
+
in_channels, out_channels, condition_size, kernel_size, activation, momentum
|
101 |
+
)
|
102 |
+
self.downsample = downsample
|
103 |
+
|
104 |
+
def forward(self, x, condition):
|
105 |
+
encoder = self.conv_block(x, condition)
|
106 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
107 |
+
return encoder_pool, encoder
|
108 |
+
|
109 |
+
|
110 |
+
class DecoderBlock(nn.Module):
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
in_channels,
|
114 |
+
out_channels,
|
115 |
+
condition_size,
|
116 |
+
kernel_size,
|
117 |
+
upsample,
|
118 |
+
activation,
|
119 |
+
momentum,
|
120 |
+
):
|
121 |
+
super(DecoderBlock, self).__init__()
|
122 |
+
self.kernel_size = kernel_size
|
123 |
+
self.stride = upsample
|
124 |
+
self.activation = activation
|
125 |
+
|
126 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
127 |
+
in_channels=in_channels,
|
128 |
+
out_channels=out_channels,
|
129 |
+
kernel_size=self.stride,
|
130 |
+
stride=self.stride,
|
131 |
+
padding=(0, 0),
|
132 |
+
bias=False,
|
133 |
+
dilation=(1, 1),
|
134 |
+
)
|
135 |
+
|
136 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
137 |
+
|
138 |
+
self.conv_block2 = ConvBlock(
|
139 |
+
out_channels * 2,
|
140 |
+
out_channels,
|
141 |
+
condition_size,
|
142 |
+
kernel_size,
|
143 |
+
activation,
|
144 |
+
momentum,
|
145 |
+
)
|
146 |
+
|
147 |
+
self.beta1 = nn.Linear(condition_size, out_channels, bias=True)
|
148 |
+
|
149 |
+
self.init_weights()
|
150 |
+
|
151 |
+
def init_weights(self):
|
152 |
+
init_layer(self.conv1)
|
153 |
+
init_bn(self.bn1)
|
154 |
+
init_embedding(self.beta1)
|
155 |
+
|
156 |
+
def forward(self, input_tensor, concat_tensor, condition):
|
157 |
+
b1 = self.beta1(condition)[:, :, None, None]
|
158 |
+
x = act(self.bn1(self.conv1(input_tensor)) + b1, self.activation)
|
159 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
160 |
+
x = self.conv_block2(x, condition)
|
161 |
+
return x
|
162 |
+
|
163 |
+
|
164 |
+
class ConditionalUNet(nn.Module, Base):
|
165 |
+
def __init__(self, input_channels, target_sources_num):
|
166 |
+
super(ConditionalUNet, self).__init__()
|
167 |
+
|
168 |
+
self.input_channels = input_channels
|
169 |
+
condition_size = target_sources_num
|
170 |
+
self.output_sources_num = 1
|
171 |
+
|
172 |
+
window_size = 2048
|
173 |
+
hop_size = 441
|
174 |
+
center = True
|
175 |
+
pad_mode = "reflect"
|
176 |
+
window = "hann"
|
177 |
+
activation = "relu"
|
178 |
+
momentum = 0.01
|
179 |
+
|
180 |
+
self.subbands_num = 4
|
181 |
+
self.K = 3 # outputs: |M|, cos∠M, sin∠M
|
182 |
+
|
183 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
184 |
+
|
185 |
+
self.stft = STFT(
|
186 |
+
n_fft=window_size,
|
187 |
+
hop_length=hop_size,
|
188 |
+
win_length=window_size,
|
189 |
+
window=window,
|
190 |
+
center=center,
|
191 |
+
pad_mode=pad_mode,
|
192 |
+
freeze_parameters=True,
|
193 |
+
)
|
194 |
+
|
195 |
+
self.istft = ISTFT(
|
196 |
+
n_fft=window_size,
|
197 |
+
hop_length=hop_size,
|
198 |
+
win_length=window_size,
|
199 |
+
window=window,
|
200 |
+
center=center,
|
201 |
+
pad_mode=pad_mode,
|
202 |
+
freeze_parameters=True,
|
203 |
+
)
|
204 |
+
|
205 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
206 |
+
|
207 |
+
self.subband = Subband(subbands_num=self.subbands_num)
|
208 |
+
|
209 |
+
self.encoder_block1 = EncoderBlock(
|
210 |
+
in_channels=input_channels * self.subbands_num,
|
211 |
+
out_channels=32,
|
212 |
+
condition_size=condition_size,
|
213 |
+
kernel_size=(3, 3),
|
214 |
+
downsample=(2, 2),
|
215 |
+
activation=activation,
|
216 |
+
momentum=momentum,
|
217 |
+
)
|
218 |
+
self.encoder_block2 = EncoderBlock(
|
219 |
+
in_channels=32,
|
220 |
+
out_channels=64,
|
221 |
+
condition_size=condition_size,
|
222 |
+
kernel_size=(3, 3),
|
223 |
+
downsample=(2, 2),
|
224 |
+
activation=activation,
|
225 |
+
momentum=momentum,
|
226 |
+
)
|
227 |
+
self.encoder_block3 = EncoderBlock(
|
228 |
+
in_channels=64,
|
229 |
+
out_channels=128,
|
230 |
+
condition_size=condition_size,
|
231 |
+
kernel_size=(3, 3),
|
232 |
+
downsample=(2, 2),
|
233 |
+
activation=activation,
|
234 |
+
momentum=momentum,
|
235 |
+
)
|
236 |
+
self.encoder_block4 = EncoderBlock(
|
237 |
+
in_channels=128,
|
238 |
+
out_channels=256,
|
239 |
+
condition_size=condition_size,
|
240 |
+
kernel_size=(3, 3),
|
241 |
+
downsample=(2, 2),
|
242 |
+
activation=activation,
|
243 |
+
momentum=momentum,
|
244 |
+
)
|
245 |
+
self.encoder_block5 = EncoderBlock(
|
246 |
+
in_channels=256,
|
247 |
+
out_channels=384,
|
248 |
+
condition_size=condition_size,
|
249 |
+
kernel_size=(3, 3),
|
250 |
+
downsample=(2, 2),
|
251 |
+
activation=activation,
|
252 |
+
momentum=momentum,
|
253 |
+
)
|
254 |
+
self.encoder_block6 = EncoderBlock(
|
255 |
+
in_channels=384,
|
256 |
+
out_channels=384,
|
257 |
+
condition_size=condition_size,
|
258 |
+
kernel_size=(3, 3),
|
259 |
+
downsample=(2, 2),
|
260 |
+
activation=activation,
|
261 |
+
momentum=momentum,
|
262 |
+
)
|
263 |
+
self.conv_block7 = ConvBlock(
|
264 |
+
in_channels=384,
|
265 |
+
out_channels=384,
|
266 |
+
condition_size=condition_size,
|
267 |
+
kernel_size=(3, 3),
|
268 |
+
activation=activation,
|
269 |
+
momentum=momentum,
|
270 |
+
)
|
271 |
+
self.decoder_block1 = DecoderBlock(
|
272 |
+
in_channels=384,
|
273 |
+
out_channels=384,
|
274 |
+
condition_size=condition_size,
|
275 |
+
kernel_size=(3, 3),
|
276 |
+
upsample=(2, 2),
|
277 |
+
activation=activation,
|
278 |
+
momentum=momentum,
|
279 |
+
)
|
280 |
+
self.decoder_block2 = DecoderBlock(
|
281 |
+
in_channels=384,
|
282 |
+
out_channels=384,
|
283 |
+
condition_size=condition_size,
|
284 |
+
kernel_size=(3, 3),
|
285 |
+
upsample=(2, 2),
|
286 |
+
activation=activation,
|
287 |
+
momentum=momentum,
|
288 |
+
)
|
289 |
+
self.decoder_block3 = DecoderBlock(
|
290 |
+
in_channels=384,
|
291 |
+
out_channels=256,
|
292 |
+
condition_size=condition_size,
|
293 |
+
kernel_size=(3, 3),
|
294 |
+
upsample=(2, 2),
|
295 |
+
activation=activation,
|
296 |
+
momentum=momentum,
|
297 |
+
)
|
298 |
+
self.decoder_block4 = DecoderBlock(
|
299 |
+
in_channels=256,
|
300 |
+
out_channels=128,
|
301 |
+
condition_size=condition_size,
|
302 |
+
kernel_size=(3, 3),
|
303 |
+
upsample=(2, 2),
|
304 |
+
activation=activation,
|
305 |
+
momentum=momentum,
|
306 |
+
)
|
307 |
+
self.decoder_block5 = DecoderBlock(
|
308 |
+
in_channels=128,
|
309 |
+
out_channels=64,
|
310 |
+
condition_size=condition_size,
|
311 |
+
kernel_size=(3, 3),
|
312 |
+
upsample=(2, 2),
|
313 |
+
activation=activation,
|
314 |
+
momentum=momentum,
|
315 |
+
)
|
316 |
+
self.decoder_block6 = DecoderBlock(
|
317 |
+
in_channels=64,
|
318 |
+
out_channels=32,
|
319 |
+
condition_size=condition_size,
|
320 |
+
kernel_size=(3, 3),
|
321 |
+
upsample=(2, 2),
|
322 |
+
activation=activation,
|
323 |
+
momentum=momentum,
|
324 |
+
)
|
325 |
+
|
326 |
+
self.after_conv_block1 = ConvBlock(
|
327 |
+
in_channels=32,
|
328 |
+
out_channels=32,
|
329 |
+
condition_size=condition_size,
|
330 |
+
kernel_size=(3, 3),
|
331 |
+
activation=activation,
|
332 |
+
momentum=momentum,
|
333 |
+
)
|
334 |
+
|
335 |
+
self.after_conv2 = nn.Conv2d(
|
336 |
+
in_channels=32,
|
337 |
+
out_channels=input_channels
|
338 |
+
* self.subbands_num
|
339 |
+
* self.output_sources_num
|
340 |
+
* self.K,
|
341 |
+
kernel_size=(1, 1),
|
342 |
+
stride=(1, 1),
|
343 |
+
padding=(0, 0),
|
344 |
+
bias=True,
|
345 |
+
)
|
346 |
+
|
347 |
+
self.init_weights()
|
348 |
+
|
349 |
+
def init_weights(self):
|
350 |
+
init_bn(self.bn0)
|
351 |
+
init_layer(self.after_conv2)
|
352 |
+
|
353 |
+
def feature_maps_to_wav(self, x, sp, sin_in, cos_in, audio_length):
|
354 |
+
|
355 |
+
batch_size, _, time_steps, freq_bins = x.shape
|
356 |
+
|
357 |
+
x = x.reshape(
|
358 |
+
batch_size,
|
359 |
+
self.output_sources_num,
|
360 |
+
self.input_channels,
|
361 |
+
self.K,
|
362 |
+
time_steps,
|
363 |
+
freq_bins,
|
364 |
+
)
|
365 |
+
# x: (batch_size, output_sources_num, input_channles, K, time_steps, freq_bins)
|
366 |
+
|
367 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
368 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
369 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
370 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
371 |
+
# mask_cos, mask_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
|
372 |
+
|
373 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
374 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
375 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
376 |
+
out_cos = (
|
377 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
378 |
+
)
|
379 |
+
out_sin = (
|
380 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
381 |
+
)
|
382 |
+
# out_cos, out_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
|
383 |
+
|
384 |
+
# Calculate |Y|.
|
385 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
|
386 |
+
# out_mag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
|
387 |
+
|
388 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
389 |
+
out_real = out_mag * out_cos
|
390 |
+
out_imag = out_mag * out_sin
|
391 |
+
# out_real, out_imag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
|
392 |
+
|
393 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
394 |
+
shape = (
|
395 |
+
batch_size * self.output_sources_num * self.input_channels,
|
396 |
+
1,
|
397 |
+
time_steps,
|
398 |
+
freq_bins,
|
399 |
+
)
|
400 |
+
out_real = out_real.reshape(shape)
|
401 |
+
out_imag = out_imag.reshape(shape)
|
402 |
+
|
403 |
+
# ISTFT.
|
404 |
+
wav_out = self.istft(out_real, out_imag, audio_length)
|
405 |
+
# (batch_size * output_sources_num * input_channels, segments_num)
|
406 |
+
|
407 |
+
# Reshape.
|
408 |
+
wav_out = wav_out.reshape(
|
409 |
+
batch_size, self.output_sources_num * self.input_channels, audio_length
|
410 |
+
)
|
411 |
+
# (batch_size, output_sources_num * input_channels, segments_num)
|
412 |
+
|
413 |
+
return wav_out
|
414 |
+
|
415 |
+
def forward(self, input_dict):
|
416 |
+
"""
|
417 |
+
Args:
|
418 |
+
input: (batch_size, segment_samples, channels_num)
|
419 |
+
|
420 |
+
Outputs:
|
421 |
+
output_dict: {
|
422 |
+
'wav': (batch_size, segment_samples, channels_num),
|
423 |
+
'sp': (batch_size, channels_num, time_steps, freq_bins)}
|
424 |
+
"""
|
425 |
+
|
426 |
+
mixture = input_dict['waveform']
|
427 |
+
condition = input_dict['condition']
|
428 |
+
|
429 |
+
sp, cos_in, sin_in = self.wav_to_spectrogram_phase(mixture)
|
430 |
+
"""(batch_size, channels_num, time_steps, freq_bins)"""
|
431 |
+
|
432 |
+
# Batch normalization
|
433 |
+
x = sp.transpose(1, 3)
|
434 |
+
x = self.bn0(x)
|
435 |
+
x = x.transpose(1, 3)
|
436 |
+
"""(batch_size, chanenls, time_steps, freq_bins)"""
|
437 |
+
|
438 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
439 |
+
origin_len = x.shape[2]
|
440 |
+
pad_len = (
|
441 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
442 |
+
- origin_len
|
443 |
+
)
|
444 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
445 |
+
"""(batch_size, channels, padded_time_steps, freq_bins)"""
|
446 |
+
|
447 |
+
# Let frequency bins be evenly divided by 2, e.g., 513 -> 512
|
448 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
|
449 |
+
|
450 |
+
x = self.subband.analysis(x)
|
451 |
+
|
452 |
+
# UNet
|
453 |
+
(x1_pool, x1) = self.encoder_block1(
|
454 |
+
x, condition
|
455 |
+
) # x1_pool: (bs, 32, T / 2, F / 2)
|
456 |
+
(x2_pool, x2) = self.encoder_block2(
|
457 |
+
x1_pool, condition
|
458 |
+
) # x2_pool: (bs, 64, T / 4, F / 4)
|
459 |
+
(x3_pool, x3) = self.encoder_block3(
|
460 |
+
x2_pool, condition
|
461 |
+
) # x3_pool: (bs, 128, T / 8, F / 8)
|
462 |
+
(x4_pool, x4) = self.encoder_block4(
|
463 |
+
x3_pool, condition
|
464 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
465 |
+
(x5_pool, x5) = self.encoder_block5(
|
466 |
+
x4_pool, condition
|
467 |
+
) # x5_pool: (bs, 512, T / 32, F / 32)
|
468 |
+
(x6_pool, x6) = self.encoder_block6(
|
469 |
+
x5_pool, condition
|
470 |
+
) # x6_pool: (bs, 1024, T / 64, F / 64)
|
471 |
+
x_center = self.conv_block7(x6_pool, condition) # (bs, 2048, T / 64, F / 64)
|
472 |
+
x7 = self.decoder_block1(x_center, x6, condition) # (bs, 1024, T / 32, F / 32)
|
473 |
+
x8 = self.decoder_block2(x7, x5, condition) # (bs, 512, T / 16, F / 16)
|
474 |
+
x9 = self.decoder_block3(x8, x4, condition) # (bs, 256, T / 8, F / 8)
|
475 |
+
x10 = self.decoder_block4(x9, x3, condition) # (bs, 128, T / 4, F / 4)
|
476 |
+
x11 = self.decoder_block5(x10, x2, condition) # (bs, 64, T / 2, F / 2)
|
477 |
+
x12 = self.decoder_block6(x11, x1, condition) # (bs, 32, T, F)
|
478 |
+
x = self.after_conv_block1(x12, condition) # (bs, 32, T, F)
|
479 |
+
x = self.after_conv2(x)
|
480 |
+
# (batch_size, input_channles * subbands_num * targets_num * k, T, F // subbands_num)
|
481 |
+
|
482 |
+
x = self.subband.synthesis(x)
|
483 |
+
# (batch_size, input_channles * targets_num * K, T, F)
|
484 |
+
|
485 |
+
# Recover shape
|
486 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
|
487 |
+
x = x[:, :, 0:origin_len, :] # (bs, feature_maps, T, F)
|
488 |
+
|
489 |
+
audio_length = mixture.shape[2]
|
490 |
+
|
491 |
+
separated_audio = self.feature_maps_to_wav(x, sp, sin_in, cos_in, audio_length)
|
492 |
+
# separated_audio: (batch_size, output_sources_num * input_channels, segments_num)
|
493 |
+
|
494 |
+
output_dict = {'waveform': separated_audio}
|
495 |
+
|
496 |
+
return output_dict
|
bytesep/models/lightning_modules.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.optim.lr_scheduler import LambdaLR
|
8 |
+
|
9 |
+
|
10 |
+
class LitSourceSeparation(pl.LightningModule):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
batch_data_preprocessor,
|
14 |
+
model: nn.Module,
|
15 |
+
loss_function: Callable,
|
16 |
+
optimizer_type: str,
|
17 |
+
learning_rate: float,
|
18 |
+
lr_lambda: Callable,
|
19 |
+
):
|
20 |
+
r"""Pytorch Lightning wrapper of PyTorch model, including forward,
|
21 |
+
optimization of model, etc.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
batch_data_preprocessor: object, used for preparing inputs and
|
25 |
+
targets for training. E.g., BasicBatchDataPreprocessor is used
|
26 |
+
for preparing data in dictionary into tensor.
|
27 |
+
model: nn.Module
|
28 |
+
loss_function: function
|
29 |
+
learning_rate: float
|
30 |
+
lr_lambda: function
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.batch_data_preprocessor = batch_data_preprocessor
|
35 |
+
self.model = model
|
36 |
+
self.optimizer_type = optimizer_type
|
37 |
+
self.loss_function = loss_function
|
38 |
+
self.learning_rate = learning_rate
|
39 |
+
self.lr_lambda = lr_lambda
|
40 |
+
|
41 |
+
def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float:
|
42 |
+
r"""Forward a mini-batch data to model, calculate loss function, and
|
43 |
+
train for one step. A mini-batch data is evenly distributed to multiple
|
44 |
+
devices (if there are) for parallel training.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
batch_data_dict: e.g. {
|
48 |
+
'vocals': (batch_size, channels_num, segment_samples),
|
49 |
+
'accompaniment': (batch_size, channels_num, segment_samples),
|
50 |
+
'mixture': (batch_size, channels_num, segment_samples)
|
51 |
+
}
|
52 |
+
batch_idx: int
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
loss: float, loss function of this mini-batch
|
56 |
+
"""
|
57 |
+
input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict)
|
58 |
+
# input_dict: {
|
59 |
+
# 'waveform': (batch_size, channels_num, segment_samples),
|
60 |
+
# (if_exist) 'condition': (batch_size, channels_num),
|
61 |
+
# }
|
62 |
+
# target_dict: {
|
63 |
+
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
|
64 |
+
# }
|
65 |
+
|
66 |
+
# Forward.
|
67 |
+
self.model.train()
|
68 |
+
|
69 |
+
output_dict = self.model(input_dict)
|
70 |
+
# output_dict: {
|
71 |
+
# 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
|
72 |
+
# }
|
73 |
+
|
74 |
+
outputs = output_dict['waveform']
|
75 |
+
# outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples)
|
76 |
+
|
77 |
+
# Calculate loss.
|
78 |
+
loss = self.loss_function(
|
79 |
+
output=outputs,
|
80 |
+
target=target_dict['waveform'],
|
81 |
+
mixture=input_dict['waveform'],
|
82 |
+
)
|
83 |
+
|
84 |
+
return loss
|
85 |
+
|
86 |
+
def configure_optimizers(self) -> Any:
|
87 |
+
r"""Configure optimizer."""
|
88 |
+
|
89 |
+
if self.optimizer_type == "Adam":
|
90 |
+
optimizer = optim.Adam(
|
91 |
+
self.model.parameters(),
|
92 |
+
lr=self.learning_rate,
|
93 |
+
betas=(0.9, 0.999),
|
94 |
+
eps=1e-08,
|
95 |
+
weight_decay=0.0,
|
96 |
+
amsgrad=True,
|
97 |
+
)
|
98 |
+
|
99 |
+
elif self.optimizer_type == "AdamW":
|
100 |
+
optimizer = optim.AdamW(
|
101 |
+
self.model.parameters(),
|
102 |
+
lr=self.learning_rate,
|
103 |
+
betas=(0.9, 0.999),
|
104 |
+
eps=1e-08,
|
105 |
+
weight_decay=0.0,
|
106 |
+
amsgrad=True,
|
107 |
+
)
|
108 |
+
|
109 |
+
else:
|
110 |
+
raise NotImplementedError
|
111 |
+
|
112 |
+
scheduler = {
|
113 |
+
'scheduler': LambdaLR(optimizer, self.lr_lambda),
|
114 |
+
'interval': 'step',
|
115 |
+
'frequency': 1,
|
116 |
+
}
|
117 |
+
|
118 |
+
return [optimizer], [scheduler]
|
119 |
+
|
120 |
+
|
121 |
+
def get_model_class(model_type):
|
122 |
+
r"""Get model.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN'
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
nn.Module
|
129 |
+
"""
|
130 |
+
if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021':
|
131 |
+
from bytesep.models.resunet_ismir2021 import (
|
132 |
+
ResUNet143_DecouplePlusInplaceABN_ISMIR2021,
|
133 |
+
)
|
134 |
+
|
135 |
+
return ResUNet143_DecouplePlusInplaceABN_ISMIR2021
|
136 |
+
|
137 |
+
elif model_type == 'UNet':
|
138 |
+
from bytesep.models.unet import UNet
|
139 |
+
|
140 |
+
return UNet
|
141 |
+
|
142 |
+
elif model_type == 'UNetSubbandTime':
|
143 |
+
from bytesep.models.unet_subbandtime import UNetSubbandTime
|
144 |
+
|
145 |
+
return UNetSubbandTime
|
146 |
+
|
147 |
+
elif model_type == 'ResUNet143_Subbandtime':
|
148 |
+
from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime
|
149 |
+
|
150 |
+
return ResUNet143_Subbandtime
|
151 |
+
|
152 |
+
elif model_type == 'ResUNet143_DecouplePlus':
|
153 |
+
from bytesep.models.resunet import ResUNet143_DecouplePlus
|
154 |
+
|
155 |
+
return ResUNet143_DecouplePlus
|
156 |
+
|
157 |
+
elif model_type == 'ConditionalUNet':
|
158 |
+
from bytesep.models.conditional_unet import ConditionalUNet
|
159 |
+
|
160 |
+
return ConditionalUNet
|
161 |
+
|
162 |
+
elif model_type == 'LevelRNN':
|
163 |
+
from bytesep.models.levelrnn import LevelRNN
|
164 |
+
|
165 |
+
return LevelRNN
|
166 |
+
|
167 |
+
elif model_type == 'WavUNet':
|
168 |
+
from bytesep.models.wavunet import WavUNet
|
169 |
+
|
170 |
+
return WavUNet
|
171 |
+
|
172 |
+
elif model_type == 'WavUNetLevelRNN':
|
173 |
+
from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN
|
174 |
+
|
175 |
+
return WavUNetLevelRNN
|
176 |
+
|
177 |
+
elif model_type == 'TTnet':
|
178 |
+
from bytesep.models.ttnet import TTnet
|
179 |
+
|
180 |
+
return TTnet
|
181 |
+
|
182 |
+
elif model_type == 'TTnetNoTransformer':
|
183 |
+
from bytesep.models.ttnet_no_transformer import TTnetNoTransformer
|
184 |
+
|
185 |
+
return TTnetNoTransformer
|
186 |
+
|
187 |
+
else:
|
188 |
+
raise NotImplementedError
|
bytesep/models/pytorch_modules.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, NoReturn
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def init_embedding(layer: nn.Module) -> NoReturn:
|
10 |
+
r"""Initialize a Linear or Convolutional layer."""
|
11 |
+
nn.init.uniform_(layer.weight, -1.0, 1.0)
|
12 |
+
|
13 |
+
if hasattr(layer, 'bias'):
|
14 |
+
if layer.bias is not None:
|
15 |
+
layer.bias.data.fill_(0.0)
|
16 |
+
|
17 |
+
|
18 |
+
def init_layer(layer: nn.Module) -> NoReturn:
|
19 |
+
r"""Initialize a Linear or Convolutional layer."""
|
20 |
+
nn.init.xavier_uniform_(layer.weight)
|
21 |
+
|
22 |
+
if hasattr(layer, "bias"):
|
23 |
+
if layer.bias is not None:
|
24 |
+
layer.bias.data.fill_(0.0)
|
25 |
+
|
26 |
+
|
27 |
+
def init_bn(bn: nn.Module) -> NoReturn:
|
28 |
+
r"""Initialize a Batchnorm layer."""
|
29 |
+
bn.bias.data.fill_(0.0)
|
30 |
+
bn.weight.data.fill_(1.0)
|
31 |
+
bn.running_mean.data.fill_(0.0)
|
32 |
+
bn.running_var.data.fill_(1.0)
|
33 |
+
|
34 |
+
|
35 |
+
def act(x: torch.Tensor, activation: str) -> torch.Tensor:
|
36 |
+
|
37 |
+
if activation == "relu":
|
38 |
+
return F.relu_(x)
|
39 |
+
|
40 |
+
elif activation == "leaky_relu":
|
41 |
+
return F.leaky_relu_(x, negative_slope=0.01)
|
42 |
+
|
43 |
+
elif activation == "swish":
|
44 |
+
return x * torch.sigmoid(x)
|
45 |
+
|
46 |
+
else:
|
47 |
+
raise Exception("Incorrect activation!")
|
48 |
+
|
49 |
+
|
50 |
+
class Base:
|
51 |
+
def __init__(self):
|
52 |
+
r"""Base function for extracting spectrogram, cos, and sin, etc."""
|
53 |
+
pass
|
54 |
+
|
55 |
+
def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
|
56 |
+
r"""Calculate spectrogram.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
input: (batch_size, segments_num)
|
60 |
+
eps: float
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
spectrogram: (batch_size, time_steps, freq_bins)
|
64 |
+
"""
|
65 |
+
(real, imag) = self.stft(input)
|
66 |
+
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
|
67 |
+
|
68 |
+
def spectrogram_phase(
|
69 |
+
self, input: torch.Tensor, eps: float = 0.0
|
70 |
+
) -> List[torch.Tensor]:
|
71 |
+
r"""Calculate the magnitude, cos, and sin of the STFT of input.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
input: (batch_size, segments_num)
|
75 |
+
eps: float
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
mag: (batch_size, time_steps, freq_bins)
|
79 |
+
cos: (batch_size, time_steps, freq_bins)
|
80 |
+
sin: (batch_size, time_steps, freq_bins)
|
81 |
+
"""
|
82 |
+
(real, imag) = self.stft(input)
|
83 |
+
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
|
84 |
+
cos = real / mag
|
85 |
+
sin = imag / mag
|
86 |
+
return mag, cos, sin
|
87 |
+
|
88 |
+
def wav_to_spectrogram_phase(
|
89 |
+
self, input: torch.Tensor, eps: float = 1e-10
|
90 |
+
) -> List[torch.Tensor]:
|
91 |
+
r"""Convert waveforms to magnitude, cos, and sin of STFT.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
input: (batch_size, channels_num, segment_samples)
|
95 |
+
eps: float
|
96 |
+
|
97 |
+
Outputs:
|
98 |
+
mag: (batch_size, channels_num, time_steps, freq_bins)
|
99 |
+
cos: (batch_size, channels_num, time_steps, freq_bins)
|
100 |
+
sin: (batch_size, channels_num, time_steps, freq_bins)
|
101 |
+
"""
|
102 |
+
batch_size, channels_num, segment_samples = input.shape
|
103 |
+
|
104 |
+
# Reshape input with shapes of (n, segments_num) to meet the
|
105 |
+
# requirements of the stft function.
|
106 |
+
x = input.reshape(batch_size * channels_num, segment_samples)
|
107 |
+
|
108 |
+
mag, cos, sin = self.spectrogram_phase(x, eps=eps)
|
109 |
+
# mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins)
|
110 |
+
|
111 |
+
_, _, time_steps, freq_bins = mag.shape
|
112 |
+
mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins)
|
113 |
+
cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins)
|
114 |
+
sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins)
|
115 |
+
|
116 |
+
return mag, cos, sin
|
117 |
+
|
118 |
+
def wav_to_spectrogram(
|
119 |
+
self, input: torch.Tensor, eps: float = 1e-10
|
120 |
+
) -> List[torch.Tensor]:
|
121 |
+
|
122 |
+
mag, cos, sin = self.wav_to_spectrogram_phase(input, eps)
|
123 |
+
return mag
|
124 |
+
|
125 |
+
|
126 |
+
class Subband:
|
127 |
+
def __init__(self, subbands_num: int):
|
128 |
+
r"""Warning!! This class is not used!!
|
129 |
+
|
130 |
+
This class does not work as good as [1] which split subbands in the
|
131 |
+
time-domain. Please refere to [1] for formal implementation.
|
132 |
+
|
133 |
+
[1] Liu, Haohe, et al. "Channel-wise subband input for better voice and
|
134 |
+
accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020).
|
135 |
+
|
136 |
+
Args:
|
137 |
+
subbands_num: int, e.g., 4
|
138 |
+
"""
|
139 |
+
self.subbands_num = subbands_num
|
140 |
+
|
141 |
+
def analysis(self, x: torch.Tensor) -> torch.Tensor:
|
142 |
+
r"""Analysis time-frequency representation into subbands. Stack the
|
143 |
+
subbands along the channel axis.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
x: (batch_size, channels_num, time_steps, freq_bins)
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
|
150 |
+
"""
|
151 |
+
batch_size, channels_num, time_steps, freq_bins = x.shape
|
152 |
+
|
153 |
+
x = x.reshape(
|
154 |
+
batch_size,
|
155 |
+
channels_num,
|
156 |
+
time_steps,
|
157 |
+
self.subbands_num,
|
158 |
+
freq_bins // self.subbands_num,
|
159 |
+
)
|
160 |
+
# x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
|
161 |
+
|
162 |
+
x = x.transpose(2, 3)
|
163 |
+
|
164 |
+
output = x.reshape(
|
165 |
+
batch_size,
|
166 |
+
channels_num * self.subbands_num,
|
167 |
+
time_steps,
|
168 |
+
freq_bins // self.subbands_num,
|
169 |
+
)
|
170 |
+
# output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
|
171 |
+
|
172 |
+
return output
|
173 |
+
|
174 |
+
def synthesis(self, x: torch.Tensor) -> torch.Tensor:
|
175 |
+
r"""Synthesis subband time-frequency representations into original
|
176 |
+
time-frequency representation.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
output: (batch_size, channels_num, time_steps, freq_bins)
|
183 |
+
"""
|
184 |
+
batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape
|
185 |
+
|
186 |
+
channels_num = subband_channels_num // self.subbands_num
|
187 |
+
freq_bins = subband_freq_bins * self.subbands_num
|
188 |
+
|
189 |
+
x = x.reshape(
|
190 |
+
batch_size,
|
191 |
+
channels_num,
|
192 |
+
self.subbands_num,
|
193 |
+
time_steps,
|
194 |
+
subband_freq_bins,
|
195 |
+
)
|
196 |
+
# x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num)
|
197 |
+
|
198 |
+
x = x.transpose(2, 3)
|
199 |
+
# x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
|
200 |
+
|
201 |
+
output = x.reshape(batch_size, channels_num, time_steps, freq_bins)
|
202 |
+
# x: (batch_size, channels_num, time_steps, freq_bins)
|
203 |
+
|
204 |
+
return output
|
bytesep/models/resunet.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
6 |
+
|
7 |
+
from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
|
8 |
+
|
9 |
+
|
10 |
+
class ConvBlockRes(nn.Module):
|
11 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
12 |
+
r"""Residual block."""
|
13 |
+
super(ConvBlockRes, self).__init__()
|
14 |
+
|
15 |
+
self.activation = activation
|
16 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
17 |
+
|
18 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
19 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
20 |
+
|
21 |
+
self.conv1 = nn.Conv2d(
|
22 |
+
in_channels=in_channels,
|
23 |
+
out_channels=out_channels,
|
24 |
+
kernel_size=kernel_size,
|
25 |
+
stride=(1, 1),
|
26 |
+
dilation=(1, 1),
|
27 |
+
padding=padding,
|
28 |
+
bias=False,
|
29 |
+
)
|
30 |
+
|
31 |
+
self.conv2 = nn.Conv2d(
|
32 |
+
in_channels=out_channels,
|
33 |
+
out_channels=out_channels,
|
34 |
+
kernel_size=kernel_size,
|
35 |
+
stride=(1, 1),
|
36 |
+
dilation=(1, 1),
|
37 |
+
padding=padding,
|
38 |
+
bias=False,
|
39 |
+
)
|
40 |
+
|
41 |
+
if in_channels != out_channels:
|
42 |
+
self.shortcut = nn.Conv2d(
|
43 |
+
in_channels=in_channels,
|
44 |
+
out_channels=out_channels,
|
45 |
+
kernel_size=(1, 1),
|
46 |
+
stride=(1, 1),
|
47 |
+
padding=(0, 0),
|
48 |
+
)
|
49 |
+
|
50 |
+
self.is_shortcut = True
|
51 |
+
else:
|
52 |
+
self.is_shortcut = False
|
53 |
+
|
54 |
+
self.init_weights()
|
55 |
+
|
56 |
+
def init_weights(self):
|
57 |
+
init_bn(self.bn1)
|
58 |
+
init_bn(self.bn2)
|
59 |
+
init_layer(self.conv1)
|
60 |
+
init_layer(self.conv2)
|
61 |
+
|
62 |
+
if self.is_shortcut:
|
63 |
+
init_layer(self.shortcut)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
origin = x
|
67 |
+
x = self.conv1(act(self.bn1(x), self.activation))
|
68 |
+
x = self.conv2(act(self.bn2(x), self.activation))
|
69 |
+
|
70 |
+
if self.is_shortcut:
|
71 |
+
return self.shortcut(origin) + x
|
72 |
+
else:
|
73 |
+
return origin + x
|
74 |
+
|
75 |
+
|
76 |
+
class EncoderBlockRes4B(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self, in_channels, out_channels, kernel_size, downsample, activation, momentum
|
79 |
+
):
|
80 |
+
r"""Encoder block, contains 8 convolutional layers."""
|
81 |
+
super(EncoderBlockRes4B, self).__init__()
|
82 |
+
|
83 |
+
self.conv_block1 = ConvBlockRes(
|
84 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
85 |
+
)
|
86 |
+
self.conv_block2 = ConvBlockRes(
|
87 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
88 |
+
)
|
89 |
+
self.conv_block3 = ConvBlockRes(
|
90 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
91 |
+
)
|
92 |
+
self.conv_block4 = ConvBlockRes(
|
93 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
94 |
+
)
|
95 |
+
self.downsample = downsample
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
encoder = self.conv_block1(x)
|
99 |
+
encoder = self.conv_block2(encoder)
|
100 |
+
encoder = self.conv_block3(encoder)
|
101 |
+
encoder = self.conv_block4(encoder)
|
102 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
103 |
+
return encoder_pool, encoder
|
104 |
+
|
105 |
+
|
106 |
+
class DecoderBlockRes4B(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self, in_channels, out_channels, kernel_size, upsample, activation, momentum
|
109 |
+
):
|
110 |
+
r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
|
111 |
+
super(DecoderBlockRes4B, self).__init__()
|
112 |
+
self.kernel_size = kernel_size
|
113 |
+
self.stride = upsample
|
114 |
+
self.activation = activation
|
115 |
+
|
116 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
117 |
+
in_channels=in_channels,
|
118 |
+
out_channels=out_channels,
|
119 |
+
kernel_size=self.stride,
|
120 |
+
stride=self.stride,
|
121 |
+
padding=(0, 0),
|
122 |
+
bias=False,
|
123 |
+
dilation=(1, 1),
|
124 |
+
)
|
125 |
+
|
126 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
127 |
+
self.conv_block2 = ConvBlockRes(
|
128 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
129 |
+
)
|
130 |
+
self.conv_block3 = ConvBlockRes(
|
131 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
132 |
+
)
|
133 |
+
self.conv_block4 = ConvBlockRes(
|
134 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
135 |
+
)
|
136 |
+
self.conv_block5 = ConvBlockRes(
|
137 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
138 |
+
)
|
139 |
+
|
140 |
+
self.init_weights()
|
141 |
+
|
142 |
+
def init_weights(self):
|
143 |
+
init_bn(self.bn1)
|
144 |
+
init_layer(self.conv1)
|
145 |
+
|
146 |
+
def forward(self, input_tensor, concat_tensor):
|
147 |
+
x = self.conv1(act(self.bn1(input_tensor), self.activation))
|
148 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
149 |
+
x = self.conv_block2(x)
|
150 |
+
x = self.conv_block3(x)
|
151 |
+
x = self.conv_block4(x)
|
152 |
+
x = self.conv_block5(x)
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
class ResUNet143_DecouplePlus(nn.Module, Base):
|
157 |
+
def __init__(self, input_channels, target_sources_num):
|
158 |
+
super(ResUNet143_DecouplePlus, self).__init__()
|
159 |
+
|
160 |
+
self.input_channels = input_channels
|
161 |
+
self.target_sources_num = target_sources_num
|
162 |
+
|
163 |
+
window_size = 2048
|
164 |
+
hop_size = 441
|
165 |
+
center = True
|
166 |
+
pad_mode = "reflect"
|
167 |
+
window = "hann"
|
168 |
+
activation = "relu"
|
169 |
+
momentum = 0.01
|
170 |
+
|
171 |
+
self.subbands_num = 4
|
172 |
+
self.K = 4 # outputs: |M|, cos∠M, sin∠M, |M2|
|
173 |
+
|
174 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
175 |
+
|
176 |
+
self.stft = STFT(
|
177 |
+
n_fft=window_size,
|
178 |
+
hop_length=hop_size,
|
179 |
+
win_length=window_size,
|
180 |
+
window=window,
|
181 |
+
center=center,
|
182 |
+
pad_mode=pad_mode,
|
183 |
+
freeze_parameters=True,
|
184 |
+
)
|
185 |
+
|
186 |
+
self.istft = ISTFT(
|
187 |
+
n_fft=window_size,
|
188 |
+
hop_length=hop_size,
|
189 |
+
win_length=window_size,
|
190 |
+
window=window,
|
191 |
+
center=center,
|
192 |
+
pad_mode=pad_mode,
|
193 |
+
freeze_parameters=True,
|
194 |
+
)
|
195 |
+
|
196 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
197 |
+
|
198 |
+
self.subband = Subband(subbands_num=self.subbands_num)
|
199 |
+
|
200 |
+
self.encoder_block1 = EncoderBlockRes4B(
|
201 |
+
in_channels=input_channels * self.subbands_num,
|
202 |
+
out_channels=32,
|
203 |
+
kernel_size=(3, 3),
|
204 |
+
downsample=(2, 2),
|
205 |
+
activation=activation,
|
206 |
+
momentum=momentum,
|
207 |
+
)
|
208 |
+
self.encoder_block2 = EncoderBlockRes4B(
|
209 |
+
in_channels=32,
|
210 |
+
out_channels=64,
|
211 |
+
kernel_size=(3, 3),
|
212 |
+
downsample=(2, 2),
|
213 |
+
activation=activation,
|
214 |
+
momentum=momentum,
|
215 |
+
)
|
216 |
+
self.encoder_block3 = EncoderBlockRes4B(
|
217 |
+
in_channels=64,
|
218 |
+
out_channels=128,
|
219 |
+
kernel_size=(3, 3),
|
220 |
+
downsample=(2, 2),
|
221 |
+
activation=activation,
|
222 |
+
momentum=momentum,
|
223 |
+
)
|
224 |
+
self.encoder_block4 = EncoderBlockRes4B(
|
225 |
+
in_channels=128,
|
226 |
+
out_channels=256,
|
227 |
+
kernel_size=(3, 3),
|
228 |
+
downsample=(2, 2),
|
229 |
+
activation=activation,
|
230 |
+
momentum=momentum,
|
231 |
+
)
|
232 |
+
self.encoder_block5 = EncoderBlockRes4B(
|
233 |
+
in_channels=256,
|
234 |
+
out_channels=384,
|
235 |
+
kernel_size=(3, 3),
|
236 |
+
downsample=(2, 2),
|
237 |
+
activation=activation,
|
238 |
+
momentum=momentum,
|
239 |
+
)
|
240 |
+
self.encoder_block6 = EncoderBlockRes4B(
|
241 |
+
in_channels=384,
|
242 |
+
out_channels=384,
|
243 |
+
kernel_size=(3, 3),
|
244 |
+
downsample=(1, 2),
|
245 |
+
activation=activation,
|
246 |
+
momentum=momentum,
|
247 |
+
)
|
248 |
+
self.conv_block7a = EncoderBlockRes4B(
|
249 |
+
in_channels=384,
|
250 |
+
out_channels=384,
|
251 |
+
kernel_size=(3, 3),
|
252 |
+
downsample=(1, 1),
|
253 |
+
activation=activation,
|
254 |
+
momentum=momentum,
|
255 |
+
)
|
256 |
+
self.conv_block7b = EncoderBlockRes4B(
|
257 |
+
in_channels=384,
|
258 |
+
out_channels=384,
|
259 |
+
kernel_size=(3, 3),
|
260 |
+
downsample=(1, 1),
|
261 |
+
activation=activation,
|
262 |
+
momentum=momentum,
|
263 |
+
)
|
264 |
+
self.conv_block7c = EncoderBlockRes4B(
|
265 |
+
in_channels=384,
|
266 |
+
out_channels=384,
|
267 |
+
kernel_size=(3, 3),
|
268 |
+
downsample=(1, 1),
|
269 |
+
activation=activation,
|
270 |
+
momentum=momentum,
|
271 |
+
)
|
272 |
+
self.conv_block7d = EncoderBlockRes4B(
|
273 |
+
in_channels=384,
|
274 |
+
out_channels=384,
|
275 |
+
kernel_size=(3, 3),
|
276 |
+
downsample=(1, 1),
|
277 |
+
activation=activation,
|
278 |
+
momentum=momentum,
|
279 |
+
)
|
280 |
+
self.decoder_block1 = DecoderBlockRes4B(
|
281 |
+
in_channels=384,
|
282 |
+
out_channels=384,
|
283 |
+
kernel_size=(3, 3),
|
284 |
+
upsample=(1, 2),
|
285 |
+
activation=activation,
|
286 |
+
momentum=momentum,
|
287 |
+
)
|
288 |
+
self.decoder_block2 = DecoderBlockRes4B(
|
289 |
+
in_channels=384,
|
290 |
+
out_channels=384,
|
291 |
+
kernel_size=(3, 3),
|
292 |
+
upsample=(2, 2),
|
293 |
+
activation=activation,
|
294 |
+
momentum=momentum,
|
295 |
+
)
|
296 |
+
self.decoder_block3 = DecoderBlockRes4B(
|
297 |
+
in_channels=384,
|
298 |
+
out_channels=256,
|
299 |
+
kernel_size=(3, 3),
|
300 |
+
upsample=(2, 2),
|
301 |
+
activation=activation,
|
302 |
+
momentum=momentum,
|
303 |
+
)
|
304 |
+
self.decoder_block4 = DecoderBlockRes4B(
|
305 |
+
in_channels=256,
|
306 |
+
out_channels=128,
|
307 |
+
kernel_size=(3, 3),
|
308 |
+
upsample=(2, 2),
|
309 |
+
activation=activation,
|
310 |
+
momentum=momentum,
|
311 |
+
)
|
312 |
+
self.decoder_block5 = DecoderBlockRes4B(
|
313 |
+
in_channels=128,
|
314 |
+
out_channels=64,
|
315 |
+
kernel_size=(3, 3),
|
316 |
+
upsample=(2, 2),
|
317 |
+
activation=activation,
|
318 |
+
momentum=momentum,
|
319 |
+
)
|
320 |
+
self.decoder_block6 = DecoderBlockRes4B(
|
321 |
+
in_channels=64,
|
322 |
+
out_channels=32,
|
323 |
+
kernel_size=(3, 3),
|
324 |
+
upsample=(2, 2),
|
325 |
+
activation=activation,
|
326 |
+
momentum=momentum,
|
327 |
+
)
|
328 |
+
|
329 |
+
self.after_conv_block1 = EncoderBlockRes4B(
|
330 |
+
in_channels=32,
|
331 |
+
out_channels=32,
|
332 |
+
kernel_size=(3, 3),
|
333 |
+
downsample=(1, 1),
|
334 |
+
activation=activation,
|
335 |
+
momentum=momentum,
|
336 |
+
)
|
337 |
+
|
338 |
+
self.after_conv2 = nn.Conv2d(
|
339 |
+
in_channels=32,
|
340 |
+
out_channels=input_channels
|
341 |
+
* self.subbands_num
|
342 |
+
* target_sources_num
|
343 |
+
* self.K,
|
344 |
+
kernel_size=(1, 1),
|
345 |
+
stride=(1, 1),
|
346 |
+
padding=(0, 0),
|
347 |
+
bias=True,
|
348 |
+
)
|
349 |
+
|
350 |
+
self.init_weights()
|
351 |
+
|
352 |
+
def init_weights(self):
|
353 |
+
init_bn(self.bn0)
|
354 |
+
init_layer(self.after_conv2)
|
355 |
+
|
356 |
+
def feature_maps_to_wav(
|
357 |
+
self,
|
358 |
+
input_tensor: torch.Tensor,
|
359 |
+
sp: torch.Tensor,
|
360 |
+
sin_in: torch.Tensor,
|
361 |
+
cos_in: torch.Tensor,
|
362 |
+
audio_length: int,
|
363 |
+
) -> torch.Tensor:
|
364 |
+
r"""Convert feature maps to waveform.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
input_tensor: (batch_size, feature_maps, time_steps, freq_bins)
|
368 |
+
sp: (batch_size, feature_maps, time_steps, freq_bins)
|
369 |
+
sin_in: (batch_size, feature_maps, time_steps, freq_bins)
|
370 |
+
cos_in: (batch_size, feature_maps, time_steps, freq_bins)
|
371 |
+
|
372 |
+
Outputs:
|
373 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
374 |
+
"""
|
375 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
376 |
+
|
377 |
+
x = input_tensor.reshape(
|
378 |
+
batch_size,
|
379 |
+
self.target_sources_num,
|
380 |
+
self.input_channels,
|
381 |
+
self.K,
|
382 |
+
time_steps,
|
383 |
+
freq_bins,
|
384 |
+
)
|
385 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
386 |
+
|
387 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
388 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
389 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
390 |
+
linear_mag = x[:, :, :, 3, :, :]
|
391 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
392 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
393 |
+
|
394 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
395 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
396 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
397 |
+
out_cos = (
|
398 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
399 |
+
)
|
400 |
+
out_sin = (
|
401 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
402 |
+
)
|
403 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
404 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
405 |
+
|
406 |
+
# Calculate |Y|.
|
407 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
|
408 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
409 |
+
|
410 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
411 |
+
out_real = out_mag * out_cos
|
412 |
+
out_imag = out_mag * out_sin
|
413 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
414 |
+
|
415 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
416 |
+
shape = (
|
417 |
+
batch_size * self.target_sources_num * self.input_channels,
|
418 |
+
1,
|
419 |
+
time_steps,
|
420 |
+
freq_bins,
|
421 |
+
)
|
422 |
+
out_real = out_real.reshape(shape)
|
423 |
+
out_imag = out_imag.reshape(shape)
|
424 |
+
|
425 |
+
# ISTFT.
|
426 |
+
x = self.istft(out_real, out_imag, audio_length)
|
427 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
428 |
+
|
429 |
+
# Reshape.
|
430 |
+
waveform = x.reshape(
|
431 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
432 |
+
)
|
433 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
434 |
+
|
435 |
+
return waveform
|
436 |
+
|
437 |
+
def forward(self, input_dict):
|
438 |
+
r"""
|
439 |
+
Args:
|
440 |
+
input: (batch_size, channels_num, segment_samples)
|
441 |
+
|
442 |
+
Outputs:
|
443 |
+
output_dict: {
|
444 |
+
'wav': (batch_size, channels_num, segment_samples)
|
445 |
+
}
|
446 |
+
"""
|
447 |
+
mixtures = input_dict['waveform']
|
448 |
+
# (batch_size, input_channels, segment_samples)
|
449 |
+
|
450 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
|
451 |
+
# mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
|
452 |
+
|
453 |
+
# Batch normalize on individual frequency bins.
|
454 |
+
x = mag.transpose(1, 3)
|
455 |
+
x = self.bn0(x)
|
456 |
+
x = x.transpose(1, 3)
|
457 |
+
"""(batch_size, input_channels, time_steps, freq_bins)"""
|
458 |
+
|
459 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
460 |
+
origin_len = x.shape[2]
|
461 |
+
pad_len = (
|
462 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
463 |
+
- origin_len
|
464 |
+
)
|
465 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
466 |
+
"""(batch_size, input_channels, padded_time_steps, freq_bins)"""
|
467 |
+
|
468 |
+
# Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
|
469 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
|
470 |
+
|
471 |
+
x = self.subband.analysis(x)
|
472 |
+
# (bs, input_channels, T, F'), where F' = F // subbands_num
|
473 |
+
|
474 |
+
# UNet
|
475 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
|
476 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
|
477 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
|
478 |
+
(x4_pool, x4) = self.encoder_block4(
|
479 |
+
x3_pool
|
480 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
481 |
+
(x5_pool, x5) = self.encoder_block5(
|
482 |
+
x4_pool
|
483 |
+
) # x5_pool: (bs, 384, T / 32, F / 32)
|
484 |
+
(x6_pool, x6) = self.encoder_block6(
|
485 |
+
x5_pool
|
486 |
+
) # x6_pool: (bs, 384, T / 32, F / 64)
|
487 |
+
(x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
|
488 |
+
(x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
|
489 |
+
(x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
|
490 |
+
(x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
|
491 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
|
492 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
|
493 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
|
494 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
|
495 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
|
496 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
|
497 |
+
(x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
|
498 |
+
|
499 |
+
x = self.after_conv2(x) # (bs, channels * 3, T, F)
|
500 |
+
# (batch_size, input_channles * subbands_num * targets_num * k, T, F')
|
501 |
+
|
502 |
+
x = self.subband.synthesis(x)
|
503 |
+
# (batch_size, input_channles * targets_num * K, T, F)
|
504 |
+
|
505 |
+
# Recover shape
|
506 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
|
507 |
+
x = x[:, :, 0:origin_len, :] # (bs, feature_maps, time_steps, freq_bins)
|
508 |
+
|
509 |
+
audio_length = mixtures.shape[2]
|
510 |
+
|
511 |
+
separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
|
512 |
+
# separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
|
513 |
+
|
514 |
+
output_dict = {'waveform': separated_audio}
|
515 |
+
|
516 |
+
return output_dict
|
bytesep/models/resunet_ismir2021.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from inplace_abn.abn import InPlaceABNSync
|
6 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
7 |
+
|
8 |
+
from bytesep.models.pytorch_modules import Base, init_bn, init_layer
|
9 |
+
|
10 |
+
|
11 |
+
class ConvBlockRes(nn.Module):
|
12 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
13 |
+
r"""Residual block."""
|
14 |
+
super(ConvBlockRes, self).__init__()
|
15 |
+
|
16 |
+
self.activation = activation
|
17 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
18 |
+
|
19 |
+
# ABN is not used for bn1 because we found using abn1 will degrade performance.
|
20 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
21 |
+
|
22 |
+
self.abn2 = InPlaceABNSync(
|
23 |
+
num_features=out_channels, momentum=momentum, activation='leaky_relu'
|
24 |
+
)
|
25 |
+
|
26 |
+
self.conv1 = nn.Conv2d(
|
27 |
+
in_channels=in_channels,
|
28 |
+
out_channels=out_channels,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
stride=(1, 1),
|
31 |
+
dilation=(1, 1),
|
32 |
+
padding=padding,
|
33 |
+
bias=False,
|
34 |
+
)
|
35 |
+
|
36 |
+
self.conv2 = nn.Conv2d(
|
37 |
+
in_channels=out_channels,
|
38 |
+
out_channels=out_channels,
|
39 |
+
kernel_size=kernel_size,
|
40 |
+
stride=(1, 1),
|
41 |
+
dilation=(1, 1),
|
42 |
+
padding=padding,
|
43 |
+
bias=False,
|
44 |
+
)
|
45 |
+
|
46 |
+
if in_channels != out_channels:
|
47 |
+
self.shortcut = nn.Conv2d(
|
48 |
+
in_channels=in_channels,
|
49 |
+
out_channels=out_channels,
|
50 |
+
kernel_size=(1, 1),
|
51 |
+
stride=(1, 1),
|
52 |
+
padding=(0, 0),
|
53 |
+
)
|
54 |
+
self.is_shortcut = True
|
55 |
+
else:
|
56 |
+
self.is_shortcut = False
|
57 |
+
|
58 |
+
self.init_weights()
|
59 |
+
|
60 |
+
def init_weights(self):
|
61 |
+
init_bn(self.bn1)
|
62 |
+
init_layer(self.conv1)
|
63 |
+
init_layer(self.conv2)
|
64 |
+
|
65 |
+
if self.is_shortcut:
|
66 |
+
init_layer(self.shortcut)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
origin = x
|
70 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
71 |
+
x = self.conv2(self.abn2(x))
|
72 |
+
|
73 |
+
if self.is_shortcut:
|
74 |
+
return self.shortcut(origin) + x
|
75 |
+
else:
|
76 |
+
return origin + x
|
77 |
+
|
78 |
+
|
79 |
+
class EncoderBlockRes4B(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self, in_channels, out_channels, kernel_size, downsample, activation, momentum
|
82 |
+
):
|
83 |
+
r"""Encoder block, contains 8 convolutional layers."""
|
84 |
+
super(EncoderBlockRes4B, self).__init__()
|
85 |
+
|
86 |
+
self.conv_block1 = ConvBlockRes(
|
87 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
88 |
+
)
|
89 |
+
self.conv_block2 = ConvBlockRes(
|
90 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
91 |
+
)
|
92 |
+
self.conv_block3 = ConvBlockRes(
|
93 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
94 |
+
)
|
95 |
+
self.conv_block4 = ConvBlockRes(
|
96 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
97 |
+
)
|
98 |
+
self.downsample = downsample
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
encoder = self.conv_block1(x)
|
102 |
+
encoder = self.conv_block2(encoder)
|
103 |
+
encoder = self.conv_block3(encoder)
|
104 |
+
encoder = self.conv_block4(encoder)
|
105 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
106 |
+
return encoder_pool, encoder
|
107 |
+
|
108 |
+
|
109 |
+
class DecoderBlockRes4B(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self, in_channels, out_channels, kernel_size, upsample, activation, momentum
|
112 |
+
):
|
113 |
+
r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
|
114 |
+
super(DecoderBlockRes4B, self).__init__()
|
115 |
+
self.kernel_size = kernel_size
|
116 |
+
self.stride = upsample
|
117 |
+
self.activation = activation
|
118 |
+
|
119 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
120 |
+
in_channels=in_channels,
|
121 |
+
out_channels=out_channels,
|
122 |
+
kernel_size=self.stride,
|
123 |
+
stride=self.stride,
|
124 |
+
padding=(0, 0),
|
125 |
+
bias=False,
|
126 |
+
dilation=(1, 1),
|
127 |
+
)
|
128 |
+
|
129 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
130 |
+
self.conv_block2 = ConvBlockRes(
|
131 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
132 |
+
)
|
133 |
+
self.conv_block3 = ConvBlockRes(
|
134 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
135 |
+
)
|
136 |
+
self.conv_block4 = ConvBlockRes(
|
137 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
138 |
+
)
|
139 |
+
self.conv_block5 = ConvBlockRes(
|
140 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
141 |
+
)
|
142 |
+
|
143 |
+
self.init_weights()
|
144 |
+
|
145 |
+
def init_weights(self):
|
146 |
+
init_bn(self.bn1)
|
147 |
+
init_layer(self.conv1)
|
148 |
+
|
149 |
+
def forward(self, input_tensor, concat_tensor):
|
150 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
151 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
152 |
+
x = self.conv_block2(x)
|
153 |
+
x = self.conv_block3(x)
|
154 |
+
x = self.conv_block4(x)
|
155 |
+
x = self.conv_block5(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class ResUNet143_DecouplePlusInplaceABN_ISMIR2021(nn.Module, Base):
|
160 |
+
def __init__(self, input_channels, target_sources_num):
|
161 |
+
super(ResUNet143_DecouplePlusInplaceABN_ISMIR2021, self).__init__()
|
162 |
+
|
163 |
+
self.input_channels = input_channels
|
164 |
+
self.target_sources_num = target_sources_num
|
165 |
+
|
166 |
+
window_size = 2048
|
167 |
+
hop_size = 441
|
168 |
+
center = True
|
169 |
+
pad_mode = 'reflect'
|
170 |
+
window = 'hann'
|
171 |
+
activation = 'leaky_relu'
|
172 |
+
momentum = 0.01
|
173 |
+
|
174 |
+
self.subbands_num = 1
|
175 |
+
|
176 |
+
assert (
|
177 |
+
self.subbands_num == 1
|
178 |
+
), "Using subbands_num > 1 on spectrogram \
|
179 |
+
will lead to unexpected performance sometimes. Suggest to use \
|
180 |
+
subband method on waveform."
|
181 |
+
|
182 |
+
# Downsample rate along the time axis.
|
183 |
+
self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q
|
184 |
+
self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
|
185 |
+
|
186 |
+
self.stft = STFT(
|
187 |
+
n_fft=window_size,
|
188 |
+
hop_length=hop_size,
|
189 |
+
win_length=window_size,
|
190 |
+
window=window,
|
191 |
+
center=center,
|
192 |
+
pad_mode=pad_mode,
|
193 |
+
freeze_parameters=True,
|
194 |
+
)
|
195 |
+
|
196 |
+
self.istft = ISTFT(
|
197 |
+
n_fft=window_size,
|
198 |
+
hop_length=hop_size,
|
199 |
+
win_length=window_size,
|
200 |
+
window=window,
|
201 |
+
center=center,
|
202 |
+
pad_mode=pad_mode,
|
203 |
+
freeze_parameters=True,
|
204 |
+
)
|
205 |
+
|
206 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
207 |
+
|
208 |
+
self.encoder_block1 = EncoderBlockRes4B(
|
209 |
+
in_channels=input_channels * self.subbands_num,
|
210 |
+
out_channels=32,
|
211 |
+
kernel_size=(3, 3),
|
212 |
+
downsample=(2, 2),
|
213 |
+
activation=activation,
|
214 |
+
momentum=momentum,
|
215 |
+
)
|
216 |
+
self.encoder_block2 = EncoderBlockRes4B(
|
217 |
+
in_channels=32,
|
218 |
+
out_channels=64,
|
219 |
+
kernel_size=(3, 3),
|
220 |
+
downsample=(2, 2),
|
221 |
+
activation=activation,
|
222 |
+
momentum=momentum,
|
223 |
+
)
|
224 |
+
self.encoder_block3 = EncoderBlockRes4B(
|
225 |
+
in_channels=64,
|
226 |
+
out_channels=128,
|
227 |
+
kernel_size=(3, 3),
|
228 |
+
downsample=(2, 2),
|
229 |
+
activation=activation,
|
230 |
+
momentum=momentum,
|
231 |
+
)
|
232 |
+
self.encoder_block4 = EncoderBlockRes4B(
|
233 |
+
in_channels=128,
|
234 |
+
out_channels=256,
|
235 |
+
kernel_size=(3, 3),
|
236 |
+
downsample=(2, 2),
|
237 |
+
activation=activation,
|
238 |
+
momentum=momentum,
|
239 |
+
)
|
240 |
+
self.encoder_block5 = EncoderBlockRes4B(
|
241 |
+
in_channels=256,
|
242 |
+
out_channels=384,
|
243 |
+
kernel_size=(3, 3),
|
244 |
+
downsample=(2, 2),
|
245 |
+
activation=activation,
|
246 |
+
momentum=momentum,
|
247 |
+
)
|
248 |
+
self.encoder_block6 = EncoderBlockRes4B(
|
249 |
+
in_channels=384,
|
250 |
+
out_channels=384,
|
251 |
+
kernel_size=(3, 3),
|
252 |
+
downsample=(1, 2),
|
253 |
+
activation=activation,
|
254 |
+
momentum=momentum,
|
255 |
+
)
|
256 |
+
self.conv_block7a = EncoderBlockRes4B(
|
257 |
+
in_channels=384,
|
258 |
+
out_channels=384,
|
259 |
+
kernel_size=(3, 3),
|
260 |
+
downsample=(1, 1),
|
261 |
+
activation=activation,
|
262 |
+
momentum=momentum,
|
263 |
+
)
|
264 |
+
self.conv_block7b = EncoderBlockRes4B(
|
265 |
+
in_channels=384,
|
266 |
+
out_channels=384,
|
267 |
+
kernel_size=(3, 3),
|
268 |
+
downsample=(1, 1),
|
269 |
+
activation=activation,
|
270 |
+
momentum=momentum,
|
271 |
+
)
|
272 |
+
self.conv_block7c = EncoderBlockRes4B(
|
273 |
+
in_channels=384,
|
274 |
+
out_channels=384,
|
275 |
+
kernel_size=(3, 3),
|
276 |
+
downsample=(1, 1),
|
277 |
+
activation=activation,
|
278 |
+
momentum=momentum,
|
279 |
+
)
|
280 |
+
self.conv_block7d = EncoderBlockRes4B(
|
281 |
+
in_channels=384,
|
282 |
+
out_channels=384,
|
283 |
+
kernel_size=(3, 3),
|
284 |
+
downsample=(1, 1),
|
285 |
+
activation=activation,
|
286 |
+
momentum=momentum,
|
287 |
+
)
|
288 |
+
self.decoder_block1 = DecoderBlockRes4B(
|
289 |
+
in_channels=384,
|
290 |
+
out_channels=384,
|
291 |
+
kernel_size=(3, 3),
|
292 |
+
upsample=(1, 2),
|
293 |
+
activation=activation,
|
294 |
+
momentum=momentum,
|
295 |
+
)
|
296 |
+
self.decoder_block2 = DecoderBlockRes4B(
|
297 |
+
in_channels=384,
|
298 |
+
out_channels=384,
|
299 |
+
kernel_size=(3, 3),
|
300 |
+
upsample=(2, 2),
|
301 |
+
activation=activation,
|
302 |
+
momentum=momentum,
|
303 |
+
)
|
304 |
+
self.decoder_block3 = DecoderBlockRes4B(
|
305 |
+
in_channels=384,
|
306 |
+
out_channels=256,
|
307 |
+
kernel_size=(3, 3),
|
308 |
+
upsample=(2, 2),
|
309 |
+
activation=activation,
|
310 |
+
momentum=momentum,
|
311 |
+
)
|
312 |
+
self.decoder_block4 = DecoderBlockRes4B(
|
313 |
+
in_channels=256,
|
314 |
+
out_channels=128,
|
315 |
+
kernel_size=(3, 3),
|
316 |
+
upsample=(2, 2),
|
317 |
+
activation=activation,
|
318 |
+
momentum=momentum,
|
319 |
+
)
|
320 |
+
self.decoder_block5 = DecoderBlockRes4B(
|
321 |
+
in_channels=128,
|
322 |
+
out_channels=64,
|
323 |
+
kernel_size=(3, 3),
|
324 |
+
upsample=(2, 2),
|
325 |
+
activation=activation,
|
326 |
+
momentum=momentum,
|
327 |
+
)
|
328 |
+
self.decoder_block6 = DecoderBlockRes4B(
|
329 |
+
in_channels=64,
|
330 |
+
out_channels=32,
|
331 |
+
kernel_size=(3, 3),
|
332 |
+
upsample=(2, 2),
|
333 |
+
activation=activation,
|
334 |
+
momentum=momentum,
|
335 |
+
)
|
336 |
+
|
337 |
+
self.after_conv_block1 = EncoderBlockRes4B(
|
338 |
+
in_channels=32,
|
339 |
+
out_channels=32,
|
340 |
+
kernel_size=(3, 3),
|
341 |
+
downsample=(1, 1),
|
342 |
+
activation=activation,
|
343 |
+
momentum=momentum,
|
344 |
+
)
|
345 |
+
|
346 |
+
self.after_conv2 = nn.Conv2d(
|
347 |
+
in_channels=32,
|
348 |
+
out_channels=target_sources_num
|
349 |
+
* input_channels
|
350 |
+
* self.K
|
351 |
+
* self.subbands_num,
|
352 |
+
kernel_size=(1, 1),
|
353 |
+
stride=(1, 1),
|
354 |
+
padding=(0, 0),
|
355 |
+
bias=True,
|
356 |
+
)
|
357 |
+
|
358 |
+
self.init_weights()
|
359 |
+
|
360 |
+
def init_weights(self):
|
361 |
+
init_bn(self.bn0)
|
362 |
+
init_layer(self.after_conv2)
|
363 |
+
|
364 |
+
def feature_maps_to_wav(
|
365 |
+
self,
|
366 |
+
input_tensor: torch.Tensor,
|
367 |
+
sp: torch.Tensor,
|
368 |
+
sin_in: torch.Tensor,
|
369 |
+
cos_in: torch.Tensor,
|
370 |
+
audio_length: int,
|
371 |
+
) -> torch.Tensor:
|
372 |
+
r"""Convert feature maps to waveform.
|
373 |
+
|
374 |
+
Args:
|
375 |
+
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
|
376 |
+
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
377 |
+
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
378 |
+
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
379 |
+
|
380 |
+
Outputs:
|
381 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
382 |
+
"""
|
383 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
384 |
+
|
385 |
+
x = input_tensor.reshape(
|
386 |
+
batch_size,
|
387 |
+
self.target_sources_num,
|
388 |
+
self.input_channels,
|
389 |
+
self.K,
|
390 |
+
time_steps,
|
391 |
+
freq_bins,
|
392 |
+
)
|
393 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
394 |
+
|
395 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
396 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
397 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
398 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
399 |
+
linear_mag = x[:, :, :, 3, :, :]
|
400 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
401 |
+
|
402 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
403 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
404 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
405 |
+
out_cos = (
|
406 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
407 |
+
)
|
408 |
+
out_sin = (
|
409 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
410 |
+
)
|
411 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
412 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
413 |
+
|
414 |
+
# Calculate |Y|.
|
415 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
|
416 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
417 |
+
|
418 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
419 |
+
out_real = out_mag * out_cos
|
420 |
+
out_imag = out_mag * out_sin
|
421 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
422 |
+
|
423 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
424 |
+
shape = (
|
425 |
+
batch_size * self.target_sources_num * self.input_channels,
|
426 |
+
1,
|
427 |
+
time_steps,
|
428 |
+
freq_bins,
|
429 |
+
)
|
430 |
+
out_real = out_real.reshape(shape)
|
431 |
+
out_imag = out_imag.reshape(shape)
|
432 |
+
|
433 |
+
# ISTFT.
|
434 |
+
x = self.istft(out_real, out_imag, audio_length)
|
435 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
436 |
+
|
437 |
+
# Reshape.
|
438 |
+
waveform = x.reshape(
|
439 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
440 |
+
)
|
441 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
442 |
+
|
443 |
+
return waveform
|
444 |
+
|
445 |
+
def forward(self, input_dict):
|
446 |
+
r"""Forward data into the module.
|
447 |
+
|
448 |
+
Args:
|
449 |
+
input_dict: dict, e.g., {
|
450 |
+
waveform: (batch_size, input_channels, segment_samples),
|
451 |
+
...,
|
452 |
+
}
|
453 |
+
|
454 |
+
Outputs:
|
455 |
+
output_dict: dict, e.g., {
|
456 |
+
'waveform': (batch_size, input_channels, segment_samples),
|
457 |
+
...,
|
458 |
+
}
|
459 |
+
"""
|
460 |
+
mixtures = input_dict['waveform']
|
461 |
+
# (batch_size, input_channels, segment_samples)
|
462 |
+
|
463 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
|
464 |
+
# mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
|
465 |
+
|
466 |
+
# Batch normalize on individual frequency bins.
|
467 |
+
x = mag.transpose(1, 3)
|
468 |
+
x = self.bn0(x)
|
469 |
+
x = x.transpose(1, 3)
|
470 |
+
# x: (batch_size, input_channels, time_steps, freq_bins)
|
471 |
+
|
472 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
473 |
+
origin_len = x.shape[2]
|
474 |
+
pad_len = (
|
475 |
+
int(np.ceil(x.shape[2] / self.time_downsample_ratio))
|
476 |
+
* self.time_downsample_ratio
|
477 |
+
- origin_len
|
478 |
+
)
|
479 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
480 |
+
# (batch_size, channels, padded_time_steps, freq_bins)
|
481 |
+
|
482 |
+
# Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024.
|
483 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
|
484 |
+
|
485 |
+
if self.subbands_num > 1:
|
486 |
+
x = self.subband.analysis(x)
|
487 |
+
# (bs, input_channels, T, F'), where F' = F // subbands_num
|
488 |
+
|
489 |
+
# UNet
|
490 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
|
491 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
|
492 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
|
493 |
+
(x4_pool, x4) = self.encoder_block4(
|
494 |
+
x3_pool
|
495 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
496 |
+
(x5_pool, x5) = self.encoder_block5(
|
497 |
+
x4_pool
|
498 |
+
) # x5_pool: (bs, 384, T / 32, F / 32)
|
499 |
+
(x6_pool, x6) = self.encoder_block6(
|
500 |
+
x5_pool
|
501 |
+
) # x6_pool: (bs, 384, T / 32, F / 64)
|
502 |
+
(x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
|
503 |
+
(x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
|
504 |
+
(x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
|
505 |
+
(x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
|
506 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
|
507 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
|
508 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
|
509 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
|
510 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
|
511 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
|
512 |
+
(x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
|
513 |
+
|
514 |
+
x = self.after_conv2(x) # (bs, channels * 3, T, F)
|
515 |
+
# (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
|
516 |
+
|
517 |
+
if self.subbands_num > 1:
|
518 |
+
x = self.subband.synthesis(x)
|
519 |
+
# (batch_size, target_sources_num * input_channles * self.K, T, F)
|
520 |
+
|
521 |
+
# Recover shape
|
522 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
|
523 |
+
|
524 |
+
x = x[:, :, 0:origin_len, :]
|
525 |
+
# (batch_size, target_sources_num * input_channles * self.K, T, F)
|
526 |
+
|
527 |
+
audio_length = mixtures.shape[2]
|
528 |
+
|
529 |
+
separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
|
530 |
+
# separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
|
531 |
+
|
532 |
+
output_dict = {'waveform': separated_audio}
|
533 |
+
|
534 |
+
return output_dict
|
bytesep/models/resunet_subbandtime.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
6 |
+
|
7 |
+
from bytesep.models.pytorch_modules import Base, init_bn, init_layer
|
8 |
+
from bytesep.models.subband_tools.pqmf import PQMF
|
9 |
+
|
10 |
+
|
11 |
+
class ConvBlockRes(nn.Module):
|
12 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
13 |
+
r"""Residual block."""
|
14 |
+
super(ConvBlockRes, self).__init__()
|
15 |
+
|
16 |
+
self.activation = activation
|
17 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
18 |
+
|
19 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
20 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
21 |
+
|
22 |
+
self.conv1 = nn.Conv2d(
|
23 |
+
in_channels=in_channels,
|
24 |
+
out_channels=out_channels,
|
25 |
+
kernel_size=kernel_size,
|
26 |
+
stride=(1, 1),
|
27 |
+
dilation=(1, 1),
|
28 |
+
padding=padding,
|
29 |
+
bias=False,
|
30 |
+
)
|
31 |
+
|
32 |
+
self.conv2 = nn.Conv2d(
|
33 |
+
in_channels=out_channels,
|
34 |
+
out_channels=out_channels,
|
35 |
+
kernel_size=kernel_size,
|
36 |
+
stride=(1, 1),
|
37 |
+
dilation=(1, 1),
|
38 |
+
padding=padding,
|
39 |
+
bias=False,
|
40 |
+
)
|
41 |
+
|
42 |
+
if in_channels != out_channels:
|
43 |
+
self.shortcut = nn.Conv2d(
|
44 |
+
in_channels=in_channels,
|
45 |
+
out_channels=out_channels,
|
46 |
+
kernel_size=(1, 1),
|
47 |
+
stride=(1, 1),
|
48 |
+
padding=(0, 0),
|
49 |
+
)
|
50 |
+
self.is_shortcut = True
|
51 |
+
else:
|
52 |
+
self.is_shortcut = False
|
53 |
+
|
54 |
+
self.init_weights()
|
55 |
+
|
56 |
+
def init_weights(self):
|
57 |
+
init_bn(self.bn1)
|
58 |
+
init_bn(self.bn2)
|
59 |
+
init_layer(self.conv1)
|
60 |
+
init_layer(self.conv2)
|
61 |
+
|
62 |
+
if self.is_shortcut:
|
63 |
+
init_layer(self.shortcut)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
origin = x
|
67 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
68 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
69 |
+
|
70 |
+
if self.is_shortcut:
|
71 |
+
return self.shortcut(origin) + x
|
72 |
+
else:
|
73 |
+
return origin + x
|
74 |
+
|
75 |
+
|
76 |
+
class EncoderBlockRes4B(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self, in_channels, out_channels, kernel_size, downsample, activation, momentum
|
79 |
+
):
|
80 |
+
r"""Encoder block, contains 8 convolutional layers."""
|
81 |
+
super(EncoderBlockRes4B, self).__init__()
|
82 |
+
|
83 |
+
self.conv_block1 = ConvBlockRes(
|
84 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
85 |
+
)
|
86 |
+
self.conv_block2 = ConvBlockRes(
|
87 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
88 |
+
)
|
89 |
+
self.conv_block3 = ConvBlockRes(
|
90 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
91 |
+
)
|
92 |
+
self.conv_block4 = ConvBlockRes(
|
93 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
94 |
+
)
|
95 |
+
self.downsample = downsample
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
encoder = self.conv_block1(x)
|
99 |
+
encoder = self.conv_block2(encoder)
|
100 |
+
encoder = self.conv_block3(encoder)
|
101 |
+
encoder = self.conv_block4(encoder)
|
102 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
103 |
+
return encoder_pool, encoder
|
104 |
+
|
105 |
+
|
106 |
+
class DecoderBlockRes4B(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self, in_channels, out_channels, kernel_size, upsample, activation, momentum
|
109 |
+
):
|
110 |
+
r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
|
111 |
+
super(DecoderBlockRes4B, self).__init__()
|
112 |
+
self.kernel_size = kernel_size
|
113 |
+
self.stride = upsample
|
114 |
+
self.activation = activation
|
115 |
+
|
116 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
117 |
+
in_channels=in_channels,
|
118 |
+
out_channels=out_channels,
|
119 |
+
kernel_size=self.stride,
|
120 |
+
stride=self.stride,
|
121 |
+
padding=(0, 0),
|
122 |
+
bias=False,
|
123 |
+
dilation=(1, 1),
|
124 |
+
)
|
125 |
+
|
126 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
127 |
+
self.conv_block2 = ConvBlockRes(
|
128 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
129 |
+
)
|
130 |
+
self.conv_block3 = ConvBlockRes(
|
131 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
132 |
+
)
|
133 |
+
self.conv_block4 = ConvBlockRes(
|
134 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
135 |
+
)
|
136 |
+
self.conv_block5 = ConvBlockRes(
|
137 |
+
out_channels, out_channels, kernel_size, activation, momentum
|
138 |
+
)
|
139 |
+
|
140 |
+
self.init_weights()
|
141 |
+
|
142 |
+
def init_weights(self):
|
143 |
+
init_bn(self.bn1)
|
144 |
+
init_layer(self.conv1)
|
145 |
+
|
146 |
+
def forward(self, input_tensor, concat_tensor):
|
147 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
148 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
149 |
+
x = self.conv_block2(x)
|
150 |
+
x = self.conv_block3(x)
|
151 |
+
x = self.conv_block4(x)
|
152 |
+
x = self.conv_block5(x)
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
class ResUNet143_Subbandtime(nn.Module, Base):
|
157 |
+
def __init__(self, input_channels, target_sources_num):
|
158 |
+
super(ResUNet143_Subbandtime, self).__init__()
|
159 |
+
|
160 |
+
self.input_channels = input_channels
|
161 |
+
self.target_sources_num = target_sources_num
|
162 |
+
|
163 |
+
window_size = 512
|
164 |
+
hop_size = 110
|
165 |
+
center = True
|
166 |
+
pad_mode = "reflect"
|
167 |
+
window = "hann"
|
168 |
+
activation = "leaky_relu"
|
169 |
+
momentum = 0.01
|
170 |
+
|
171 |
+
self.subbands_num = 4
|
172 |
+
self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q
|
173 |
+
|
174 |
+
self.downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
|
175 |
+
|
176 |
+
self.pqmf = PQMF(
|
177 |
+
N=self.subbands_num,
|
178 |
+
M=64,
|
179 |
+
project_root='bytesep/models/subband_tools/filters',
|
180 |
+
)
|
181 |
+
|
182 |
+
self.stft = STFT(
|
183 |
+
n_fft=window_size,
|
184 |
+
hop_length=hop_size,
|
185 |
+
win_length=window_size,
|
186 |
+
window=window,
|
187 |
+
center=center,
|
188 |
+
pad_mode=pad_mode,
|
189 |
+
freeze_parameters=True,
|
190 |
+
)
|
191 |
+
|
192 |
+
self.istft = ISTFT(
|
193 |
+
n_fft=window_size,
|
194 |
+
hop_length=hop_size,
|
195 |
+
win_length=window_size,
|
196 |
+
window=window,
|
197 |
+
center=center,
|
198 |
+
pad_mode=pad_mode,
|
199 |
+
freeze_parameters=True,
|
200 |
+
)
|
201 |
+
|
202 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
203 |
+
|
204 |
+
self.encoder_block1 = EncoderBlockRes4B(
|
205 |
+
in_channels=input_channels * self.subbands_num,
|
206 |
+
out_channels=32,
|
207 |
+
kernel_size=(3, 3),
|
208 |
+
downsample=(2, 2),
|
209 |
+
activation=activation,
|
210 |
+
momentum=momentum,
|
211 |
+
)
|
212 |
+
self.encoder_block2 = EncoderBlockRes4B(
|
213 |
+
in_channels=32,
|
214 |
+
out_channels=64,
|
215 |
+
kernel_size=(3, 3),
|
216 |
+
downsample=(2, 2),
|
217 |
+
activation=activation,
|
218 |
+
momentum=momentum,
|
219 |
+
)
|
220 |
+
self.encoder_block3 = EncoderBlockRes4B(
|
221 |
+
in_channels=64,
|
222 |
+
out_channels=128,
|
223 |
+
kernel_size=(3, 3),
|
224 |
+
downsample=(2, 2),
|
225 |
+
activation=activation,
|
226 |
+
momentum=momentum,
|
227 |
+
)
|
228 |
+
self.encoder_block4 = EncoderBlockRes4B(
|
229 |
+
in_channels=128,
|
230 |
+
out_channels=256,
|
231 |
+
kernel_size=(3, 3),
|
232 |
+
downsample=(2, 2),
|
233 |
+
activation=activation,
|
234 |
+
momentum=momentum,
|
235 |
+
)
|
236 |
+
self.encoder_block5 = EncoderBlockRes4B(
|
237 |
+
in_channels=256,
|
238 |
+
out_channels=384,
|
239 |
+
kernel_size=(3, 3),
|
240 |
+
downsample=(2, 2),
|
241 |
+
activation=activation,
|
242 |
+
momentum=momentum,
|
243 |
+
)
|
244 |
+
self.encoder_block6 = EncoderBlockRes4B(
|
245 |
+
in_channels=384,
|
246 |
+
out_channels=384,
|
247 |
+
kernel_size=(3, 3),
|
248 |
+
downsample=(1, 2),
|
249 |
+
activation=activation,
|
250 |
+
momentum=momentum,
|
251 |
+
)
|
252 |
+
self.conv_block7a = EncoderBlockRes4B(
|
253 |
+
in_channels=384,
|
254 |
+
out_channels=384,
|
255 |
+
kernel_size=(3, 3),
|
256 |
+
downsample=(1, 1),
|
257 |
+
activation=activation,
|
258 |
+
momentum=momentum,
|
259 |
+
)
|
260 |
+
self.conv_block7b = EncoderBlockRes4B(
|
261 |
+
in_channels=384,
|
262 |
+
out_channels=384,
|
263 |
+
kernel_size=(3, 3),
|
264 |
+
downsample=(1, 1),
|
265 |
+
activation=activation,
|
266 |
+
momentum=momentum,
|
267 |
+
)
|
268 |
+
self.conv_block7c = EncoderBlockRes4B(
|
269 |
+
in_channels=384,
|
270 |
+
out_channels=384,
|
271 |
+
kernel_size=(3, 3),
|
272 |
+
downsample=(1, 1),
|
273 |
+
activation=activation,
|
274 |
+
momentum=momentum,
|
275 |
+
)
|
276 |
+
self.conv_block7d = EncoderBlockRes4B(
|
277 |
+
in_channels=384,
|
278 |
+
out_channels=384,
|
279 |
+
kernel_size=(3, 3),
|
280 |
+
downsample=(1, 1),
|
281 |
+
activation=activation,
|
282 |
+
momentum=momentum,
|
283 |
+
)
|
284 |
+
self.decoder_block1 = DecoderBlockRes4B(
|
285 |
+
in_channels=384,
|
286 |
+
out_channels=384,
|
287 |
+
kernel_size=(3, 3),
|
288 |
+
upsample=(1, 2),
|
289 |
+
activation=activation,
|
290 |
+
momentum=momentum,
|
291 |
+
)
|
292 |
+
self.decoder_block2 = DecoderBlockRes4B(
|
293 |
+
in_channels=384,
|
294 |
+
out_channels=384,
|
295 |
+
kernel_size=(3, 3),
|
296 |
+
upsample=(2, 2),
|
297 |
+
activation=activation,
|
298 |
+
momentum=momentum,
|
299 |
+
)
|
300 |
+
self.decoder_block3 = DecoderBlockRes4B(
|
301 |
+
in_channels=384,
|
302 |
+
out_channels=256,
|
303 |
+
kernel_size=(3, 3),
|
304 |
+
upsample=(2, 2),
|
305 |
+
activation=activation,
|
306 |
+
momentum=momentum,
|
307 |
+
)
|
308 |
+
self.decoder_block4 = DecoderBlockRes4B(
|
309 |
+
in_channels=256,
|
310 |
+
out_channels=128,
|
311 |
+
kernel_size=(3, 3),
|
312 |
+
upsample=(2, 2),
|
313 |
+
activation=activation,
|
314 |
+
momentum=momentum,
|
315 |
+
)
|
316 |
+
self.decoder_block5 = DecoderBlockRes4B(
|
317 |
+
in_channels=128,
|
318 |
+
out_channels=64,
|
319 |
+
kernel_size=(3, 3),
|
320 |
+
upsample=(2, 2),
|
321 |
+
activation=activation,
|
322 |
+
momentum=momentum,
|
323 |
+
)
|
324 |
+
self.decoder_block6 = DecoderBlockRes4B(
|
325 |
+
in_channels=64,
|
326 |
+
out_channels=32,
|
327 |
+
kernel_size=(3, 3),
|
328 |
+
upsample=(2, 2),
|
329 |
+
activation=activation,
|
330 |
+
momentum=momentum,
|
331 |
+
)
|
332 |
+
|
333 |
+
self.after_conv_block1 = EncoderBlockRes4B(
|
334 |
+
in_channels=32,
|
335 |
+
out_channels=32,
|
336 |
+
kernel_size=(3, 3),
|
337 |
+
downsample=(1, 1),
|
338 |
+
activation=activation,
|
339 |
+
momentum=momentum,
|
340 |
+
)
|
341 |
+
|
342 |
+
self.after_conv2 = nn.Conv2d(
|
343 |
+
in_channels=32,
|
344 |
+
out_channels=target_sources_num
|
345 |
+
* input_channels
|
346 |
+
* self.K
|
347 |
+
* self.subbands_num,
|
348 |
+
kernel_size=(1, 1),
|
349 |
+
stride=(1, 1),
|
350 |
+
padding=(0, 0),
|
351 |
+
bias=True,
|
352 |
+
)
|
353 |
+
|
354 |
+
self.init_weights()
|
355 |
+
|
356 |
+
def init_weights(self):
|
357 |
+
init_bn(self.bn0)
|
358 |
+
init_layer(self.after_conv2)
|
359 |
+
|
360 |
+
def feature_maps_to_wav(
|
361 |
+
self,
|
362 |
+
input_tensor: torch.Tensor,
|
363 |
+
sp: torch.Tensor,
|
364 |
+
sin_in: torch.Tensor,
|
365 |
+
cos_in: torch.Tensor,
|
366 |
+
audio_length: int,
|
367 |
+
) -> torch.Tensor:
|
368 |
+
r"""Convert feature maps to waveform.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
|
372 |
+
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
373 |
+
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
374 |
+
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
375 |
+
|
376 |
+
Outputs:
|
377 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
378 |
+
"""
|
379 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
380 |
+
|
381 |
+
x = input_tensor.reshape(
|
382 |
+
batch_size,
|
383 |
+
self.target_sources_num,
|
384 |
+
self.input_channels,
|
385 |
+
self.K,
|
386 |
+
time_steps,
|
387 |
+
freq_bins,
|
388 |
+
)
|
389 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
390 |
+
|
391 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
392 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
393 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
394 |
+
linear_mag = torch.tanh(x[:, :, :, 3, :, :])
|
395 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
396 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
397 |
+
|
398 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
399 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
400 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
401 |
+
out_cos = (
|
402 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
403 |
+
)
|
404 |
+
out_sin = (
|
405 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
406 |
+
)
|
407 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
408 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
409 |
+
|
410 |
+
# Calculate |Y|.
|
411 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
|
412 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
413 |
+
|
414 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
415 |
+
out_real = out_mag * out_cos
|
416 |
+
out_imag = out_mag * out_sin
|
417 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
418 |
+
|
419 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
420 |
+
shape = (
|
421 |
+
batch_size * self.target_sources_num * self.input_channels,
|
422 |
+
1,
|
423 |
+
time_steps,
|
424 |
+
freq_bins,
|
425 |
+
)
|
426 |
+
out_real = out_real.reshape(shape)
|
427 |
+
out_imag = out_imag.reshape(shape)
|
428 |
+
|
429 |
+
# ISTFT.
|
430 |
+
x = self.istft(out_real, out_imag, audio_length)
|
431 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
432 |
+
|
433 |
+
# Reshape.
|
434 |
+
waveform = x.reshape(
|
435 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
436 |
+
)
|
437 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
438 |
+
|
439 |
+
return waveform
|
440 |
+
|
441 |
+
def forward(self, input_dict):
|
442 |
+
r"""Forward data into the module.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
input_dict: dict, e.g., {
|
446 |
+
waveform: (batch_size, input_channels, segment_samples),
|
447 |
+
...,
|
448 |
+
}
|
449 |
+
|
450 |
+
Outputs:
|
451 |
+
output_dict: dict, e.g., {
|
452 |
+
'waveform': (batch_size, input_channels, segment_samples),
|
453 |
+
...,
|
454 |
+
}
|
455 |
+
"""
|
456 |
+
mixtures = input_dict['waveform']
|
457 |
+
# (batch_size, input_channels, segment_samples)
|
458 |
+
|
459 |
+
subband_x = self.pqmf.analysis(mixtures)
|
460 |
+
# subband_x: (batch_size, input_channels * subbands_num, segment_samples)
|
461 |
+
|
462 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
|
463 |
+
# mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
|
464 |
+
|
465 |
+
# Batch normalize on individual frequency bins.
|
466 |
+
x = mag.transpose(1, 3)
|
467 |
+
x = self.bn0(x)
|
468 |
+
x = x.transpose(1, 3)
|
469 |
+
# (batch_size, input_channels * subbands_num, time_steps, freq_bins)
|
470 |
+
|
471 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
472 |
+
origin_len = x.shape[2]
|
473 |
+
pad_len = (
|
474 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
475 |
+
- origin_len
|
476 |
+
)
|
477 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
478 |
+
# x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
|
479 |
+
|
480 |
+
# Let frequency bins be evenly divided by 2, e.g., 257 -> 256
|
481 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
|
482 |
+
# x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
|
483 |
+
|
484 |
+
# UNet
|
485 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
|
486 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
|
487 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
|
488 |
+
(x4_pool, x4) = self.encoder_block4(
|
489 |
+
x3_pool
|
490 |
+
) # x4_pool: (bs, 256, T / 16, F / 16)
|
491 |
+
(x5_pool, x5) = self.encoder_block5(
|
492 |
+
x4_pool
|
493 |
+
) # x5_pool: (bs, 384, T / 32, F / 32)
|
494 |
+
(x6_pool, x6) = self.encoder_block6(
|
495 |
+
x5_pool
|
496 |
+
) # x6_pool: (bs, 384, T / 32, F / 64)
|
497 |
+
(x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
|
498 |
+
(x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
|
499 |
+
(x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
|
500 |
+
(x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
|
501 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
|
502 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
|
503 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
|
504 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
|
505 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
|
506 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
|
507 |
+
(x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
|
508 |
+
|
509 |
+
x = self.after_conv2(x)
|
510 |
+
# (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
|
511 |
+
|
512 |
+
# Recover shape
|
513 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257.
|
514 |
+
|
515 |
+
x = x[:, :, 0:origin_len, :]
|
516 |
+
# (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
|
517 |
+
|
518 |
+
audio_length = subband_x.shape[2]
|
519 |
+
|
520 |
+
# Recover each subband spectrograms to subband waveforms. Then synthesis
|
521 |
+
# the subband waveforms to a waveform.
|
522 |
+
C1 = x.shape[1] // self.subbands_num
|
523 |
+
C2 = mag.shape[1] // self.subbands_num
|
524 |
+
|
525 |
+
separated_subband_audio = torch.cat(
|
526 |
+
[
|
527 |
+
self.feature_maps_to_wav(
|
528 |
+
input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
|
529 |
+
sp=mag[:, j * C2 : (j + 1) * C2, :, :],
|
530 |
+
sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
|
531 |
+
cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
|
532 |
+
audio_length=audio_length,
|
533 |
+
)
|
534 |
+
for j in range(self.subbands_num)
|
535 |
+
],
|
536 |
+
dim=1,
|
537 |
+
)
|
538 |
+
# (batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
|
539 |
+
|
540 |
+
separated_audio = self.pqmf.synthesis(separated_subband_audio)
|
541 |
+
# (batch_size, input_channles, segment_samples)
|
542 |
+
|
543 |
+
output_dict = {'waveform': separated_audio}
|
544 |
+
|
545 |
+
return output_dict
|
bytesep/models/subband_tools/__init__.py
ADDED
File without changes
|
bytesep/models/subband_tools/fDomainHelper.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchlibrosa.stft import STFT, ISTFT, magphase
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from tools.pytorch.modules.pqmf import PQMF
|
6 |
+
|
7 |
+
|
8 |
+
class FDomainHelper(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
window_size=2048,
|
12 |
+
hop_size=441,
|
13 |
+
center=True,
|
14 |
+
pad_mode='reflect',
|
15 |
+
window='hann',
|
16 |
+
freeze_parameters=True,
|
17 |
+
subband=None,
|
18 |
+
root="/Users/admin/Documents/projects/",
|
19 |
+
):
|
20 |
+
super(FDomainHelper, self).__init__()
|
21 |
+
self.subband = subband
|
22 |
+
if self.subband is None:
|
23 |
+
self.stft = STFT(
|
24 |
+
n_fft=window_size,
|
25 |
+
hop_length=hop_size,
|
26 |
+
win_length=window_size,
|
27 |
+
window=window,
|
28 |
+
center=center,
|
29 |
+
pad_mode=pad_mode,
|
30 |
+
freeze_parameters=freeze_parameters,
|
31 |
+
)
|
32 |
+
|
33 |
+
self.istft = ISTFT(
|
34 |
+
n_fft=window_size,
|
35 |
+
hop_length=hop_size,
|
36 |
+
win_length=window_size,
|
37 |
+
window=window,
|
38 |
+
center=center,
|
39 |
+
pad_mode=pad_mode,
|
40 |
+
freeze_parameters=freeze_parameters,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
self.stft = STFT(
|
44 |
+
n_fft=window_size // self.subband,
|
45 |
+
hop_length=hop_size // self.subband,
|
46 |
+
win_length=window_size // self.subband,
|
47 |
+
window=window,
|
48 |
+
center=center,
|
49 |
+
pad_mode=pad_mode,
|
50 |
+
freeze_parameters=freeze_parameters,
|
51 |
+
)
|
52 |
+
|
53 |
+
self.istft = ISTFT(
|
54 |
+
n_fft=window_size // self.subband,
|
55 |
+
hop_length=hop_size // self.subband,
|
56 |
+
win_length=window_size // self.subband,
|
57 |
+
window=window,
|
58 |
+
center=center,
|
59 |
+
pad_mode=pad_mode,
|
60 |
+
freeze_parameters=freeze_parameters,
|
61 |
+
)
|
62 |
+
|
63 |
+
if subband is not None and root is not None:
|
64 |
+
self.qmf = PQMF(subband, 64, root)
|
65 |
+
|
66 |
+
def complex_spectrogram(self, input, eps=0.0):
|
67 |
+
# [batchsize, samples]
|
68 |
+
# return [batchsize, 2, t-steps, f-bins]
|
69 |
+
real, imag = self.stft(input)
|
70 |
+
return torch.cat([real, imag], dim=1)
|
71 |
+
|
72 |
+
def reverse_complex_spectrogram(self, input, eps=0.0, length=None):
|
73 |
+
# [batchsize, 2[real,imag], t-steps, f-bins]
|
74 |
+
wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length)
|
75 |
+
return wav
|
76 |
+
|
77 |
+
def spectrogram(self, input, eps=0.0):
|
78 |
+
(real, imag) = self.stft(input.float())
|
79 |
+
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
|
80 |
+
|
81 |
+
def spectrogram_phase(self, input, eps=0.0):
|
82 |
+
(real, imag) = self.stft(input.float())
|
83 |
+
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
|
84 |
+
cos = real / mag
|
85 |
+
sin = imag / mag
|
86 |
+
return mag, cos, sin
|
87 |
+
|
88 |
+
def wav_to_spectrogram_phase(self, input, eps=1e-8):
|
89 |
+
"""Waveform to spectrogram.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
input: (batch_size, channels_num, segment_samples)
|
93 |
+
|
94 |
+
Outputs:
|
95 |
+
output: (batch_size, channels_num, time_steps, freq_bins)
|
96 |
+
"""
|
97 |
+
sp_list = []
|
98 |
+
cos_list = []
|
99 |
+
sin_list = []
|
100 |
+
channels_num = input.shape[1]
|
101 |
+
for channel in range(channels_num):
|
102 |
+
mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
|
103 |
+
sp_list.append(mag)
|
104 |
+
cos_list.append(cos)
|
105 |
+
sin_list.append(sin)
|
106 |
+
|
107 |
+
sps = torch.cat(sp_list, dim=1)
|
108 |
+
coss = torch.cat(cos_list, dim=1)
|
109 |
+
sins = torch.cat(sin_list, dim=1)
|
110 |
+
return sps, coss, sins
|
111 |
+
|
112 |
+
def spectrogram_phase_to_wav(self, sps, coss, sins, length):
|
113 |
+
channels_num = sps.size()[1]
|
114 |
+
res = []
|
115 |
+
for i in range(channels_num):
|
116 |
+
res.append(
|
117 |
+
self.istft(
|
118 |
+
sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...],
|
119 |
+
sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...],
|
120 |
+
length,
|
121 |
+
)
|
122 |
+
)
|
123 |
+
res[-1] = res[-1].unsqueeze(1)
|
124 |
+
return torch.cat(res, dim=1)
|
125 |
+
|
126 |
+
def wav_to_spectrogram(self, input, eps=1e-8):
|
127 |
+
"""Waveform to spectrogram.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
input: (batch_size,channels_num, segment_samples)
|
131 |
+
|
132 |
+
Outputs:
|
133 |
+
output: (batch_size, channels_num, time_steps, freq_bins)
|
134 |
+
"""
|
135 |
+
sp_list = []
|
136 |
+
channels_num = input.shape[1]
|
137 |
+
for channel in range(channels_num):
|
138 |
+
sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
|
139 |
+
output = torch.cat(sp_list, dim=1)
|
140 |
+
return output
|
141 |
+
|
142 |
+
def spectrogram_to_wav(self, input, spectrogram, length=None):
|
143 |
+
"""Spectrogram to waveform.
|
144 |
+
Args:
|
145 |
+
input: (batch_size, segment_samples, channels_num)
|
146 |
+
spectrogram: (batch_size, channels_num, time_steps, freq_bins)
|
147 |
+
|
148 |
+
Outputs:
|
149 |
+
output: (batch_size, segment_samples, channels_num)
|
150 |
+
"""
|
151 |
+
channels_num = input.shape[1]
|
152 |
+
wav_list = []
|
153 |
+
for channel in range(channels_num):
|
154 |
+
(real, imag) = self.stft(input[:, channel, :])
|
155 |
+
(_, cos, sin) = magphase(real, imag)
|
156 |
+
wav_list.append(
|
157 |
+
self.istft(
|
158 |
+
spectrogram[:, channel : channel + 1, :, :] * cos,
|
159 |
+
spectrogram[:, channel : channel + 1, :, :] * sin,
|
160 |
+
length,
|
161 |
+
)
|
162 |
+
)
|
163 |
+
|
164 |
+
output = torch.stack(wav_list, dim=1)
|
165 |
+
return output
|
166 |
+
|
167 |
+
# todo the following code is not bug free!
|
168 |
+
def wav_to_complex_spectrogram(self, input, eps=0.0):
|
169 |
+
# [batchsize , channels, samples]
|
170 |
+
# [batchsize, 2[real,imag]*channels, t-steps, f-bins]
|
171 |
+
res = []
|
172 |
+
channels_num = input.shape[1]
|
173 |
+
for channel in range(channels_num):
|
174 |
+
res.append(self.complex_spectrogram(input[:, channel, :], eps=eps))
|
175 |
+
return torch.cat(res, dim=1)
|
176 |
+
|
177 |
+
def complex_spectrogram_to_wav(self, input, eps=0.0, length=None):
|
178 |
+
# [batchsize, 2[real,imag]*channels, t-steps, f-bins]
|
179 |
+
# return [batchsize, channels, samples]
|
180 |
+
channels = input.size()[1] // 2
|
181 |
+
wavs = []
|
182 |
+
for i in range(channels):
|
183 |
+
wavs.append(
|
184 |
+
self.reverse_complex_spectrogram(
|
185 |
+
input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length
|
186 |
+
)
|
187 |
+
)
|
188 |
+
wavs[-1] = wavs[-1].unsqueeze(1)
|
189 |
+
return torch.cat(wavs, dim=1)
|
190 |
+
|
191 |
+
def wav_to_complex_subband_spectrogram(self, input, eps=0.0):
|
192 |
+
# [batchsize, channels, samples]
|
193 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
194 |
+
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
|
195 |
+
subspec = self.wav_to_complex_spectrogram(subwav)
|
196 |
+
return subspec
|
197 |
+
|
198 |
+
def complex_subband_spectrogram_to_wav(self, input, eps=0.0):
|
199 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
200 |
+
# [batchsize, channels, samples]
|
201 |
+
subwav = self.complex_spectrogram_to_wav(input)
|
202 |
+
data = self.qmf.synthesis(subwav)
|
203 |
+
return data
|
204 |
+
|
205 |
+
def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8):
|
206 |
+
"""
|
207 |
+
:param input:
|
208 |
+
:param eps:
|
209 |
+
:return:
|
210 |
+
loss = torch.nn.L1Loss()
|
211 |
+
model = FDomainHelper(subband=4)
|
212 |
+
data = torch.randn((3,1, 44100*3))
|
213 |
+
|
214 |
+
sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data)
|
215 |
+
wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4)
|
216 |
+
|
217 |
+
print(loss(data,wav))
|
218 |
+
print(torch.max(torch.abs(data-wav)))
|
219 |
+
|
220 |
+
"""
|
221 |
+
# [batchsize, channels, samples]
|
222 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
223 |
+
subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
|
224 |
+
sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps)
|
225 |
+
return sps, coss, sins
|
226 |
+
|
227 |
+
def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0):
|
228 |
+
# [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
|
229 |
+
# [batchsize, channels, samples]
|
230 |
+
subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length)
|
231 |
+
data = self.qmf.synthesis(subwav)
|
232 |
+
return data
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
# from thop import profile
|
237 |
+
# from thop import clever_format
|
238 |
+
# from tools.file.wav import *
|
239 |
+
# import time
|
240 |
+
#
|
241 |
+
# wav = torch.randn((1,2,44100))
|
242 |
+
# model = FDomainHelper()
|
243 |
+
|
244 |
+
from tools.file.wav import *
|
245 |
+
|
246 |
+
loss = torch.nn.L1Loss()
|
247 |
+
model = FDomainHelper()
|
248 |
+
data = torch.randn((3, 1, 44100 * 5))
|
249 |
+
|
250 |
+
sps = model.wav_to_complex_spectrogram(data)
|
251 |
+
print(sps.size())
|
252 |
+
wav = model.complex_spectrogram_to_wav(sps, 44100 * 5)
|
253 |
+
|
254 |
+
print(loss(data, wav))
|
255 |
+
print(torch.max(torch.abs(data - wav)))
|
bytesep/models/subband_tools/filters/f_4_64.mat
ADDED
Binary file (2.19 kB). View file
|
|
bytesep/models/subband_tools/filters/h_4_64.mat
ADDED
Binary file (2.19 kB). View file
|
|
bytesep/models/subband_tools/pqmf.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : subband_util.py
|
3 |
+
@Contact : liu.8948@buckeyemail.osu.edu
|
4 |
+
@License : (C)Copyright 2020-2021
|
5 |
+
@Modify Time @Author @Version @Desciption
|
6 |
+
------------ ------- -------- -----------
|
7 |
+
2020/4/3 4:54 PM Haohe Liu 1.0 None
|
8 |
+
'''
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.nn as nn
|
13 |
+
import numpy as np
|
14 |
+
import os.path as op
|
15 |
+
from scipy.io import loadmat
|
16 |
+
|
17 |
+
|
18 |
+
def load_mat2numpy(fname=""):
|
19 |
+
'''
|
20 |
+
Args:
|
21 |
+
fname: pth to mat
|
22 |
+
type:
|
23 |
+
Returns: dic object
|
24 |
+
'''
|
25 |
+
if len(fname) == 0:
|
26 |
+
return None
|
27 |
+
else:
|
28 |
+
return loadmat(fname)
|
29 |
+
|
30 |
+
|
31 |
+
class PQMF(nn.Module):
|
32 |
+
def __init__(self, N, M, project_root):
|
33 |
+
super().__init__()
|
34 |
+
self.N = N # nsubband
|
35 |
+
self.M = M # nfilter
|
36 |
+
try:
|
37 |
+
assert (N, M) in [(8, 64), (4, 64), (2, 64)]
|
38 |
+
except:
|
39 |
+
print("Warning:", N, "subbandand ", M, " filter is not supported")
|
40 |
+
self.pad_samples = 64
|
41 |
+
self.name = str(N) + "_" + str(M) + ".mat"
|
42 |
+
self.ana_conv_filter = nn.Conv1d(
|
43 |
+
1, out_channels=N, kernel_size=M, stride=N, bias=False
|
44 |
+
)
|
45 |
+
data = load_mat2numpy(op.join(project_root, "f_" + self.name))
|
46 |
+
data = data['f'].astype(np.float32) / N
|
47 |
+
data = np.flipud(data.T).T
|
48 |
+
data = np.reshape(data, (N, 1, M)).copy()
|
49 |
+
dict_new = self.ana_conv_filter.state_dict().copy()
|
50 |
+
dict_new['weight'] = torch.from_numpy(data)
|
51 |
+
self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
|
52 |
+
self.ana_conv_filter.load_state_dict(dict_new)
|
53 |
+
|
54 |
+
self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
|
55 |
+
self.syn_conv_filter = nn.Conv1d(
|
56 |
+
N, out_channels=N, kernel_size=M // N, stride=1, bias=False
|
57 |
+
)
|
58 |
+
gk = load_mat2numpy(op.join(project_root, "h_" + self.name))
|
59 |
+
gk = gk['h'].astype(np.float32)
|
60 |
+
gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
|
61 |
+
gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
|
62 |
+
dict_new = self.syn_conv_filter.state_dict().copy()
|
63 |
+
dict_new['weight'] = torch.from_numpy(gk)
|
64 |
+
self.syn_conv_filter.load_state_dict(dict_new)
|
65 |
+
|
66 |
+
for param in self.parameters():
|
67 |
+
param.requires_grad = False
|
68 |
+
|
69 |
+
def __analysis_channel(self, inputs):
|
70 |
+
return self.ana_conv_filter(self.ana_pad(inputs))
|
71 |
+
|
72 |
+
def __systhesis_channel(self, inputs):
|
73 |
+
ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
|
74 |
+
return torch.reshape(ret, (ret.shape[0], 1, -1))
|
75 |
+
|
76 |
+
def analysis(self, inputs):
|
77 |
+
'''
|
78 |
+
:param inputs: [batchsize,channel,raw_wav],value:[0,1]
|
79 |
+
:return:
|
80 |
+
'''
|
81 |
+
inputs = F.pad(inputs, ((0, self.pad_samples)))
|
82 |
+
ret = None
|
83 |
+
for i in range(inputs.size()[1]): # channels
|
84 |
+
if ret is None:
|
85 |
+
ret = self.__analysis_channel(inputs[:, i : i + 1, :])
|
86 |
+
else:
|
87 |
+
ret = torch.cat(
|
88 |
+
(ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
|
89 |
+
)
|
90 |
+
return ret
|
91 |
+
|
92 |
+
def synthesis(self, data):
|
93 |
+
'''
|
94 |
+
:param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
|
95 |
+
:return:
|
96 |
+
'''
|
97 |
+
ret = None
|
98 |
+
# data = F.pad(data,((0,self.pad_samples//self.N)))
|
99 |
+
for i in range(data.size()[1]): # channels
|
100 |
+
if i % self.N == 0:
|
101 |
+
if ret is None:
|
102 |
+
ret = self.__systhesis_channel(data[:, i : i + self.N, :])
|
103 |
+
else:
|
104 |
+
new = self.__systhesis_channel(data[:, i : i + self.N, :])
|
105 |
+
ret = torch.cat((ret, new), dim=1)
|
106 |
+
ret = ret[..., : -self.pad_samples]
|
107 |
+
return ret
|
108 |
+
|
109 |
+
def forward(self, inputs):
|
110 |
+
return self.ana_conv_filter(self.ana_pad(inputs))
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
import torch
|
115 |
+
import numpy as np
|
116 |
+
import matplotlib.pyplot as plt
|
117 |
+
from tools.file.wav import *
|
118 |
+
|
119 |
+
pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects")
|
120 |
+
|
121 |
+
rs = np.random.RandomState(0)
|
122 |
+
x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32)
|
123 |
+
|
124 |
+
a1 = pqmf.analysis(x)
|
125 |
+
a2 = pqmf.synthesis(a1)
|
126 |
+
|
127 |
+
print(a2.size(), x.size())
|
128 |
+
|
129 |
+
plt.subplot(211)
|
130 |
+
plt.plot(x[0, 0, -500:])
|
131 |
+
plt.subplot(212)
|
132 |
+
plt.plot(a2[0, 0, -500:])
|
133 |
+
plt.plot(x[0, 0, -500:] - a2[0, 0, -500:])
|
134 |
+
plt.show()
|
135 |
+
|
136 |
+
print(torch.sum(torch.abs(x[...] - a2[...])))
|
bytesep/models/unet.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Dict, List, NoReturn, Tuple
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.optim as optim
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
13 |
+
|
14 |
+
from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
|
15 |
+
|
16 |
+
|
17 |
+
class ConvBlock(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_channels: int,
|
21 |
+
out_channels: int,
|
22 |
+
kernel_size: Tuple,
|
23 |
+
activation: str,
|
24 |
+
momentum: float,
|
25 |
+
):
|
26 |
+
r"""Convolutional block."""
|
27 |
+
super(ConvBlock, self).__init__()
|
28 |
+
|
29 |
+
self.activation = activation
|
30 |
+
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
31 |
+
|
32 |
+
self.conv1 = nn.Conv2d(
|
33 |
+
in_channels=in_channels,
|
34 |
+
out_channels=out_channels,
|
35 |
+
kernel_size=kernel_size,
|
36 |
+
stride=(1, 1),
|
37 |
+
dilation=(1, 1),
|
38 |
+
padding=padding,
|
39 |
+
bias=False,
|
40 |
+
)
|
41 |
+
|
42 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
43 |
+
|
44 |
+
self.conv2 = nn.Conv2d(
|
45 |
+
in_channels=out_channels,
|
46 |
+
out_channels=out_channels,
|
47 |
+
kernel_size=kernel_size,
|
48 |
+
stride=(1, 1),
|
49 |
+
dilation=(1, 1),
|
50 |
+
padding=padding,
|
51 |
+
bias=False,
|
52 |
+
)
|
53 |
+
|
54 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
55 |
+
|
56 |
+
self.init_weights()
|
57 |
+
|
58 |
+
def init_weights(self) -> NoReturn:
|
59 |
+
r"""Initialize weights."""
|
60 |
+
init_layer(self.conv1)
|
61 |
+
init_layer(self.conv2)
|
62 |
+
init_bn(self.bn1)
|
63 |
+
init_bn(self.bn2)
|
64 |
+
|
65 |
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
66 |
+
r"""Forward data into the module.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
|
73 |
+
"""
|
74 |
+
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
|
75 |
+
x = act(self.bn2(self.conv2(x)), self.activation)
|
76 |
+
output_tensor = x
|
77 |
+
|
78 |
+
return output_tensor
|
79 |
+
|
80 |
+
|
81 |
+
class EncoderBlock(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
in_channels: int,
|
85 |
+
out_channels: int,
|
86 |
+
kernel_size: Tuple,
|
87 |
+
downsample: Tuple,
|
88 |
+
activation: str,
|
89 |
+
momentum: float,
|
90 |
+
):
|
91 |
+
r"""Encoder block."""
|
92 |
+
super(EncoderBlock, self).__init__()
|
93 |
+
|
94 |
+
self.conv_block = ConvBlock(
|
95 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
96 |
+
)
|
97 |
+
self.downsample = downsample
|
98 |
+
|
99 |
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
100 |
+
r"""Forward data into the module.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
|
107 |
+
encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
|
108 |
+
"""
|
109 |
+
encoder_tensor = self.conv_block(input_tensor)
|
110 |
+
# encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
|
111 |
+
|
112 |
+
encoder_pool = F.avg_pool2d(encoder_tensor, kernel_size=self.downsample)
|
113 |
+
# encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
|
114 |
+
|
115 |
+
return encoder_pool, encoder_tensor
|
116 |
+
|
117 |
+
|
118 |
+
class DecoderBlock(nn.Module):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
in_channels: int,
|
122 |
+
out_channels: int,
|
123 |
+
kernel_size: Tuple,
|
124 |
+
upsample: Tuple,
|
125 |
+
activation: str,
|
126 |
+
momentum: float,
|
127 |
+
):
|
128 |
+
r"""Decoder block."""
|
129 |
+
super(DecoderBlock, self).__init__()
|
130 |
+
|
131 |
+
self.kernel_size = kernel_size
|
132 |
+
self.stride = upsample
|
133 |
+
self.activation = activation
|
134 |
+
|
135 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
136 |
+
in_channels=in_channels,
|
137 |
+
out_channels=out_channels,
|
138 |
+
kernel_size=self.stride,
|
139 |
+
stride=self.stride,
|
140 |
+
padding=(0, 0),
|
141 |
+
bias=False,
|
142 |
+
dilation=(1, 1),
|
143 |
+
)
|
144 |
+
|
145 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
146 |
+
|
147 |
+
self.conv_block2 = ConvBlock(
|
148 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
149 |
+
)
|
150 |
+
|
151 |
+
self.init_weights()
|
152 |
+
|
153 |
+
def init_weights(self):
|
154 |
+
r"""Initialize weights."""
|
155 |
+
init_layer(self.conv1)
|
156 |
+
init_bn(self.bn1)
|
157 |
+
|
158 |
+
def forward(
|
159 |
+
self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor
|
160 |
+
) -> torch.Tensor:
|
161 |
+
r"""Forward data into the module.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
torch_tensor: (batch_size, in_feature_maps, downsampled_time_steps, downsampled_freq_bins)
|
165 |
+
concat_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
|
169 |
+
"""
|
170 |
+
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
|
171 |
+
# (batch_size, in_feature_maps, time_steps, freq_bins)
|
172 |
+
|
173 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
174 |
+
# (batch_size, in_feature_maps * 2, time_steps, freq_bins)
|
175 |
+
|
176 |
+
output_tensor = self.conv_block2(x)
|
177 |
+
# output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
|
178 |
+
|
179 |
+
return output_tensor
|
180 |
+
|
181 |
+
|
182 |
+
class UNet(nn.Module, Base):
|
183 |
+
def __init__(self, input_channels: int, target_sources_num: int):
|
184 |
+
r"""UNet."""
|
185 |
+
super(UNet, self).__init__()
|
186 |
+
|
187 |
+
self.input_channels = input_channels
|
188 |
+
self.target_sources_num = target_sources_num
|
189 |
+
|
190 |
+
window_size = 2048
|
191 |
+
hop_size = 441
|
192 |
+
center = True
|
193 |
+
pad_mode = "reflect"
|
194 |
+
window = "hann"
|
195 |
+
activation = "leaky_relu"
|
196 |
+
momentum = 0.01
|
197 |
+
|
198 |
+
self.subbands_num = 1
|
199 |
+
|
200 |
+
assert (
|
201 |
+
self.subbands_num == 1
|
202 |
+
), "Using subbands_num > 1 on spectrogram \
|
203 |
+
will lead to unexpected performance sometimes. Suggest to use \
|
204 |
+
subband method on waveform."
|
205 |
+
|
206 |
+
self.K = 3 # outputs: |M|, cos∠M, sin∠M
|
207 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
208 |
+
|
209 |
+
self.stft = STFT(
|
210 |
+
n_fft=window_size,
|
211 |
+
hop_length=hop_size,
|
212 |
+
win_length=window_size,
|
213 |
+
window=window,
|
214 |
+
center=center,
|
215 |
+
pad_mode=pad_mode,
|
216 |
+
freeze_parameters=True,
|
217 |
+
)
|
218 |
+
|
219 |
+
self.istft = ISTFT(
|
220 |
+
n_fft=window_size,
|
221 |
+
hop_length=hop_size,
|
222 |
+
win_length=window_size,
|
223 |
+
window=window,
|
224 |
+
center=center,
|
225 |
+
pad_mode=pad_mode,
|
226 |
+
freeze_parameters=True,
|
227 |
+
)
|
228 |
+
|
229 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
230 |
+
|
231 |
+
self.subband = Subband(subbands_num=self.subbands_num)
|
232 |
+
|
233 |
+
self.encoder_block1 = EncoderBlock(
|
234 |
+
in_channels=input_channels * self.subbands_num,
|
235 |
+
out_channels=32,
|
236 |
+
kernel_size=(3, 3),
|
237 |
+
downsample=(2, 2),
|
238 |
+
activation=activation,
|
239 |
+
momentum=momentum,
|
240 |
+
)
|
241 |
+
self.encoder_block2 = EncoderBlock(
|
242 |
+
in_channels=32,
|
243 |
+
out_channels=64,
|
244 |
+
kernel_size=(3, 3),
|
245 |
+
downsample=(2, 2),
|
246 |
+
activation=activation,
|
247 |
+
momentum=momentum,
|
248 |
+
)
|
249 |
+
self.encoder_block3 = EncoderBlock(
|
250 |
+
in_channels=64,
|
251 |
+
out_channels=128,
|
252 |
+
kernel_size=(3, 3),
|
253 |
+
downsample=(2, 2),
|
254 |
+
activation=activation,
|
255 |
+
momentum=momentum,
|
256 |
+
)
|
257 |
+
self.encoder_block4 = EncoderBlock(
|
258 |
+
in_channels=128,
|
259 |
+
out_channels=256,
|
260 |
+
kernel_size=(3, 3),
|
261 |
+
downsample=(2, 2),
|
262 |
+
activation=activation,
|
263 |
+
momentum=momentum,
|
264 |
+
)
|
265 |
+
self.encoder_block5 = EncoderBlock(
|
266 |
+
in_channels=256,
|
267 |
+
out_channels=384,
|
268 |
+
kernel_size=(3, 3),
|
269 |
+
downsample=(2, 2),
|
270 |
+
activation=activation,
|
271 |
+
momentum=momentum,
|
272 |
+
)
|
273 |
+
self.encoder_block6 = EncoderBlock(
|
274 |
+
in_channels=384,
|
275 |
+
out_channels=384,
|
276 |
+
kernel_size=(3, 3),
|
277 |
+
downsample=(2, 2),
|
278 |
+
activation=activation,
|
279 |
+
momentum=momentum,
|
280 |
+
)
|
281 |
+
self.conv_block7 = ConvBlock(
|
282 |
+
in_channels=384,
|
283 |
+
out_channels=384,
|
284 |
+
kernel_size=(3, 3),
|
285 |
+
activation=activation,
|
286 |
+
momentum=momentum,
|
287 |
+
)
|
288 |
+
self.decoder_block1 = DecoderBlock(
|
289 |
+
in_channels=384,
|
290 |
+
out_channels=384,
|
291 |
+
kernel_size=(3, 3),
|
292 |
+
upsample=(2, 2),
|
293 |
+
activation=activation,
|
294 |
+
momentum=momentum,
|
295 |
+
)
|
296 |
+
self.decoder_block2 = DecoderBlock(
|
297 |
+
in_channels=384,
|
298 |
+
out_channels=384,
|
299 |
+
kernel_size=(3, 3),
|
300 |
+
upsample=(2, 2),
|
301 |
+
activation=activation,
|
302 |
+
momentum=momentum,
|
303 |
+
)
|
304 |
+
self.decoder_block3 = DecoderBlock(
|
305 |
+
in_channels=384,
|
306 |
+
out_channels=256,
|
307 |
+
kernel_size=(3, 3),
|
308 |
+
upsample=(2, 2),
|
309 |
+
activation=activation,
|
310 |
+
momentum=momentum,
|
311 |
+
)
|
312 |
+
self.decoder_block4 = DecoderBlock(
|
313 |
+
in_channels=256,
|
314 |
+
out_channels=128,
|
315 |
+
kernel_size=(3, 3),
|
316 |
+
upsample=(2, 2),
|
317 |
+
activation=activation,
|
318 |
+
momentum=momentum,
|
319 |
+
)
|
320 |
+
self.decoder_block5 = DecoderBlock(
|
321 |
+
in_channels=128,
|
322 |
+
out_channels=64,
|
323 |
+
kernel_size=(3, 3),
|
324 |
+
upsample=(2, 2),
|
325 |
+
activation=activation,
|
326 |
+
momentum=momentum,
|
327 |
+
)
|
328 |
+
|
329 |
+
self.decoder_block6 = DecoderBlock(
|
330 |
+
in_channels=64,
|
331 |
+
out_channels=32,
|
332 |
+
kernel_size=(3, 3),
|
333 |
+
upsample=(2, 2),
|
334 |
+
activation=activation,
|
335 |
+
momentum=momentum,
|
336 |
+
)
|
337 |
+
|
338 |
+
self.after_conv_block1 = ConvBlock(
|
339 |
+
in_channels=32,
|
340 |
+
out_channels=32,
|
341 |
+
kernel_size=(3, 3),
|
342 |
+
activation=activation,
|
343 |
+
momentum=momentum,
|
344 |
+
)
|
345 |
+
|
346 |
+
self.after_conv2 = nn.Conv2d(
|
347 |
+
in_channels=32,
|
348 |
+
out_channels=target_sources_num
|
349 |
+
* input_channels
|
350 |
+
* self.K
|
351 |
+
* self.subbands_num,
|
352 |
+
kernel_size=(1, 1),
|
353 |
+
stride=(1, 1),
|
354 |
+
padding=(0, 0),
|
355 |
+
bias=True,
|
356 |
+
)
|
357 |
+
|
358 |
+
self.init_weights()
|
359 |
+
|
360 |
+
def init_weights(self):
|
361 |
+
r"""Initialize weights."""
|
362 |
+
init_bn(self.bn0)
|
363 |
+
init_layer(self.after_conv2)
|
364 |
+
|
365 |
+
def feature_maps_to_wav(
|
366 |
+
self,
|
367 |
+
input_tensor: torch.Tensor,
|
368 |
+
sp: torch.Tensor,
|
369 |
+
sin_in: torch.Tensor,
|
370 |
+
cos_in: torch.Tensor,
|
371 |
+
audio_length: int,
|
372 |
+
) -> torch.Tensor:
|
373 |
+
r"""Convert feature maps to waveform.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
|
377 |
+
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
378 |
+
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
379 |
+
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
380 |
+
|
381 |
+
Outputs:
|
382 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
383 |
+
"""
|
384 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
385 |
+
|
386 |
+
x = input_tensor.reshape(
|
387 |
+
batch_size,
|
388 |
+
self.target_sources_num,
|
389 |
+
self.input_channels,
|
390 |
+
self.K,
|
391 |
+
time_steps,
|
392 |
+
freq_bins,
|
393 |
+
)
|
394 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
395 |
+
|
396 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
397 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
398 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
399 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
400 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
401 |
+
|
402 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
403 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
404 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
405 |
+
out_cos = (
|
406 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
407 |
+
)
|
408 |
+
out_sin = (
|
409 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
410 |
+
)
|
411 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
412 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
413 |
+
|
414 |
+
# Calculate |Y|.
|
415 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
|
416 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
417 |
+
|
418 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
419 |
+
out_real = out_mag * out_cos
|
420 |
+
out_imag = out_mag * out_sin
|
421 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
422 |
+
|
423 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
424 |
+
shape = (
|
425 |
+
batch_size * self.target_sources_num * self.input_channels,
|
426 |
+
1,
|
427 |
+
time_steps,
|
428 |
+
freq_bins,
|
429 |
+
)
|
430 |
+
out_real = out_real.reshape(shape)
|
431 |
+
out_imag = out_imag.reshape(shape)
|
432 |
+
|
433 |
+
# ISTFT.
|
434 |
+
x = self.istft(out_real, out_imag, audio_length)
|
435 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
436 |
+
|
437 |
+
# Reshape.
|
438 |
+
waveform = x.reshape(
|
439 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
440 |
+
)
|
441 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
442 |
+
|
443 |
+
return waveform
|
444 |
+
|
445 |
+
def forward(self, input_dict: Dict) -> Dict:
|
446 |
+
r"""Forward data into the module.
|
447 |
+
|
448 |
+
Args:
|
449 |
+
input_dict: dict, e.g., {
|
450 |
+
waveform: (batch_size, input_channels, segment_samples),
|
451 |
+
...,
|
452 |
+
}
|
453 |
+
|
454 |
+
Outputs:
|
455 |
+
output_dict: dict, e.g., {
|
456 |
+
'waveform': (batch_size, input_channels, segment_samples),
|
457 |
+
...,
|
458 |
+
}
|
459 |
+
"""
|
460 |
+
mixtures = input_dict['waveform']
|
461 |
+
# (batch_size, input_channels, segment_samples)
|
462 |
+
|
463 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
|
464 |
+
# mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
|
465 |
+
|
466 |
+
# Batch normalize on individual frequency bins.
|
467 |
+
x = mag.transpose(1, 3)
|
468 |
+
x = self.bn0(x)
|
469 |
+
x = x.transpose(1, 3)
|
470 |
+
# x: (batch_size, input_channels, time_steps, freq_bins)
|
471 |
+
|
472 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
473 |
+
origin_len = x.shape[2]
|
474 |
+
pad_len = (
|
475 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
476 |
+
- origin_len
|
477 |
+
)
|
478 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
479 |
+
# x: (batch_size, input_channels, padded_time_steps, freq_bins)
|
480 |
+
|
481 |
+
# Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
|
482 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
|
483 |
+
|
484 |
+
if self.subbands_num > 1:
|
485 |
+
x = self.subband.analysis(x)
|
486 |
+
# (bs, input_channels, T, F'), where F' = F // subbands_num
|
487 |
+
|
488 |
+
# UNet
|
489 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
|
490 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
|
491 |
+
(x3_pool, x3) = self.encoder_block3(
|
492 |
+
x2_pool
|
493 |
+
) # x3_pool: (bs, 128, T / 8, F' / 8)
|
494 |
+
(x4_pool, x4) = self.encoder_block4(
|
495 |
+
x3_pool
|
496 |
+
) # x4_pool: (bs, 256, T / 16, F' / 16)
|
497 |
+
(x5_pool, x5) = self.encoder_block5(
|
498 |
+
x4_pool
|
499 |
+
) # x5_pool: (bs, 384, T / 32, F' / 32)
|
500 |
+
(x6_pool, x6) = self.encoder_block6(
|
501 |
+
x5_pool
|
502 |
+
) # x6_pool: (bs, 384, T / 64, F' / 64)
|
503 |
+
x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
|
504 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
|
505 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
|
506 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
|
507 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
|
508 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
|
509 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
|
510 |
+
x = self.after_conv_block1(x12) # (bs, 32, T, F')
|
511 |
+
|
512 |
+
x = self.after_conv2(x)
|
513 |
+
# (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
|
514 |
+
|
515 |
+
if self.subbands_num > 1:
|
516 |
+
x = self.subband.synthesis(x)
|
517 |
+
# (batch_size, target_sources_num * input_channles * self.K, T, F)
|
518 |
+
|
519 |
+
# Recover shape
|
520 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
|
521 |
+
|
522 |
+
x = x[:, :, 0:origin_len, :]
|
523 |
+
# (batch_size, target_sources_num * input_channles * self.K, T, F)
|
524 |
+
|
525 |
+
audio_length = mixtures.shape[2]
|
526 |
+
|
527 |
+
separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
|
528 |
+
# separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
|
529 |
+
|
530 |
+
output_dict = {'waveform': separated_audio}
|
531 |
+
|
532 |
+
return output_dict
|
bytesep/models/unet_subbandtime.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchlibrosa.stft import ISTFT, STFT, magphase
|
8 |
+
|
9 |
+
from bytesep.models.pytorch_modules import Base, init_bn, init_layer
|
10 |
+
from bytesep.models.subband_tools.pqmf import PQMF
|
11 |
+
from bytesep.models.unet import ConvBlock, DecoderBlock, EncoderBlock
|
12 |
+
|
13 |
+
|
14 |
+
class UNetSubbandTime(nn.Module, Base):
|
15 |
+
def __init__(self, input_channels: int, target_sources_num: int):
|
16 |
+
r"""Subband waveform UNet."""
|
17 |
+
super(UNetSubbandTime, self).__init__()
|
18 |
+
|
19 |
+
self.input_channels = input_channels
|
20 |
+
self.target_sources_num = target_sources_num
|
21 |
+
|
22 |
+
window_size = 512 # 2048 // 4
|
23 |
+
hop_size = 110 # 441 // 4
|
24 |
+
center = True
|
25 |
+
pad_mode = "reflect"
|
26 |
+
window = "hann"
|
27 |
+
activation = "leaky_relu"
|
28 |
+
momentum = 0.01
|
29 |
+
|
30 |
+
self.subbands_num = 4
|
31 |
+
self.K = 3 # outputs: |M|, cos∠M, sin∠M
|
32 |
+
|
33 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
34 |
+
|
35 |
+
self.pqmf = PQMF(
|
36 |
+
N=self.subbands_num,
|
37 |
+
M=64,
|
38 |
+
project_root='bytesep/models/subband_tools/filters',
|
39 |
+
)
|
40 |
+
|
41 |
+
self.stft = STFT(
|
42 |
+
n_fft=window_size,
|
43 |
+
hop_length=hop_size,
|
44 |
+
win_length=window_size,
|
45 |
+
window=window,
|
46 |
+
center=center,
|
47 |
+
pad_mode=pad_mode,
|
48 |
+
freeze_parameters=True,
|
49 |
+
)
|
50 |
+
|
51 |
+
self.istft = ISTFT(
|
52 |
+
n_fft=window_size,
|
53 |
+
hop_length=hop_size,
|
54 |
+
win_length=window_size,
|
55 |
+
window=window,
|
56 |
+
center=center,
|
57 |
+
pad_mode=pad_mode,
|
58 |
+
freeze_parameters=True,
|
59 |
+
)
|
60 |
+
|
61 |
+
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
62 |
+
|
63 |
+
self.encoder_block1 = EncoderBlock(
|
64 |
+
in_channels=input_channels * self.subbands_num,
|
65 |
+
out_channels=32,
|
66 |
+
kernel_size=(3, 3),
|
67 |
+
downsample=(2, 2),
|
68 |
+
activation=activation,
|
69 |
+
momentum=momentum,
|
70 |
+
)
|
71 |
+
self.encoder_block2 = EncoderBlock(
|
72 |
+
in_channels=32,
|
73 |
+
out_channels=64,
|
74 |
+
kernel_size=(3, 3),
|
75 |
+
downsample=(2, 2),
|
76 |
+
activation=activation,
|
77 |
+
momentum=momentum,
|
78 |
+
)
|
79 |
+
self.encoder_block3 = EncoderBlock(
|
80 |
+
in_channels=64,
|
81 |
+
out_channels=128,
|
82 |
+
kernel_size=(3, 3),
|
83 |
+
downsample=(2, 2),
|
84 |
+
activation=activation,
|
85 |
+
momentum=momentum,
|
86 |
+
)
|
87 |
+
self.encoder_block4 = EncoderBlock(
|
88 |
+
in_channels=128,
|
89 |
+
out_channels=256,
|
90 |
+
kernel_size=(3, 3),
|
91 |
+
downsample=(2, 2),
|
92 |
+
activation=activation,
|
93 |
+
momentum=momentum,
|
94 |
+
)
|
95 |
+
self.encoder_block5 = EncoderBlock(
|
96 |
+
in_channels=256,
|
97 |
+
out_channels=384,
|
98 |
+
kernel_size=(3, 3),
|
99 |
+
downsample=(2, 2),
|
100 |
+
activation=activation,
|
101 |
+
momentum=momentum,
|
102 |
+
)
|
103 |
+
self.encoder_block6 = EncoderBlock(
|
104 |
+
in_channels=384,
|
105 |
+
out_channels=384,
|
106 |
+
kernel_size=(3, 3),
|
107 |
+
downsample=(2, 2),
|
108 |
+
activation=activation,
|
109 |
+
momentum=momentum,
|
110 |
+
)
|
111 |
+
self.conv_block7 = ConvBlock(
|
112 |
+
in_channels=384,
|
113 |
+
out_channels=384,
|
114 |
+
kernel_size=(3, 3),
|
115 |
+
activation=activation,
|
116 |
+
momentum=momentum,
|
117 |
+
)
|
118 |
+
self.decoder_block1 = DecoderBlock(
|
119 |
+
in_channels=384,
|
120 |
+
out_channels=384,
|
121 |
+
kernel_size=(3, 3),
|
122 |
+
upsample=(2, 2),
|
123 |
+
activation=activation,
|
124 |
+
momentum=momentum,
|
125 |
+
)
|
126 |
+
self.decoder_block2 = DecoderBlock(
|
127 |
+
in_channels=384,
|
128 |
+
out_channels=384,
|
129 |
+
kernel_size=(3, 3),
|
130 |
+
upsample=(2, 2),
|
131 |
+
activation=activation,
|
132 |
+
momentum=momentum,
|
133 |
+
)
|
134 |
+
self.decoder_block3 = DecoderBlock(
|
135 |
+
in_channels=384,
|
136 |
+
out_channels=256,
|
137 |
+
kernel_size=(3, 3),
|
138 |
+
upsample=(2, 2),
|
139 |
+
activation=activation,
|
140 |
+
momentum=momentum,
|
141 |
+
)
|
142 |
+
self.decoder_block4 = DecoderBlock(
|
143 |
+
in_channels=256,
|
144 |
+
out_channels=128,
|
145 |
+
kernel_size=(3, 3),
|
146 |
+
upsample=(2, 2),
|
147 |
+
activation=activation,
|
148 |
+
momentum=momentum,
|
149 |
+
)
|
150 |
+
self.decoder_block5 = DecoderBlock(
|
151 |
+
in_channels=128,
|
152 |
+
out_channels=64,
|
153 |
+
kernel_size=(3, 3),
|
154 |
+
upsample=(2, 2),
|
155 |
+
activation=activation,
|
156 |
+
momentum=momentum,
|
157 |
+
)
|
158 |
+
|
159 |
+
self.decoder_block6 = DecoderBlock(
|
160 |
+
in_channels=64,
|
161 |
+
out_channels=32,
|
162 |
+
kernel_size=(3, 3),
|
163 |
+
upsample=(2, 2),
|
164 |
+
activation=activation,
|
165 |
+
momentum=momentum,
|
166 |
+
)
|
167 |
+
|
168 |
+
self.after_conv_block1 = ConvBlock(
|
169 |
+
in_channels=32,
|
170 |
+
out_channels=32,
|
171 |
+
kernel_size=(3, 3),
|
172 |
+
activation=activation,
|
173 |
+
momentum=momentum,
|
174 |
+
)
|
175 |
+
|
176 |
+
self.after_conv2 = nn.Conv2d(
|
177 |
+
in_channels=32,
|
178 |
+
out_channels=target_sources_num
|
179 |
+
* input_channels
|
180 |
+
* self.K
|
181 |
+
* self.subbands_num,
|
182 |
+
kernel_size=(1, 1),
|
183 |
+
stride=(1, 1),
|
184 |
+
padding=(0, 0),
|
185 |
+
bias=True,
|
186 |
+
)
|
187 |
+
|
188 |
+
self.init_weights()
|
189 |
+
|
190 |
+
def init_weights(self):
|
191 |
+
r"""Initialize weights."""
|
192 |
+
init_bn(self.bn0)
|
193 |
+
init_layer(self.after_conv2)
|
194 |
+
|
195 |
+
def feature_maps_to_wav(
|
196 |
+
self,
|
197 |
+
input_tensor: torch.Tensor,
|
198 |
+
sp: torch.Tensor,
|
199 |
+
sin_in: torch.Tensor,
|
200 |
+
cos_in: torch.Tensor,
|
201 |
+
audio_length: int,
|
202 |
+
) -> torch.Tensor:
|
203 |
+
r"""Convert feature maps to waveform.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
|
207 |
+
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
208 |
+
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
209 |
+
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
|
210 |
+
|
211 |
+
Outputs:
|
212 |
+
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
|
213 |
+
"""
|
214 |
+
batch_size, _, time_steps, freq_bins = input_tensor.shape
|
215 |
+
|
216 |
+
x = input_tensor.reshape(
|
217 |
+
batch_size,
|
218 |
+
self.target_sources_num,
|
219 |
+
self.input_channels,
|
220 |
+
self.K,
|
221 |
+
time_steps,
|
222 |
+
freq_bins,
|
223 |
+
)
|
224 |
+
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
|
225 |
+
|
226 |
+
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
227 |
+
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
228 |
+
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
229 |
+
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
230 |
+
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
231 |
+
|
232 |
+
# Y = |Y|cos∠Y + j|Y|sin∠Y
|
233 |
+
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
|
234 |
+
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
|
235 |
+
out_cos = (
|
236 |
+
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
237 |
+
)
|
238 |
+
out_sin = (
|
239 |
+
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
240 |
+
)
|
241 |
+
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
242 |
+
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
243 |
+
|
244 |
+
# Calculate |Y|.
|
245 |
+
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
|
246 |
+
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
247 |
+
|
248 |
+
# Calculate Y_{real} and Y_{imag} for ISTFT.
|
249 |
+
out_real = out_mag * out_cos
|
250 |
+
out_imag = out_mag * out_sin
|
251 |
+
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
|
252 |
+
|
253 |
+
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
|
254 |
+
shape = (
|
255 |
+
batch_size * self.target_sources_num * self.input_channels,
|
256 |
+
1,
|
257 |
+
time_steps,
|
258 |
+
freq_bins,
|
259 |
+
)
|
260 |
+
out_real = out_real.reshape(shape)
|
261 |
+
out_imag = out_imag.reshape(shape)
|
262 |
+
|
263 |
+
# ISTFT.
|
264 |
+
x = self.istft(out_real, out_imag, audio_length)
|
265 |
+
# (batch_size * target_sources_num * input_channels, segments_num)
|
266 |
+
|
267 |
+
# Reshape.
|
268 |
+
waveform = x.reshape(
|
269 |
+
batch_size, self.target_sources_num * self.input_channels, audio_length
|
270 |
+
)
|
271 |
+
# (batch_size, target_sources_num * input_channels, segments_num)
|
272 |
+
|
273 |
+
return waveform
|
274 |
+
|
275 |
+
def forward(self, input_dict: Dict) -> Dict:
|
276 |
+
"""Forward data into the module.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
input_dict: dict, e.g., {
|
280 |
+
waveform: (batch_size, input_channels, segment_samples),
|
281 |
+
...,
|
282 |
+
}
|
283 |
+
|
284 |
+
Outputs:
|
285 |
+
output_dict: dict, e.g., {
|
286 |
+
'waveform': (batch_size, input_channels, segment_samples),
|
287 |
+
...,
|
288 |
+
}
|
289 |
+
"""
|
290 |
+
mixtures = input_dict['waveform']
|
291 |
+
# (batch_size, input_channels, segment_samples)
|
292 |
+
|
293 |
+
if self.subbands_num > 1:
|
294 |
+
subband_x = self.pqmf.analysis(mixtures)
|
295 |
+
# -- subband_x: (batch_size, input_channels * subbands_num, segment_samples)
|
296 |
+
# -- subband_x: (batch_size, subbands_num * input_channels, segment_samples)
|
297 |
+
else:
|
298 |
+
subband_x = mixtures
|
299 |
+
|
300 |
+
# from IPython import embed; embed(using=False); os._exit(0)
|
301 |
+
# import soundfile
|
302 |
+
# soundfile.write(file='_zz.wav', data=subband_x.data.cpu().numpy()[0, 2], samplerate=11025)
|
303 |
+
|
304 |
+
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
|
305 |
+
# mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
|
306 |
+
|
307 |
+
# Batch normalize on individual frequency bins.
|
308 |
+
x = mag.transpose(1, 3)
|
309 |
+
x = self.bn0(x)
|
310 |
+
x = x.transpose(1, 3)
|
311 |
+
# (batch_size, input_channels * subbands_num, time_steps, freq_bins)
|
312 |
+
|
313 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
314 |
+
origin_len = x.shape[2]
|
315 |
+
pad_len = (
|
316 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
317 |
+
- origin_len
|
318 |
+
)
|
319 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
320 |
+
# x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
|
321 |
+
|
322 |
+
# Let frequency bins be evenly divided by 2, e.g., 257 -> 256
|
323 |
+
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
|
324 |
+
# x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
|
325 |
+
|
326 |
+
# UNet
|
327 |
+
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
|
328 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
|
329 |
+
(x3_pool, x3) = self.encoder_block3(
|
330 |
+
x2_pool
|
331 |
+
) # x3_pool: (bs, 128, T / 8, F' / 8)
|
332 |
+
(x4_pool, x4) = self.encoder_block4(
|
333 |
+
x3_pool
|
334 |
+
) # x4_pool: (bs, 256, T / 16, F' / 16)
|
335 |
+
(x5_pool, x5) = self.encoder_block5(
|
336 |
+
x4_pool
|
337 |
+
) # x5_pool: (bs, 384, T / 32, F' / 32)
|
338 |
+
(x6_pool, x6) = self.encoder_block6(
|
339 |
+
x5_pool
|
340 |
+
) # x6_pool: (bs, 384, T / 64, F' / 64)
|
341 |
+
x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
|
342 |
+
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
|
343 |
+
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
|
344 |
+
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
|
345 |
+
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
|
346 |
+
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
|
347 |
+
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
|
348 |
+
x = self.after_conv_block1(x12) # (bs, 32, T, F')
|
349 |
+
|
350 |
+
x = self.after_conv2(x)
|
351 |
+
# (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
|
352 |
+
|
353 |
+
# Recover shape
|
354 |
+
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257.
|
355 |
+
|
356 |
+
x = x[:, :, 0:origin_len, :]
|
357 |
+
# (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
|
358 |
+
|
359 |
+
audio_length = subband_x.shape[2]
|
360 |
+
|
361 |
+
# Recover each subband spectrograms to subband waveforms. Then synthesis
|
362 |
+
# the subband waveforms to a waveform.
|
363 |
+
C1 = x.shape[1] // self.subbands_num
|
364 |
+
C2 = mag.shape[1] // self.subbands_num
|
365 |
+
|
366 |
+
separated_subband_audio = torch.cat(
|
367 |
+
[
|
368 |
+
self.feature_maps_to_wav(
|
369 |
+
input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
|
370 |
+
sp=mag[:, j * C2 : (j + 1) * C2, :, :],
|
371 |
+
sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
|
372 |
+
cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
|
373 |
+
audio_length=audio_length,
|
374 |
+
)
|
375 |
+
for j in range(self.subbands_num)
|
376 |
+
],
|
377 |
+
dim=1,
|
378 |
+
)
|
379 |
+
# (batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
|
380 |
+
|
381 |
+
if self.subbands_num > 1:
|
382 |
+
separated_audio = self.pqmf.synthesis(separated_subband_audio)
|
383 |
+
# (batch_size, target_sources_num * input_channles, segment_samples)
|
384 |
+
else:
|
385 |
+
separated_audio = separated_subband_audio
|
386 |
+
|
387 |
+
output_dict = {'waveform': separated_audio}
|
388 |
+
|
389 |
+
return output_dict
|
bytesep/optimizers/__init__.py
ADDED
File without changes
|
bytesep/optimizers/lr_schedulers.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int):
|
2 |
+
r"""Get lr_lambda for LambdaLR. E.g.,
|
3 |
+
|
4 |
+
.. code-block: python
|
5 |
+
lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)
|
6 |
+
|
7 |
+
from torch.optim.lr_scheduler import LambdaLR
|
8 |
+
LambdaLR(optimizer, lr_lambda)
|
9 |
+
|
10 |
+
Args:
|
11 |
+
warm_up_steps: int, steps for warm up
|
12 |
+
reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
learning rate: float
|
16 |
+
"""
|
17 |
+
if step <= warm_up_steps:
|
18 |
+
return step / warm_up_steps
|
19 |
+
else:
|
20 |
+
return 0.9 ** (step // reduce_lr_steps)
|
bytesep/plot_results/__init__.py
ADDED
File without changes
|
bytesep/plot_results/musdb18.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def load_sdrs(workspace, task_name, filename, config, gpus, source_type):
|
10 |
+
|
11 |
+
stat_path = os.path.join(
|
12 |
+
workspace,
|
13 |
+
"statistics",
|
14 |
+
task_name,
|
15 |
+
filename,
|
16 |
+
"config={},gpus={}".format(config, gpus),
|
17 |
+
"statistics.pkl",
|
18 |
+
)
|
19 |
+
|
20 |
+
stat_dict = pickle.load(open(stat_path, 'rb'))
|
21 |
+
|
22 |
+
median_sdrs = [e['median_sdr_dict'][source_type] for e in stat_dict['test']]
|
23 |
+
|
24 |
+
return median_sdrs
|
25 |
+
|
26 |
+
|
27 |
+
def plot_statistics(args):
|
28 |
+
|
29 |
+
# arguments & parameters
|
30 |
+
workspace = args.workspace
|
31 |
+
select = args.select
|
32 |
+
task_name = "musdb18"
|
33 |
+
filename = "train"
|
34 |
+
|
35 |
+
# paths
|
36 |
+
fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select))
|
37 |
+
os.makedirs(os.path.dirname(fig_path), exist_ok=True)
|
38 |
+
|
39 |
+
linewidth = 1
|
40 |
+
lines = []
|
41 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
|
42 |
+
|
43 |
+
if select == '1a':
|
44 |
+
sdrs = load_sdrs(
|
45 |
+
workspace,
|
46 |
+
task_name,
|
47 |
+
filename,
|
48 |
+
config='vocals-accompaniment,unet',
|
49 |
+
gpus=1,
|
50 |
+
source_type="vocals",
|
51 |
+
)
|
52 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
53 |
+
lines.append(line)
|
54 |
+
ylim = 15
|
55 |
+
|
56 |
+
elif select == '1b':
|
57 |
+
sdrs = load_sdrs(
|
58 |
+
workspace,
|
59 |
+
task_name,
|
60 |
+
filename,
|
61 |
+
config='accompaniment-vocals,unet',
|
62 |
+
gpus=1,
|
63 |
+
source_type="accompaniment",
|
64 |
+
)
|
65 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
66 |
+
lines.append(line)
|
67 |
+
ylim = 20
|
68 |
+
|
69 |
+
if select == '1c':
|
70 |
+
sdrs = load_sdrs(
|
71 |
+
workspace,
|
72 |
+
task_name,
|
73 |
+
filename,
|
74 |
+
config='vocals-accompaniment,unet',
|
75 |
+
gpus=1,
|
76 |
+
source_type="vocals",
|
77 |
+
)
|
78 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
79 |
+
lines.append(line)
|
80 |
+
|
81 |
+
sdrs = load_sdrs(
|
82 |
+
workspace,
|
83 |
+
task_name,
|
84 |
+
filename,
|
85 |
+
config='vocals-accompaniment,resunet',
|
86 |
+
gpus=2,
|
87 |
+
source_type="vocals",
|
88 |
+
)
|
89 |
+
(line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth)
|
90 |
+
lines.append(line)
|
91 |
+
|
92 |
+
sdrs = load_sdrs(
|
93 |
+
workspace,
|
94 |
+
task_name,
|
95 |
+
filename,
|
96 |
+
config='vocals-accompaniment,unet_subbandtime',
|
97 |
+
gpus=1,
|
98 |
+
source_type="vocals",
|
99 |
+
)
|
100 |
+
(line,) = ax.plot(sdrs, label='unet_subband,l1_wav', linewidth=linewidth)
|
101 |
+
lines.append(line)
|
102 |
+
|
103 |
+
sdrs = load_sdrs(
|
104 |
+
workspace,
|
105 |
+
task_name,
|
106 |
+
filename,
|
107 |
+
config='vocals-accompaniment,resunet_subbandtime',
|
108 |
+
gpus=1,
|
109 |
+
source_type="vocals",
|
110 |
+
)
|
111 |
+
(line,) = ax.plot(sdrs, label='resunet_subband,l1_wav', linewidth=linewidth)
|
112 |
+
lines.append(line)
|
113 |
+
|
114 |
+
ylim = 15
|
115 |
+
|
116 |
+
elif select == '1d':
|
117 |
+
sdrs = load_sdrs(
|
118 |
+
workspace,
|
119 |
+
task_name,
|
120 |
+
filename,
|
121 |
+
config='accompaniment-vocals,unet',
|
122 |
+
gpus=1,
|
123 |
+
source_type="accompaniment",
|
124 |
+
)
|
125 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
126 |
+
lines.append(line)
|
127 |
+
|
128 |
+
sdrs = load_sdrs(
|
129 |
+
workspace,
|
130 |
+
task_name,
|
131 |
+
filename,
|
132 |
+
config='accompaniment-vocals,resunet',
|
133 |
+
gpus=2,
|
134 |
+
source_type="accompaniment",
|
135 |
+
)
|
136 |
+
(line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth)
|
137 |
+
lines.append(line)
|
138 |
+
|
139 |
+
# sdrs = load_sdrs(
|
140 |
+
# workspace,
|
141 |
+
# task_name,
|
142 |
+
# filename,
|
143 |
+
# config='accompaniment-vocals,unet_subbandtime',
|
144 |
+
# gpus=1,
|
145 |
+
# source_type="accompaniment",
|
146 |
+
# )
|
147 |
+
# (line,) = ax.plot(sdrs, label='UNet_subbtandtime,l1_wav', linewidth=linewidth)
|
148 |
+
# lines.append(line)
|
149 |
+
|
150 |
+
sdrs = load_sdrs(
|
151 |
+
workspace,
|
152 |
+
task_name,
|
153 |
+
filename,
|
154 |
+
config='accompaniment-vocals,resunet_subbandtime',
|
155 |
+
gpus=1,
|
156 |
+
source_type="accompaniment",
|
157 |
+
)
|
158 |
+
(line,) = ax.plot(
|
159 |
+
sdrs, label='ResUNet_subbtandtime,l1_wav', linewidth=linewidth
|
160 |
+
)
|
161 |
+
lines.append(line)
|
162 |
+
|
163 |
+
ylim = 20
|
164 |
+
|
165 |
+
else:
|
166 |
+
raise Exception('Error!')
|
167 |
+
|
168 |
+
eval_every_iterations = 10000
|
169 |
+
total_ticks = 50
|
170 |
+
ticks_freq = 10
|
171 |
+
|
172 |
+
ax.set_ylim(0, ylim)
|
173 |
+
ax.set_xlim(0, total_ticks)
|
174 |
+
ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq))
|
175 |
+
ax.xaxis.set_ticklabels(
|
176 |
+
np.arange(
|
177 |
+
0,
|
178 |
+
total_ticks * eval_every_iterations + 1,
|
179 |
+
ticks_freq * eval_every_iterations,
|
180 |
+
)
|
181 |
+
)
|
182 |
+
ax.yaxis.set_ticks(np.arange(ylim + 1))
|
183 |
+
ax.yaxis.set_ticklabels(np.arange(ylim + 1))
|
184 |
+
ax.grid(color='b', linestyle='solid', linewidth=0.3)
|
185 |
+
plt.legend(handles=lines, loc=4)
|
186 |
+
|
187 |
+
plt.savefig(fig_path)
|
188 |
+
print('Save figure to {}'.format(fig_path))
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == '__main__':
|
192 |
+
parser = argparse.ArgumentParser()
|
193 |
+
parser.add_argument('--workspace', type=str, required=True)
|
194 |
+
parser.add_argument('--select', type=str, required=True)
|
195 |
+
|
196 |
+
args = parser.parse_args()
|
197 |
+
|
198 |
+
plot_statistics(args)
|
bytesep/plot_results/plot_vctk-musdb18.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import h5py
|
6 |
+
import math
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
import pickle
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
|
13 |
+
def load_sdrs(workspace, task_name, filename, config, gpus):
|
14 |
+
|
15 |
+
stat_path = os.path.join(
|
16 |
+
workspace,
|
17 |
+
"statistics",
|
18 |
+
task_name,
|
19 |
+
filename,
|
20 |
+
"config={},gpus={}".format(config, gpus),
|
21 |
+
"statistics.pkl",
|
22 |
+
)
|
23 |
+
|
24 |
+
stat_dict = pickle.load(open(stat_path, 'rb'))
|
25 |
+
|
26 |
+
median_sdrs = [e['sdr'] for e in stat_dict['test']]
|
27 |
+
|
28 |
+
return median_sdrs
|
29 |
+
|
30 |
+
|
31 |
+
def plot_statistics(args):
|
32 |
+
|
33 |
+
# arguments & parameters
|
34 |
+
workspace = args.workspace
|
35 |
+
select = args.select
|
36 |
+
task_name = "vctk-musdb18"
|
37 |
+
filename = "train"
|
38 |
+
|
39 |
+
# paths
|
40 |
+
fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select))
|
41 |
+
os.makedirs(os.path.dirname(fig_path), exist_ok=True)
|
42 |
+
|
43 |
+
linewidth = 1
|
44 |
+
lines = []
|
45 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
|
46 |
+
ylim = 30
|
47 |
+
expand = 1
|
48 |
+
|
49 |
+
if select == '1a':
|
50 |
+
sdrs = load_sdrs(workspace, task_name, filename, config='unet', gpus=1)
|
51 |
+
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
|
52 |
+
lines.append(line)
|
53 |
+
|
54 |
+
else:
|
55 |
+
raise Exception('Error!')
|
56 |
+
|
57 |
+
eval_every_iterations = 10000
|
58 |
+
total_ticks = 50
|
59 |
+
ticks_freq = 10
|
60 |
+
|
61 |
+
ax.set_ylim(0, ylim)
|
62 |
+
ax.set_xlim(0, total_ticks)
|
63 |
+
ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq))
|
64 |
+
ax.xaxis.set_ticklabels(
|
65 |
+
np.arange(
|
66 |
+
0,
|
67 |
+
total_ticks * eval_every_iterations + 1,
|
68 |
+
ticks_freq * eval_every_iterations,
|
69 |
+
)
|
70 |
+
)
|
71 |
+
ax.yaxis.set_ticks(np.arange(ylim + 1))
|
72 |
+
ax.yaxis.set_ticklabels(np.arange(ylim + 1))
|
73 |
+
ax.grid(color='b', linestyle='solid', linewidth=0.3)
|
74 |
+
plt.legend(handles=lines, loc=4)
|
75 |
+
|
76 |
+
plt.savefig(fig_path)
|
77 |
+
print('Save figure to {}'.format(fig_path))
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
parser.add_argument('--workspace', type=str, required=True)
|
83 |
+
parser.add_argument('--select', type=str, required=True)
|
84 |
+
|
85 |
+
args = parser.parse_args()
|
86 |
+
|
87 |
+
plot_statistics(args)
|
bytesep/train.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
from functools import partial
|
6 |
+
from typing import List, NoReturn
|
7 |
+
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from pytorch_lightning.plugins import DDPPlugin
|
10 |
+
|
11 |
+
from bytesep.callbacks import get_callbacks
|
12 |
+
from bytesep.data.augmentors import Augmentor
|
13 |
+
from bytesep.data.batch_data_preprocessors import (
|
14 |
+
get_batch_data_preprocessor_class,
|
15 |
+
)
|
16 |
+
from bytesep.data.data_modules import DataModule, Dataset
|
17 |
+
from bytesep.data.samplers import SegmentSampler
|
18 |
+
from bytesep.losses import get_loss_function
|
19 |
+
from bytesep.models.lightning_modules import (
|
20 |
+
LitSourceSeparation,
|
21 |
+
get_model_class,
|
22 |
+
)
|
23 |
+
from bytesep.optimizers.lr_schedulers import get_lr_lambda
|
24 |
+
from bytesep.utils import (
|
25 |
+
create_logging,
|
26 |
+
get_pitch_shift_factor,
|
27 |
+
read_yaml,
|
28 |
+
check_configs_gramma,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def get_dirs(
|
33 |
+
workspace: str, task_name: str, filename: str, config_yaml: str, gpus: int
|
34 |
+
) -> List[str]:
|
35 |
+
r"""Get directories.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
workspace: str
|
39 |
+
task_name, str, e.g., 'musdb18'
|
40 |
+
filenmae: str
|
41 |
+
config_yaml: str
|
42 |
+
gpus: int, e.g., 0 for cpu and 8 for training with 8 gpu cards
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
checkpoints_dir: str
|
46 |
+
logs_dir: str
|
47 |
+
logger: pl.loggers.TensorBoardLogger
|
48 |
+
statistics_path: str
|
49 |
+
"""
|
50 |
+
|
51 |
+
# save checkpoints dir
|
52 |
+
checkpoints_dir = os.path.join(
|
53 |
+
workspace,
|
54 |
+
"checkpoints",
|
55 |
+
task_name,
|
56 |
+
filename,
|
57 |
+
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
|
58 |
+
)
|
59 |
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
60 |
+
|
61 |
+
# logs dir
|
62 |
+
logs_dir = os.path.join(
|
63 |
+
workspace,
|
64 |
+
"logs",
|
65 |
+
task_name,
|
66 |
+
filename,
|
67 |
+
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
|
68 |
+
)
|
69 |
+
os.makedirs(logs_dir, exist_ok=True)
|
70 |
+
|
71 |
+
# loggings
|
72 |
+
create_logging(logs_dir, filemode='w')
|
73 |
+
logging.info(args)
|
74 |
+
|
75 |
+
# tensorboard logs dir
|
76 |
+
tb_logs_dir = os.path.join(workspace, "tensorboard_logs")
|
77 |
+
os.makedirs(tb_logs_dir, exist_ok=True)
|
78 |
+
|
79 |
+
experiment_name = os.path.join(task_name, filename, pathlib.Path(config_yaml).stem)
|
80 |
+
logger = pl.loggers.TensorBoardLogger(save_dir=tb_logs_dir, name=experiment_name)
|
81 |
+
|
82 |
+
# statistics path
|
83 |
+
statistics_path = os.path.join(
|
84 |
+
workspace,
|
85 |
+
"statistics",
|
86 |
+
task_name,
|
87 |
+
filename,
|
88 |
+
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
|
89 |
+
"statistics.pkl",
|
90 |
+
)
|
91 |
+
os.makedirs(os.path.dirname(statistics_path), exist_ok=True)
|
92 |
+
|
93 |
+
return checkpoints_dir, logs_dir, logger, statistics_path
|
94 |
+
|
95 |
+
|
96 |
+
def _get_data_module(
|
97 |
+
workspace: str, config_yaml: str, num_workers: int, distributed: bool
|
98 |
+
) -> DataModule:
|
99 |
+
r"""Create data_module. Mini-batch data can be obtained by:
|
100 |
+
|
101 |
+
code-block:: python
|
102 |
+
|
103 |
+
data_module.setup()
|
104 |
+
for batch_data_dict in data_module.train_dataloader():
|
105 |
+
print(batch_data_dict.keys())
|
106 |
+
break
|
107 |
+
|
108 |
+
Args:
|
109 |
+
workspace: str
|
110 |
+
config_yaml: str
|
111 |
+
num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores
|
112 |
+
for preparing data in parallel
|
113 |
+
distributed: bool
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
data_module: DataModule
|
117 |
+
"""
|
118 |
+
|
119 |
+
configs = read_yaml(config_yaml)
|
120 |
+
input_source_types = configs['train']['input_source_types']
|
121 |
+
indexes_path = os.path.join(workspace, configs['train']['indexes_dict'])
|
122 |
+
sample_rate = configs['train']['sample_rate']
|
123 |
+
segment_seconds = configs['train']['segment_seconds']
|
124 |
+
mixaudio_dict = configs['train']['augmentations']['mixaudio']
|
125 |
+
augmentations = configs['train']['augmentations']
|
126 |
+
max_pitch_shift = max(
|
127 |
+
[
|
128 |
+
augmentations['pitch_shift'][source_type]
|
129 |
+
for source_type in input_source_types
|
130 |
+
]
|
131 |
+
)
|
132 |
+
batch_size = configs['train']['batch_size']
|
133 |
+
steps_per_epoch = configs['train']['steps_per_epoch']
|
134 |
+
|
135 |
+
segment_samples = int(segment_seconds * sample_rate)
|
136 |
+
ex_segment_samples = int(segment_samples * get_pitch_shift_factor(max_pitch_shift))
|
137 |
+
|
138 |
+
# sampler
|
139 |
+
train_sampler = SegmentSampler(
|
140 |
+
indexes_path=indexes_path,
|
141 |
+
segment_samples=ex_segment_samples,
|
142 |
+
mixaudio_dict=mixaudio_dict,
|
143 |
+
batch_size=batch_size,
|
144 |
+
steps_per_epoch=steps_per_epoch,
|
145 |
+
)
|
146 |
+
|
147 |
+
# augmentor
|
148 |
+
augmentor = Augmentor(augmentations=augmentations)
|
149 |
+
|
150 |
+
# dataset
|
151 |
+
train_dataset = Dataset(augmentor, segment_samples)
|
152 |
+
|
153 |
+
# data module
|
154 |
+
data_module = DataModule(
|
155 |
+
train_sampler=train_sampler,
|
156 |
+
train_dataset=train_dataset,
|
157 |
+
num_workers=num_workers,
|
158 |
+
distributed=distributed,
|
159 |
+
)
|
160 |
+
|
161 |
+
return data_module
|
162 |
+
|
163 |
+
|
164 |
+
def train(args) -> NoReturn:
|
165 |
+
r"""Train & evaluate and save checkpoints.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
workspace: str, directory of workspace
|
169 |
+
gpus: int
|
170 |
+
config_yaml: str, path of config file for training
|
171 |
+
"""
|
172 |
+
|
173 |
+
# arugments & parameters
|
174 |
+
workspace = args.workspace
|
175 |
+
gpus = args.gpus
|
176 |
+
config_yaml = args.config_yaml
|
177 |
+
filename = args.filename
|
178 |
+
|
179 |
+
num_workers = 8
|
180 |
+
distributed = True if gpus > 1 else False
|
181 |
+
evaluate_device = "cuda" if gpus > 0 else "cpu"
|
182 |
+
|
183 |
+
# Read config file.
|
184 |
+
configs = read_yaml(config_yaml)
|
185 |
+
check_configs_gramma(configs)
|
186 |
+
task_name = configs['task_name']
|
187 |
+
target_source_types = configs['train']['target_source_types']
|
188 |
+
target_sources_num = len(target_source_types)
|
189 |
+
channels = configs['train']['channels']
|
190 |
+
batch_data_preprocessor_type = configs['train']['batch_data_preprocessor']
|
191 |
+
model_type = configs['train']['model_type']
|
192 |
+
loss_type = configs['train']['loss_type']
|
193 |
+
optimizer_type = configs['train']['optimizer_type']
|
194 |
+
learning_rate = float(configs['train']['learning_rate'])
|
195 |
+
precision = configs['train']['precision']
|
196 |
+
early_stop_steps = configs['train']['early_stop_steps']
|
197 |
+
warm_up_steps = configs['train']['warm_up_steps']
|
198 |
+
reduce_lr_steps = configs['train']['reduce_lr_steps']
|
199 |
+
|
200 |
+
# paths
|
201 |
+
checkpoints_dir, logs_dir, logger, statistics_path = get_dirs(
|
202 |
+
workspace, task_name, filename, config_yaml, gpus
|
203 |
+
)
|
204 |
+
|
205 |
+
# training data module
|
206 |
+
data_module = _get_data_module(
|
207 |
+
workspace=workspace,
|
208 |
+
config_yaml=config_yaml,
|
209 |
+
num_workers=num_workers,
|
210 |
+
distributed=distributed,
|
211 |
+
)
|
212 |
+
|
213 |
+
# batch data preprocessor
|
214 |
+
BatchDataPreprocessor = get_batch_data_preprocessor_class(
|
215 |
+
batch_data_preprocessor_type=batch_data_preprocessor_type
|
216 |
+
)
|
217 |
+
|
218 |
+
batch_data_preprocessor = BatchDataPreprocessor(
|
219 |
+
target_source_types=target_source_types
|
220 |
+
)
|
221 |
+
|
222 |
+
# model
|
223 |
+
Model = get_model_class(model_type=model_type)
|
224 |
+
model = Model(input_channels=channels, target_sources_num=target_sources_num)
|
225 |
+
|
226 |
+
# loss function
|
227 |
+
loss_function = get_loss_function(loss_type=loss_type)
|
228 |
+
|
229 |
+
# callbacks
|
230 |
+
callbacks = get_callbacks(
|
231 |
+
task_name=task_name,
|
232 |
+
config_yaml=config_yaml,
|
233 |
+
workspace=workspace,
|
234 |
+
checkpoints_dir=checkpoints_dir,
|
235 |
+
statistics_path=statistics_path,
|
236 |
+
logger=logger,
|
237 |
+
model=model,
|
238 |
+
evaluate_device=evaluate_device,
|
239 |
+
)
|
240 |
+
# callbacks = []
|
241 |
+
|
242 |
+
# learning rate reduce function
|
243 |
+
lr_lambda = partial(
|
244 |
+
get_lr_lambda, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps
|
245 |
+
)
|
246 |
+
|
247 |
+
# pytorch-lightning model
|
248 |
+
pl_model = LitSourceSeparation(
|
249 |
+
batch_data_preprocessor=batch_data_preprocessor,
|
250 |
+
model=model,
|
251 |
+
optimizer_type=optimizer_type,
|
252 |
+
loss_function=loss_function,
|
253 |
+
learning_rate=learning_rate,
|
254 |
+
lr_lambda=lr_lambda,
|
255 |
+
)
|
256 |
+
|
257 |
+
# trainer
|
258 |
+
trainer = pl.Trainer(
|
259 |
+
checkpoint_callback=False,
|
260 |
+
gpus=gpus,
|
261 |
+
callbacks=callbacks,
|
262 |
+
max_steps=early_stop_steps,
|
263 |
+
accelerator="ddp",
|
264 |
+
sync_batchnorm=True,
|
265 |
+
precision=precision,
|
266 |
+
replace_sampler_ddp=False,
|
267 |
+
plugins=[DDPPlugin(find_unused_parameters=True)],
|
268 |
+
profiler='simple',
|
269 |
+
)
|
270 |
+
|
271 |
+
# Fit, evaluate, and save checkpoints.
|
272 |
+
trainer.fit(pl_model, data_module)
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
|
277 |
+
parser = argparse.ArgumentParser(description="")
|
278 |
+
subparsers = parser.add_subparsers(dest="mode")
|
279 |
+
|
280 |
+
parser_train = subparsers.add_parser("train")
|
281 |
+
parser_train.add_argument(
|
282 |
+
"--workspace", type=str, required=True, help="Directory of workspace."
|
283 |
+
)
|
284 |
+
parser_train.add_argument("--gpus", type=int, required=True)
|
285 |
+
parser_train.add_argument(
|
286 |
+
"--config_yaml",
|
287 |
+
type=str,
|
288 |
+
required=True,
|
289 |
+
help="Path of config file for training.",
|
290 |
+
)
|
291 |
+
|
292 |
+
args = parser.parse_args()
|
293 |
+
args.filename = pathlib.Path(__file__).stem
|
294 |
+
|
295 |
+
if args.mode == "train":
|
296 |
+
train(args)
|
297 |
+
|
298 |
+
else:
|
299 |
+
raise Exception("Error argument!")
|
bytesep/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
from typing import Dict, NoReturn
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import yaml
|
10 |
+
|
11 |
+
|
12 |
+
def create_logging(log_dir: str, filemode: str) -> logging:
|
13 |
+
r"""Create logging to write out log files.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
logs_dir, str, directory to write out logs
|
17 |
+
filemode: str, e.g., "w"
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
logging
|
21 |
+
"""
|
22 |
+
os.makedirs(log_dir, exist_ok=True)
|
23 |
+
i1 = 0
|
24 |
+
|
25 |
+
while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
|
26 |
+
i1 += 1
|
27 |
+
|
28 |
+
log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
|
29 |
+
logging.basicConfig(
|
30 |
+
level=logging.DEBUG,
|
31 |
+
format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
|
32 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
33 |
+
filename=log_path,
|
34 |
+
filemode=filemode,
|
35 |
+
)
|
36 |
+
|
37 |
+
# Print to console
|
38 |
+
console = logging.StreamHandler()
|
39 |
+
console.setLevel(logging.INFO)
|
40 |
+
formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
|
41 |
+
console.setFormatter(formatter)
|
42 |
+
logging.getLogger("").addHandler(console)
|
43 |
+
|
44 |
+
return logging
|
45 |
+
|
46 |
+
|
47 |
+
def load_audio(
|
48 |
+
audio_path: str,
|
49 |
+
mono: bool,
|
50 |
+
sample_rate: float,
|
51 |
+
offset: float = 0.0,
|
52 |
+
duration: float = None,
|
53 |
+
) -> np.array:
|
54 |
+
r"""Load audio.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
audio_path: str
|
58 |
+
mono: bool
|
59 |
+
sample_rate: float
|
60 |
+
"""
|
61 |
+
audio, _ = librosa.core.load(
|
62 |
+
audio_path, sr=sample_rate, mono=mono, offset=offset, duration=duration
|
63 |
+
)
|
64 |
+
# (audio_samples,) | (channels_num, audio_samples)
|
65 |
+
|
66 |
+
if audio.ndim == 1:
|
67 |
+
audio = audio[None, :]
|
68 |
+
# (1, audio_samples,)
|
69 |
+
|
70 |
+
return audio
|
71 |
+
|
72 |
+
|
73 |
+
def load_random_segment(
|
74 |
+
audio_path: str, random_state, segment_seconds: float, mono: bool, sample_rate: int
|
75 |
+
) -> np.array:
|
76 |
+
r"""Randomly select an audio segment from a recording."""
|
77 |
+
|
78 |
+
duration = librosa.get_duration(filename=audio_path)
|
79 |
+
|
80 |
+
start_time = random_state.uniform(0.0, duration - segment_seconds)
|
81 |
+
|
82 |
+
audio = load_audio(
|
83 |
+
audio_path=audio_path,
|
84 |
+
mono=mono,
|
85 |
+
sample_rate=sample_rate,
|
86 |
+
offset=start_time,
|
87 |
+
duration=segment_seconds,
|
88 |
+
)
|
89 |
+
# (channels_num, audio_samples)
|
90 |
+
|
91 |
+
return audio
|
92 |
+
|
93 |
+
|
94 |
+
def float32_to_int16(x: np.float32) -> np.int16:
|
95 |
+
|
96 |
+
x = np.clip(x, a_min=-1, a_max=1)
|
97 |
+
|
98 |
+
return (x * 32767.0).astype(np.int16)
|
99 |
+
|
100 |
+
|
101 |
+
def int16_to_float32(x: np.int16) -> np.float32:
|
102 |
+
|
103 |
+
return (x / 32767.0).astype(np.float32)
|
104 |
+
|
105 |
+
|
106 |
+
def read_yaml(config_yaml: str):
|
107 |
+
|
108 |
+
with open(config_yaml, "r") as fr:
|
109 |
+
configs = yaml.load(fr, Loader=yaml.FullLoader)
|
110 |
+
|
111 |
+
return configs
|
112 |
+
|
113 |
+
|
114 |
+
def check_configs_gramma(configs: Dict) -> NoReturn:
|
115 |
+
r"""Check if the gramma of the config dictionary for training is legal."""
|
116 |
+
input_source_types = configs['train']['input_source_types']
|
117 |
+
|
118 |
+
for augmentation_type in configs['train']['augmentations'].keys():
|
119 |
+
augmentation_dict = configs['train']['augmentations'][augmentation_type]
|
120 |
+
|
121 |
+
for source_type in augmentation_dict.keys():
|
122 |
+
if source_type not in input_source_types:
|
123 |
+
error_msg = (
|
124 |
+
"The source type '{}'' in configs['train']['augmentations']['{}'] "
|
125 |
+
"must be one of input_source_types {}".format(
|
126 |
+
source_type, augmentation_type, input_source_types
|
127 |
+
)
|
128 |
+
)
|
129 |
+
raise Exception(error_msg)
|
130 |
+
|
131 |
+
|
132 |
+
def magnitude_to_db(x: float) -> float:
|
133 |
+
eps = 1e-10
|
134 |
+
return 20.0 * np.log10(max(x, eps))
|
135 |
+
|
136 |
+
|
137 |
+
def db_to_magnitude(x: float) -> float:
|
138 |
+
return 10.0 ** (x / 20)
|
139 |
+
|
140 |
+
|
141 |
+
def get_pitch_shift_factor(shift_pitch: float) -> float:
|
142 |
+
r"""The factor of the audio length to be scaled."""
|
143 |
+
return 2 ** (shift_pitch / 12)
|
144 |
+
|
145 |
+
|
146 |
+
class StatisticsContainer(object):
|
147 |
+
def __init__(self, statistics_path):
|
148 |
+
self.statistics_path = statistics_path
|
149 |
+
|
150 |
+
self.backup_statistics_path = "{}_{}.pkl".format(
|
151 |
+
os.path.splitext(self.statistics_path)[0],
|
152 |
+
datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
|
153 |
+
)
|
154 |
+
|
155 |
+
self.statistics_dict = {"train": [], "test": []}
|
156 |
+
|
157 |
+
def append(self, steps, statistics, split):
|
158 |
+
statistics["steps"] = steps
|
159 |
+
self.statistics_dict[split].append(statistics)
|
160 |
+
|
161 |
+
def dump(self):
|
162 |
+
pickle.dump(self.statistics_dict, open(self.statistics_path, "wb"))
|
163 |
+
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb"))
|
164 |
+
logging.info(" Dump statistics to {}".format(self.statistics_path))
|
165 |
+
logging.info(" Dump statistics to {}".format(self.backup_statistics_path))
|
166 |
+
|
167 |
+
'''
|
168 |
+
def load_state_dict(self, resume_steps):
|
169 |
+
self.statistics_dict = pickle.load(open(self.statistics_path, "rb"))
|
170 |
+
|
171 |
+
resume_statistics_dict = {"train": [], "test": []}
|
172 |
+
|
173 |
+
for key in self.statistics_dict.keys():
|
174 |
+
for statistics in self.statistics_dict[key]:
|
175 |
+
if statistics["steps"] <= resume_steps:
|
176 |
+
resume_statistics_dict[key].append(statistics)
|
177 |
+
|
178 |
+
self.statistics_dict = resume_statistics_dict
|
179 |
+
'''
|
180 |
+
|
181 |
+
|
182 |
+
def calculate_sdr(ref: np.array, est: np.array) -> float:
|
183 |
+
s_true = ref
|
184 |
+
s_artif = est - ref
|
185 |
+
sdr = 10.0 * (
|
186 |
+
np.log10(np.clip(np.mean(s_true ** 2), 1e-8, np.inf))
|
187 |
+
- np.log10(np.clip(np.mean(s_artif ** 2), 1e-8, np.inf))
|
188 |
+
)
|
189 |
+
return sdr
|
pyproject.toml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 88
|
3 |
+
target-version = ['py37']
|
4 |
+
skip-string-normalization = true
|
5 |
+
include = '\.pyi?$'
|
6 |
+
exclude = '''
|
7 |
+
(
|
8 |
+
/(
|
9 |
+
\.eggs # exclude a few common directories in the
|
10 |
+
| \.git # root of the project
|
11 |
+
| \.hg
|
12 |
+
| \.mypy_cache
|
13 |
+
| \.tox
|
14 |
+
| \.venv
|
15 |
+
| _build
|
16 |
+
| buck-out
|
17 |
+
| build
|
18 |
+
| dist
|
19 |
+
)/
|
20 |
+
)
|
21 |
+
'''
|