akhaliq3 commited on
Commit
5019931
·
1 Parent(s): bee7f54

spaces demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +13 -0
  2. bytesep/__init__.py +1 -0
  3. bytesep/callbacks/__init__.py +76 -0
  4. bytesep/callbacks/base_callbacks.py +44 -0
  5. bytesep/callbacks/instruments_callbacks.py +200 -0
  6. bytesep/callbacks/musdb18.py +485 -0
  7. bytesep/callbacks/voicebank_demand.py +231 -0
  8. bytesep/data/__init__.py +0 -0
  9. bytesep/data/augmentors.py +157 -0
  10. bytesep/data/batch_data_preprocessors.py +141 -0
  11. bytesep/data/data_modules.py +187 -0
  12. bytesep/data/samplers.py +188 -0
  13. bytesep/dataset_creation/__init__.py +0 -0
  14. bytesep/dataset_creation/create_evaluation_audios/__init__.py +0 -0
  15. bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py +160 -0
  16. bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py +164 -0
  17. bytesep/dataset_creation/create_evaluation_audios/violin-piano.py +162 -0
  18. bytesep/dataset_creation/create_indexes/__init__.py +0 -0
  19. bytesep/dataset_creation/create_indexes/create_indexes.py +142 -0
  20. bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py +0 -0
  21. bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py +173 -0
  22. bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py +136 -0
  23. bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py +207 -0
  24. bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py +114 -0
  25. bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py +143 -0
  26. bytesep/inference.py +402 -0
  27. bytesep/inference_many.py +163 -0
  28. bytesep/losses.py +106 -0
  29. bytesep/models/__init__.py +0 -0
  30. bytesep/models/conditional_unet.py +496 -0
  31. bytesep/models/lightning_modules.py +188 -0
  32. bytesep/models/pytorch_modules.py +204 -0
  33. bytesep/models/resunet.py +516 -0
  34. bytesep/models/resunet_ismir2021.py +534 -0
  35. bytesep/models/resunet_subbandtime.py +545 -0
  36. bytesep/models/subband_tools/__init__.py +0 -0
  37. bytesep/models/subband_tools/fDomainHelper.py +255 -0
  38. bytesep/models/subband_tools/filters/f_4_64.mat +0 -0
  39. bytesep/models/subband_tools/filters/h_4_64.mat +0 -0
  40. bytesep/models/subband_tools/pqmf.py +136 -0
  41. bytesep/models/unet.py +532 -0
  42. bytesep/models/unet_subbandtime.py +389 -0
  43. bytesep/optimizers/__init__.py +0 -0
  44. bytesep/optimizers/lr_schedulers.py +20 -0
  45. bytesep/plot_results/__init__.py +0 -0
  46. bytesep/plot_results/musdb18.py +198 -0
  47. bytesep/plot_results/plot_vctk-musdb18.py +87 -0
  48. bytesep/train.py +299 -0
  49. bytesep/utils.py +189 -0
  50. pyproject.toml +21 -0
LICENSE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2021 ByteDance
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
bytesep/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from bytesep.inference import Separator
bytesep/callbacks/__init__.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import pytorch_lightning as pl
4
+ import torch.nn as nn
5
+
6
+
7
+ def get_callbacks(
8
+ task_name: str,
9
+ config_yaml: str,
10
+ workspace: str,
11
+ checkpoints_dir: str,
12
+ statistics_path: str,
13
+ logger: pl.loggers.TensorBoardLogger,
14
+ model: nn.Module,
15
+ evaluate_device: str,
16
+ ) -> List[pl.Callback]:
17
+ r"""Get callbacks of a task and config yaml file.
18
+
19
+ Args:
20
+ task_name: str
21
+ config_yaml: str
22
+ dataset_dir: str
23
+ workspace: str, containing useful files such as audios for evaluation
24
+ checkpoints_dir: str, directory to save checkpoints
25
+ statistics_dir: str, directory to save statistics
26
+ logger: pl.loggers.TensorBoardLogger
27
+ model: nn.Module
28
+ evaluate_device: str
29
+
30
+ Return:
31
+ callbacks: List[pl.Callback]
32
+ """
33
+ if task_name == 'musdb18':
34
+
35
+ from bytesep.callbacks.musdb18 import get_musdb18_callbacks
36
+
37
+ return get_musdb18_callbacks(
38
+ config_yaml=config_yaml,
39
+ workspace=workspace,
40
+ checkpoints_dir=checkpoints_dir,
41
+ statistics_path=statistics_path,
42
+ logger=logger,
43
+ model=model,
44
+ evaluate_device=evaluate_device,
45
+ )
46
+
47
+ elif task_name == 'voicebank-demand':
48
+
49
+ from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks
50
+
51
+ return get_voicebank_demand_callbacks(
52
+ config_yaml=config_yaml,
53
+ workspace=workspace,
54
+ checkpoints_dir=checkpoints_dir,
55
+ statistics_path=statistics_path,
56
+ logger=logger,
57
+ model=model,
58
+ evaluate_device=evaluate_device,
59
+ )
60
+
61
+ elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']:
62
+
63
+ from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks
64
+
65
+ return get_instruments_callbacks(
66
+ config_yaml=config_yaml,
67
+ workspace=workspace,
68
+ checkpoints_dir=checkpoints_dir,
69
+ statistics_path=statistics_path,
70
+ logger=logger,
71
+ model=model,
72
+ evaluate_device=evaluate_device,
73
+ )
74
+
75
+ else:
76
+ raise NotImplementedError
bytesep/callbacks/base_callbacks.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import NoReturn
4
+
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import torch.nn as nn
8
+ from pytorch_lightning.utilities import rank_zero_only
9
+
10
+
11
+ class SaveCheckpointsCallback(pl.Callback):
12
+ def __init__(
13
+ self,
14
+ model: nn.Module,
15
+ checkpoints_dir: str,
16
+ save_step_frequency: int,
17
+ ):
18
+ r"""Callback to save checkpoints every #save_step_frequency steps.
19
+
20
+ Args:
21
+ model: nn.Module
22
+ checkpoints_dir: str, directory to save checkpoints
23
+ save_step_frequency: int
24
+ """
25
+ self.model = model
26
+ self.checkpoints_dir = checkpoints_dir
27
+ self.save_step_frequency = save_step_frequency
28
+ os.makedirs(self.checkpoints_dir, exist_ok=True)
29
+
30
+ @rank_zero_only
31
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
32
+ r"""Save checkpoint."""
33
+ global_step = trainer.global_step
34
+
35
+ if global_step % self.save_step_frequency == 0:
36
+
37
+ checkpoint_path = os.path.join(
38
+ self.checkpoints_dir, "step={}.pth".format(global_step)
39
+ )
40
+
41
+ checkpoint = {'step': global_step, 'model': self.model.state_dict()}
42
+
43
+ torch.save(checkpoint, checkpoint_path)
44
+ logging.info("Save checkpoint to {}".format(checkpoint_path))
bytesep/callbacks/instruments_callbacks.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import List, NoReturn
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import pytorch_lightning as pl
9
+ import torch.nn as nn
10
+ from pytorch_lightning.utilities import rank_zero_only
11
+
12
+ from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
13
+ from bytesep.inference import Separator
14
+ from bytesep.utils import StatisticsContainer, calculate_sdr, read_yaml
15
+
16
+
17
+ def get_instruments_callbacks(
18
+ config_yaml: str,
19
+ workspace: str,
20
+ checkpoints_dir: str,
21
+ statistics_path: str,
22
+ logger: pl.loggers.TensorBoardLogger,
23
+ model: nn.Module,
24
+ evaluate_device: str,
25
+ ) -> List[pl.Callback]:
26
+ """Get Voicebank-Demand callbacks of a config yaml.
27
+
28
+ Args:
29
+ config_yaml: str
30
+ workspace: str
31
+ checkpoints_dir: str, directory to save checkpoints
32
+ statistics_dir: str, directory to save statistics
33
+ logger: pl.loggers.TensorBoardLogger
34
+ model: nn.Module
35
+ evaluate_device: str
36
+
37
+ Return:
38
+ callbacks: List[pl.Callback]
39
+ """
40
+ configs = read_yaml(config_yaml)
41
+ task_name = configs['task_name']
42
+ target_source_types = configs['train']['target_source_types']
43
+ input_channels = configs['train']['channels']
44
+ mono = True if input_channels == 1 else False
45
+ test_audios_dir = os.path.join(workspace, "evaluation_audios", task_name, "test")
46
+ sample_rate = configs['train']['sample_rate']
47
+ evaluate_step_frequency = configs['train']['evaluate_step_frequency']
48
+ save_step_frequency = configs['train']['save_step_frequency']
49
+ test_batch_size = configs['evaluate']['batch_size']
50
+ test_segment_seconds = configs['evaluate']['segment_seconds']
51
+
52
+ test_segment_samples = int(test_segment_seconds * sample_rate)
53
+ assert len(target_source_types) == 1
54
+ target_source_type = target_source_types[0]
55
+
56
+ # save checkpoint callback
57
+ save_checkpoints_callback = SaveCheckpointsCallback(
58
+ model=model,
59
+ checkpoints_dir=checkpoints_dir,
60
+ save_step_frequency=save_step_frequency,
61
+ )
62
+
63
+ # statistics container
64
+ statistics_container = StatisticsContainer(statistics_path)
65
+
66
+ # evaluation callback
67
+ evaluate_test_callback = EvaluationCallback(
68
+ model=model,
69
+ target_source_type=target_source_type,
70
+ input_channels=input_channels,
71
+ sample_rate=sample_rate,
72
+ mono=mono,
73
+ evaluation_audios_dir=test_audios_dir,
74
+ segment_samples=test_segment_samples,
75
+ batch_size=test_batch_size,
76
+ device=evaluate_device,
77
+ evaluate_step_frequency=evaluate_step_frequency,
78
+ logger=logger,
79
+ statistics_container=statistics_container,
80
+ )
81
+
82
+ callbacks = [save_checkpoints_callback, evaluate_test_callback]
83
+ # callbacks = [save_checkpoints_callback]
84
+
85
+ return callbacks
86
+
87
+
88
+ class EvaluationCallback(pl.Callback):
89
+ def __init__(
90
+ self,
91
+ model: nn.Module,
92
+ input_channels: int,
93
+ evaluation_audios_dir: str,
94
+ target_source_type: str,
95
+ sample_rate: int,
96
+ mono: bool,
97
+ segment_samples: int,
98
+ batch_size: int,
99
+ device: str,
100
+ evaluate_step_frequency: int,
101
+ logger: pl.loggers.TensorBoardLogger,
102
+ statistics_container: StatisticsContainer,
103
+ ):
104
+ r"""Callback to evaluate every #save_step_frequency steps.
105
+
106
+ Args:
107
+ model: nn.Module
108
+ input_channels: int
109
+ evaluation_audios_dir: str, directory containing audios for evaluation
110
+ target_source_type: str, e.g., 'violin'
111
+ sample_rate: int
112
+ mono: bool
113
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
114
+ batch_size, int, e.g., 12
115
+ device: str, e.g., 'cuda'
116
+ evaluate_step_frequency: int, evaluate every #save_step_frequency steps
117
+ logger: pl.loggers.TensorBoardLogger
118
+ statistics_container: StatisticsContainer
119
+ """
120
+ self.model = model
121
+ self.target_source_type = target_source_type
122
+ self.sample_rate = sample_rate
123
+ self.mono = mono
124
+ self.segment_samples = segment_samples
125
+ self.evaluate_step_frequency = evaluate_step_frequency
126
+ self.logger = logger
127
+ self.statistics_container = statistics_container
128
+
129
+ self.evaluation_audios_dir = evaluation_audios_dir
130
+
131
+ # separator
132
+ self.separator = Separator(model, self.segment_samples, batch_size, device)
133
+
134
+ @rank_zero_only
135
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
136
+ r"""Evaluate losses on a few mini-batches. Losses are only used for
137
+ observing training, and are not final F1 metrics.
138
+ """
139
+
140
+ global_step = trainer.global_step
141
+
142
+ if global_step % self.evaluate_step_frequency == 0:
143
+
144
+ mixture_audios_dir = os.path.join(self.evaluation_audios_dir, 'mixture')
145
+ clean_audios_dir = os.path.join(
146
+ self.evaluation_audios_dir, self.target_source_type
147
+ )
148
+
149
+ audio_names = sorted(os.listdir(mixture_audios_dir))
150
+
151
+ error_str = "Directory {} does not contain audios for evaluation!".format(
152
+ self.evaluation_audios_dir
153
+ )
154
+ assert len(audio_names) > 0, error_str
155
+
156
+ logging.info("--- Step {} ---".format(global_step))
157
+ logging.info("Total {} pieces for evaluation:".format(len(audio_names)))
158
+
159
+ eval_time = time.time()
160
+
161
+ sdrs = []
162
+
163
+ for n, audio_name in enumerate(audio_names):
164
+
165
+ # Load audio.
166
+ mixture_path = os.path.join(mixture_audios_dir, audio_name)
167
+ clean_path = os.path.join(clean_audios_dir, audio_name)
168
+
169
+ mixture, origin_fs = librosa.core.load(
170
+ mixture_path, sr=self.sample_rate, mono=self.mono
171
+ )
172
+
173
+ # Target
174
+ clean, origin_fs = librosa.core.load(
175
+ clean_path, sr=self.sample_rate, mono=self.mono
176
+ )
177
+
178
+ if mixture.ndim == 1:
179
+ mixture = mixture[None, :]
180
+ # (channels_num, audio_length)
181
+
182
+ input_dict = {'waveform': mixture}
183
+
184
+ # separate
185
+ sep_wav = self.separator.separate(input_dict)
186
+ # (channels_num, audio_length)
187
+
188
+ sdr = calculate_sdr(ref=clean, est=sep_wav)
189
+
190
+ print("{} SDR: {:.3f}".format(audio_name, sdr))
191
+ sdrs.append(sdr)
192
+
193
+ logging.info("-----------------------------")
194
+ logging.info('Avg SDR: {:.3f}'.format(np.mean(sdrs)))
195
+
196
+ logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
197
+
198
+ statistics = {"sdr": np.mean(sdrs)}
199
+ self.statistics_container.append(global_step, statistics, 'test')
200
+ self.statistics_container.dump()
bytesep/callbacks/musdb18.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import Dict, List, NoReturn
5
+
6
+ import librosa
7
+ import musdb
8
+ import museval
9
+ import numpy as np
10
+ import pytorch_lightning as pl
11
+ import torch.nn as nn
12
+ from pytorch_lightning.utilities import rank_zero_only
13
+
14
+ from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
15
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio
16
+ from bytesep.inference import Separator
17
+ from bytesep.utils import StatisticsContainer, read_yaml
18
+
19
+
20
+ def get_musdb18_callbacks(
21
+ config_yaml: str,
22
+ workspace: str,
23
+ checkpoints_dir: str,
24
+ statistics_path: str,
25
+ logger: pl.loggers.TensorBoardLogger,
26
+ model: nn.Module,
27
+ evaluate_device: str,
28
+ ) -> List[pl.Callback]:
29
+ r"""Get MUSDB18 callbacks of a config yaml.
30
+
31
+ Args:
32
+ config_yaml: str
33
+ workspace: str
34
+ checkpoints_dir: str, directory to save checkpoints
35
+ statistics_dir: str, directory to save statistics
36
+ logger: pl.loggers.TensorBoardLogger
37
+ model: nn.Module
38
+ evaluate_device: str
39
+
40
+ Return:
41
+ callbacks: List[pl.Callback]
42
+ """
43
+ configs = read_yaml(config_yaml)
44
+ task_name = configs['task_name']
45
+ evaluation_callback = configs['train']['evaluation_callback']
46
+ target_source_types = configs['train']['target_source_types']
47
+ input_channels = configs['train']['channels']
48
+ evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
49
+ test_segment_seconds = configs['evaluate']['segment_seconds']
50
+ sample_rate = configs['train']['sample_rate']
51
+ test_segment_samples = int(test_segment_seconds * sample_rate)
52
+ test_batch_size = configs['evaluate']['batch_size']
53
+
54
+ evaluate_step_frequency = configs['train']['evaluate_step_frequency']
55
+ save_step_frequency = configs['train']['save_step_frequency']
56
+
57
+ # save checkpoint callback
58
+ save_checkpoints_callback = SaveCheckpointsCallback(
59
+ model=model,
60
+ checkpoints_dir=checkpoints_dir,
61
+ save_step_frequency=save_step_frequency,
62
+ )
63
+
64
+ # evaluation callback
65
+ EvaluationCallback = _get_evaluation_callback_class(evaluation_callback)
66
+
67
+ # statistics container
68
+ statistics_container = StatisticsContainer(statistics_path)
69
+
70
+ # evaluation callback
71
+ evaluate_train_callback = EvaluationCallback(
72
+ dataset_dir=evaluation_audios_dir,
73
+ model=model,
74
+ target_source_types=target_source_types,
75
+ input_channels=input_channels,
76
+ sample_rate=sample_rate,
77
+ split='train',
78
+ segment_samples=test_segment_samples,
79
+ batch_size=test_batch_size,
80
+ device=evaluate_device,
81
+ evaluate_step_frequency=evaluate_step_frequency,
82
+ logger=logger,
83
+ statistics_container=statistics_container,
84
+ )
85
+
86
+ evaluate_test_callback = EvaluationCallback(
87
+ dataset_dir=evaluation_audios_dir,
88
+ model=model,
89
+ target_source_types=target_source_types,
90
+ input_channels=input_channels,
91
+ sample_rate=sample_rate,
92
+ split='test',
93
+ segment_samples=test_segment_samples,
94
+ batch_size=test_batch_size,
95
+ device=evaluate_device,
96
+ evaluate_step_frequency=evaluate_step_frequency,
97
+ logger=logger,
98
+ statistics_container=statistics_container,
99
+ )
100
+
101
+ # callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback]
102
+ callbacks = [save_checkpoints_callback, evaluate_test_callback]
103
+
104
+ return callbacks
105
+
106
+
107
+ def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback:
108
+ r"""Get evaluation callback class."""
109
+ if evaluation_callback == "Musdb18EvaluationCallback":
110
+ return Musdb18EvaluationCallback
111
+
112
+ if evaluation_callback == 'Musdb18ConditionalEvaluationCallback':
113
+ return Musdb18ConditionalEvaluationCallback
114
+
115
+ else:
116
+ raise NotImplementedError
117
+
118
+
119
+ class Musdb18EvaluationCallback(pl.Callback):
120
+ def __init__(
121
+ self,
122
+ dataset_dir: str,
123
+ model: nn.Module,
124
+ target_source_types: str,
125
+ input_channels: int,
126
+ split: str,
127
+ sample_rate: int,
128
+ segment_samples: int,
129
+ batch_size: int,
130
+ device: str,
131
+ evaluate_step_frequency: int,
132
+ logger: pl.loggers.TensorBoardLogger,
133
+ statistics_container: StatisticsContainer,
134
+ ):
135
+ r"""Callback to evaluate every #save_step_frequency steps.
136
+
137
+ Args:
138
+ dataset_dir: str
139
+ model: nn.Module
140
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
141
+ input_channels: int
142
+ split: 'train' | 'test'
143
+ sample_rate: int
144
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
145
+ batch_size, int, e.g., 12
146
+ device: str, e.g., 'cuda'
147
+ evaluate_step_frequency: int, evaluate every #save_step_frequency steps
148
+ logger: object
149
+ statistics_container: StatisticsContainer
150
+ """
151
+ self.model = model
152
+ self.target_source_types = target_source_types
153
+ self.input_channels = input_channels
154
+ self.sample_rate = sample_rate
155
+ self.split = split
156
+ self.segment_samples = segment_samples
157
+ self.evaluate_step_frequency = evaluate_step_frequency
158
+ self.logger = logger
159
+ self.statistics_container = statistics_container
160
+ self.mono = input_channels == 1
161
+ self.resample_type = "kaiser_fast"
162
+
163
+ self.mus = musdb.DB(root=dataset_dir, subsets=[split])
164
+
165
+ error_msg = "The directory {} is empty!".format(dataset_dir)
166
+ assert len(self.mus) > 0, error_msg
167
+
168
+ # separator
169
+ self.separator = Separator(model, self.segment_samples, batch_size, device)
170
+
171
+ @rank_zero_only
172
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
173
+ r"""Evaluate separation SDRs of audio recordings."""
174
+ global_step = trainer.global_step
175
+
176
+ if global_step % self.evaluate_step_frequency == 0:
177
+
178
+ sdr_dict = {}
179
+
180
+ logging.info("--- Step {} ---".format(global_step))
181
+ logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
182
+
183
+ eval_time = time.time()
184
+
185
+ for track in self.mus.tracks:
186
+
187
+ audio_name = track.name
188
+
189
+ # Get waveform of mixture.
190
+ mixture = track.audio.T
191
+ # (channels_num, audio_samples)
192
+
193
+ mixture = preprocess_audio(
194
+ audio=mixture,
195
+ mono=self.mono,
196
+ origin_sr=track.rate,
197
+ sr=self.sample_rate,
198
+ resample_type=self.resample_type,
199
+ )
200
+ # (channels_num, audio_samples)
201
+
202
+ target_dict = {}
203
+ sdr_dict[audio_name] = {}
204
+
205
+ # Get waveform of all target source types.
206
+ for j, source_type in enumerate(self.target_source_types):
207
+ # E.g., ['vocals', 'bass', ...]
208
+
209
+ audio = track.targets[source_type].audio.T
210
+
211
+ audio = preprocess_audio(
212
+ audio=audio,
213
+ mono=self.mono,
214
+ origin_sr=track.rate,
215
+ sr=self.sample_rate,
216
+ resample_type=self.resample_type,
217
+ )
218
+ # (channels_num, audio_samples)
219
+
220
+ target_dict[source_type] = audio
221
+ # (channels_num, audio_samples)
222
+
223
+ # Separate.
224
+ input_dict = {'waveform': mixture}
225
+
226
+ sep_wavs = self.separator.separate(input_dict)
227
+ # sep_wavs: (target_sources_num * channels_num, audio_samples)
228
+
229
+ # Post process separation results.
230
+ sep_wavs = preprocess_audio(
231
+ audio=sep_wavs,
232
+ mono=self.mono,
233
+ origin_sr=self.sample_rate,
234
+ sr=track.rate,
235
+ resample_type=self.resample_type,
236
+ )
237
+ # sep_wavs: (target_sources_num * channels_num, audio_samples)
238
+
239
+ sep_wavs = librosa.util.fix_length(
240
+ sep_wavs, size=mixture.shape[1], axis=1
241
+ )
242
+ # sep_wavs: (target_sources_num * channels_num, audio_samples)
243
+
244
+ sep_wav_dict = get_separated_wavs_from_simo_output(
245
+ sep_wavs, self.input_channels, self.target_source_types
246
+ )
247
+ # output_dict: dict, e.g., {
248
+ # 'vocals': (channels_num, audio_samples),
249
+ # 'bass': (channels_num, audio_samples),
250
+ # ...,
251
+ # }
252
+
253
+ # Evaluate for all target source types.
254
+ for source_type in self.target_source_types:
255
+ # E.g., ['vocals', 'bass', ...]
256
+
257
+ # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan).
258
+ (sdrs, _, _, _) = museval.evaluate(
259
+ [target_dict[source_type].T], [sep_wav_dict[source_type].T]
260
+ )
261
+
262
+ sdr = np.nanmedian(sdrs)
263
+ sdr_dict[audio_name][source_type] = sdr
264
+
265
+ logging.info(
266
+ "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
267
+ )
268
+
269
+ logging.info("-----------------------------")
270
+ median_sdr_dict = {}
271
+
272
+ # Calculate median SDRs of all songs.
273
+ for source_type in self.target_source_types:
274
+ # E.g., ['vocals', 'bass', ...]
275
+
276
+ median_sdr = np.median(
277
+ [
278
+ sdr_dict[audio_name][source_type]
279
+ for audio_name in sdr_dict.keys()
280
+ ]
281
+ )
282
+
283
+ median_sdr_dict[source_type] = median_sdr
284
+
285
+ logging.info(
286
+ "Step: {}, {}, Median SDR: {:.3f}".format(
287
+ global_step, source_type, median_sdr
288
+ )
289
+ )
290
+
291
+ logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
292
+
293
+ statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
294
+ self.statistics_container.append(global_step, statistics, self.split)
295
+ self.statistics_container.dump()
296
+
297
+
298
+ def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict:
299
+ r"""Get separated waveforms of target sources from a single input multiple
300
+ output (SIMO) system.
301
+
302
+ Args:
303
+ x: (target_sources_num * channels_num, audio_samples)
304
+ input_channels: int
305
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
306
+
307
+ Returns:
308
+ output_dict: dict, e.g., {
309
+ 'vocals': (channels_num, audio_samples),
310
+ 'bass': (channels_num, audio_samples),
311
+ ...,
312
+ }
313
+ """
314
+ output_dict = {}
315
+
316
+ for j, source_type in enumerate(target_source_types):
317
+ output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels]
318
+
319
+ return output_dict
320
+
321
+
322
+ class Musdb18ConditionalEvaluationCallback(pl.Callback):
323
+ def __init__(
324
+ self,
325
+ dataset_dir: str,
326
+ model: nn.Module,
327
+ target_source_types: str,
328
+ input_channels: int,
329
+ split: str,
330
+ sample_rate: int,
331
+ segment_samples: int,
332
+ batch_size: int,
333
+ device: str,
334
+ evaluate_step_frequency: int,
335
+ logger: pl.loggers.TensorBoardLogger,
336
+ statistics_container: StatisticsContainer,
337
+ ):
338
+ r"""Callback to evaluate every #save_step_frequency steps.
339
+
340
+ Args:
341
+ dataset_dir: str
342
+ model: nn.Module
343
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
344
+ input_channels: int
345
+ split: 'train' | 'test'
346
+ sample_rate: int
347
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
348
+ batch_size, int, e.g., 12
349
+ device: str, e.g., 'cuda'
350
+ evaluate_step_frequency: int, evaluate every #save_step_frequency steps
351
+ logger: object
352
+ statistics_container: StatisticsContainer
353
+ """
354
+ self.model = model
355
+ self.target_source_types = target_source_types
356
+ self.input_channels = input_channels
357
+ self.sample_rate = sample_rate
358
+ self.split = split
359
+ self.segment_samples = segment_samples
360
+ self.evaluate_step_frequency = evaluate_step_frequency
361
+ self.logger = logger
362
+ self.statistics_container = statistics_container
363
+ self.mono = input_channels == 1
364
+ self.resample_type = "kaiser_fast"
365
+
366
+ self.mus = musdb.DB(root=dataset_dir, subsets=[split])
367
+
368
+ error_msg = "The directory {} is empty!".format(dataset_dir)
369
+ assert len(self.mus) > 0, error_msg
370
+
371
+ # separator
372
+ self.separator = Separator(model, self.segment_samples, batch_size, device)
373
+
374
+ @rank_zero_only
375
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
376
+ r"""Evaluate separation SDRs of audio recordings."""
377
+ global_step = trainer.global_step
378
+
379
+ if global_step % self.evaluate_step_frequency == 0:
380
+
381
+ sdr_dict = {}
382
+
383
+ logging.info("--- Step {} ---".format(global_step))
384
+ logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))
385
+
386
+ eval_time = time.time()
387
+
388
+ for track in self.mus.tracks:
389
+
390
+ audio_name = track.name
391
+
392
+ # Get waveform of mixture.
393
+ mixture = track.audio.T
394
+ # (channels_num, audio_samples)
395
+
396
+ mixture = preprocess_audio(
397
+ audio=mixture,
398
+ mono=self.mono,
399
+ origin_sr=track.rate,
400
+ sr=self.sample_rate,
401
+ resample_type=self.resample_type,
402
+ )
403
+ # (channels_num, audio_samples)
404
+
405
+ target_dict = {}
406
+ sdr_dict[audio_name] = {}
407
+
408
+ # Get waveform of all target source types.
409
+ for j, source_type in enumerate(self.target_source_types):
410
+ # E.g., ['vocals', 'bass', ...]
411
+
412
+ audio = track.targets[source_type].audio.T
413
+
414
+ audio = preprocess_audio(
415
+ audio=audio,
416
+ mono=self.mono,
417
+ origin_sr=track.rate,
418
+ sr=self.sample_rate,
419
+ resample_type=self.resample_type,
420
+ )
421
+ # (channels_num, audio_samples)
422
+
423
+ target_dict[source_type] = audio
424
+ # (channels_num, audio_samples)
425
+
426
+ condition = np.zeros(len(self.target_source_types))
427
+ condition[j] = 1
428
+
429
+ input_dict = {'waveform': mixture, 'condition': condition}
430
+
431
+ sep_wav = self.separator.separate(input_dict)
432
+ # sep_wav: (channels_num, audio_samples)
433
+
434
+ sep_wav = preprocess_audio(
435
+ audio=sep_wav,
436
+ mono=self.mono,
437
+ origin_sr=self.sample_rate,
438
+ sr=track.rate,
439
+ resample_type=self.resample_type,
440
+ )
441
+ # sep_wav: (channels_num, audio_samples)
442
+
443
+ sep_wav = librosa.util.fix_length(
444
+ sep_wav, size=mixture.shape[1], axis=1
445
+ )
446
+ # sep_wav: (target_sources_num * channels_num, audio_samples)
447
+
448
+ # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan)
449
+ (sdrs, _, _, _) = museval.evaluate(
450
+ [target_dict[source_type].T], [sep_wav.T]
451
+ )
452
+
453
+ sdr = np.nanmedian(sdrs)
454
+ sdr_dict[audio_name][source_type] = sdr
455
+
456
+ logging.info(
457
+ "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
458
+ )
459
+
460
+ logging.info("-----------------------------")
461
+ median_sdr_dict = {}
462
+
463
+ # Calculate median SDRs of all songs.
464
+ for source_type in self.target_source_types:
465
+
466
+ median_sdr = np.median(
467
+ [
468
+ sdr_dict[audio_name][source_type]
469
+ for audio_name in sdr_dict.keys()
470
+ ]
471
+ )
472
+
473
+ median_sdr_dict[source_type] = median_sdr
474
+
475
+ logging.info(
476
+ "Step: {}, {}, Median SDR: {:.3f}".format(
477
+ global_step, source_type, median_sdr
478
+ )
479
+ )
480
+
481
+ logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
482
+
483
+ statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
484
+ self.statistics_container.append(global_step, statistics, self.split)
485
+ self.statistics_container.dump()
bytesep/callbacks/voicebank_demand.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import List, NoReturn
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import pysepm
9
+ import pytorch_lightning as pl
10
+ import torch.nn as nn
11
+ from pesq import pesq
12
+ from pytorch_lightning.utilities import rank_zero_only
13
+
14
+ from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback
15
+ from bytesep.inference import Separator
16
+ from bytesep.utils import StatisticsContainer, read_yaml
17
+
18
+
19
+ def get_voicebank_demand_callbacks(
20
+ config_yaml: str,
21
+ workspace: str,
22
+ checkpoints_dir: str,
23
+ statistics_path: str,
24
+ logger: pl.loggers.TensorBoardLogger,
25
+ model: nn.Module,
26
+ evaluate_device: str,
27
+ ) -> List[pl.Callback]:
28
+ """Get Voicebank-Demand callbacks of a config yaml.
29
+
30
+ Args:
31
+ config_yaml: str
32
+ workspace: str
33
+ checkpoints_dir: str, directory to save checkpoints
34
+ statistics_dir: str, directory to save statistics
35
+ logger: pl.loggers.TensorBoardLogger
36
+ model: nn.Module
37
+ evaluate_device: str
38
+
39
+ Return:
40
+ callbacks: List[pl.Callback]
41
+ """
42
+ configs = read_yaml(config_yaml)
43
+ task_name = configs['task_name']
44
+ target_source_types = configs['train']['target_source_types']
45
+ input_channels = configs['train']['channels']
46
+ evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
47
+ sample_rate = configs['train']['sample_rate']
48
+ evaluate_step_frequency = configs['train']['evaluate_step_frequency']
49
+ save_step_frequency = configs['train']['save_step_frequency']
50
+ test_batch_size = configs['evaluate']['batch_size']
51
+ test_segment_seconds = configs['evaluate']['segment_seconds']
52
+
53
+ test_segment_samples = int(test_segment_seconds * sample_rate)
54
+ assert len(target_source_types) == 1
55
+ target_source_type = target_source_types[0]
56
+ assert target_source_type == 'speech'
57
+
58
+ # save checkpoint callback
59
+ save_checkpoints_callback = SaveCheckpointsCallback(
60
+ model=model,
61
+ checkpoints_dir=checkpoints_dir,
62
+ save_step_frequency=save_step_frequency,
63
+ )
64
+
65
+ # statistics container
66
+ statistics_container = StatisticsContainer(statistics_path)
67
+
68
+ # evaluation callback
69
+ evaluate_test_callback = EvaluationCallback(
70
+ model=model,
71
+ input_channels=input_channels,
72
+ sample_rate=sample_rate,
73
+ evaluation_audios_dir=evaluation_audios_dir,
74
+ segment_samples=test_segment_samples,
75
+ batch_size=test_batch_size,
76
+ device=evaluate_device,
77
+ evaluate_step_frequency=evaluate_step_frequency,
78
+ logger=logger,
79
+ statistics_container=statistics_container,
80
+ )
81
+
82
+ callbacks = [save_checkpoints_callback, evaluate_test_callback]
83
+
84
+ return callbacks
85
+
86
+
87
+ class EvaluationCallback(pl.Callback):
88
+ def __init__(
89
+ self,
90
+ model: nn.Module,
91
+ input_channels: int,
92
+ evaluation_audios_dir,
93
+ sample_rate: int,
94
+ segment_samples: int,
95
+ batch_size: int,
96
+ device: str,
97
+ evaluate_step_frequency: int,
98
+ logger: pl.loggers.TensorBoardLogger,
99
+ statistics_container: StatisticsContainer,
100
+ ):
101
+ r"""Callback to evaluate every #save_step_frequency steps.
102
+
103
+ Args:
104
+ model: nn.Module
105
+ input_channels: int
106
+ evaluation_audios_dir: str, directory containing audios for evaluation
107
+ sample_rate: int
108
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
109
+ batch_size, int, e.g., 12
110
+ device: str, e.g., 'cuda'
111
+ evaluate_step_frequency: int, evaluate every #save_step_frequency steps
112
+ logger: pl.loggers.TensorBoardLogger
113
+ statistics_container: StatisticsContainer
114
+ """
115
+ self.model = model
116
+ self.mono = True
117
+ self.sample_rate = sample_rate
118
+ self.segment_samples = segment_samples
119
+ self.evaluate_step_frequency = evaluate_step_frequency
120
+ self.logger = logger
121
+ self.statistics_container = statistics_container
122
+
123
+ self.clean_dir = os.path.join(evaluation_audios_dir, "clean_testset_wav")
124
+ self.noisy_dir = os.path.join(evaluation_audios_dir, "noisy_testset_wav")
125
+
126
+ self.EVALUATION_SAMPLE_RATE = 16000 # Evaluation sample rate of the
127
+ # Voicebank-Demand task.
128
+
129
+ # separator
130
+ self.separator = Separator(model, self.segment_samples, batch_size, device)
131
+
132
+ @rank_zero_only
133
+ def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
134
+ r"""Evaluate losses on a few mini-batches. Losses are only used for
135
+ observing training, and are not final F1 metrics.
136
+ """
137
+
138
+ global_step = trainer.global_step
139
+
140
+ if global_step % self.evaluate_step_frequency == 0:
141
+
142
+ audio_names = sorted(
143
+ [
144
+ audio_name
145
+ for audio_name in sorted(os.listdir(self.clean_dir))
146
+ if audio_name.endswith('.wav')
147
+ ]
148
+ )
149
+
150
+ error_str = "Directory {} does not contain audios for evaluation!".format(
151
+ self.clean_dir
152
+ )
153
+ assert len(audio_names) > 0, error_str
154
+
155
+ pesqs, csigs, cbaks, covls, ssnrs = [], [], [], [], []
156
+
157
+ logging.info("--- Step {} ---".format(global_step))
158
+ logging.info("Total {} pieces for evaluation:".format(len(audio_names)))
159
+
160
+ eval_time = time.time()
161
+
162
+ for n, audio_name in enumerate(audio_names):
163
+
164
+ # Load audio.
165
+ clean_path = os.path.join(self.clean_dir, audio_name)
166
+ mixture_path = os.path.join(self.noisy_dir, audio_name)
167
+
168
+ mixture, _ = librosa.core.load(
169
+ mixture_path, sr=self.sample_rate, mono=self.mono
170
+ )
171
+
172
+ if mixture.ndim == 1:
173
+ mixture = mixture[None, :]
174
+ # (channels_num, audio_length)
175
+
176
+ # Separate.
177
+ input_dict = {'waveform': mixture}
178
+
179
+ sep_wav = self.separator.separate(input_dict)
180
+ # (channels_num, audio_length)
181
+
182
+ # Target
183
+ clean, _ = librosa.core.load(
184
+ clean_path, sr=self.EVALUATION_SAMPLE_RATE, mono=self.mono
185
+ )
186
+
187
+ # to mono
188
+ sep_wav = np.squeeze(sep_wav)
189
+
190
+ # Resample for evaluation.
191
+ sep_wav = librosa.resample(
192
+ sep_wav,
193
+ orig_sr=self.sample_rate,
194
+ target_sr=self.EVALUATION_SAMPLE_RATE,
195
+ )
196
+
197
+ sep_wav = librosa.util.fix_length(sep_wav, size=len(clean), axis=0)
198
+ # (channels, audio_length)
199
+
200
+ # Evaluate metrics
201
+ pesq_ = pesq(self.EVALUATION_SAMPLE_RATE, clean, sep_wav, 'wb')
202
+
203
+ (csig, cbak, covl) = pysepm.composite(
204
+ clean, sep_wav, self.EVALUATION_SAMPLE_RATE
205
+ )
206
+
207
+ ssnr = pysepm.SNRseg(clean, sep_wav, self.EVALUATION_SAMPLE_RATE)
208
+
209
+ pesqs.append(pesq_)
210
+ csigs.append(csig)
211
+ cbaks.append(cbak)
212
+ covls.append(covl)
213
+ ssnrs.append(ssnr)
214
+ print(
215
+ '{}, {}, PESQ: {:.3f}, CSIG: {:.3f}, CBAK: {:.3f}, COVL: {:.3f}, SSNR: {:.3f}'.format(
216
+ n, audio_name, pesq_, csig, cbak, covl, ssnr
217
+ )
218
+ )
219
+
220
+ logging.info("-----------------------------")
221
+ logging.info('Avg PESQ: {:.3f}'.format(np.mean(pesqs)))
222
+ logging.info('Avg CSIG: {:.3f}'.format(np.mean(csigs)))
223
+ logging.info('Avg CBAK: {:.3f}'.format(np.mean(cbaks)))
224
+ logging.info('Avg COVL: {:.3f}'.format(np.mean(covls)))
225
+ logging.info('Avg SSNR: {:.3f}'.format(np.mean(ssnrs)))
226
+
227
+ logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))
228
+
229
+ statistics = {"pesq": np.mean(pesqs)}
230
+ self.statistics_container.append(global_step, statistics, 'test')
231
+ self.statistics_container.dump()
bytesep/data/__init__.py ADDED
File without changes
bytesep/data/augmentors.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import librosa
4
+ import numpy as np
5
+
6
+ from bytesep.utils import db_to_magnitude, get_pitch_shift_factor, magnitude_to_db
7
+
8
+
9
+ class Augmentor:
10
+ def __init__(self, augmentations: Dict, random_seed=1234):
11
+ r"""Augmentor for data augmentation of a waveform.
12
+
13
+ Args:
14
+ augmentations: Dict, e.g, {
15
+ 'mixaudio': {'vocals': 2, 'accompaniment': 2}
16
+ 'pitch_shift': {'vocals': 4, 'accompaniment': 4},
17
+ ...,
18
+ }
19
+ random_seed: int
20
+ """
21
+ self.augmentations = augmentations
22
+ self.random_state = np.random.RandomState(random_seed)
23
+
24
+ def __call__(self, waveform: np.array, source_type: str) -> np.array:
25
+ r"""Augment a waveform.
26
+
27
+ Args:
28
+ waveform: (channels_num, audio_samples)
29
+ source_type: str
30
+
31
+ Returns:
32
+ new_waveform: (channels_num, new_audio_samples)
33
+ """
34
+ if 'pitch_shift' in self.augmentations.keys():
35
+ waveform = self.pitch_shift(waveform, source_type)
36
+
37
+ if 'magnitude_scale' in self.augmentations.keys():
38
+ waveform = self.magnitude_scale(waveform, source_type)
39
+
40
+ if 'swap_channel' in self.augmentations.keys():
41
+ waveform = self.swap_channel(waveform, source_type)
42
+
43
+ if 'flip_axis' in self.augmentations.keys():
44
+ waveform = self.flip_axis(waveform, source_type)
45
+
46
+ return waveform
47
+
48
+ def pitch_shift(self, waveform: np.array, source_type: str) -> np.array:
49
+ r"""Shift the pitch of a waveform. We use resampling for fast pitch
50
+ shifting, so the speed will also be chaneged. The length of the returned
51
+ waveform will be changed.
52
+
53
+ Args:
54
+ waveform: (channels_num, audio_samples)
55
+ source_type: str
56
+
57
+ Returns:
58
+ new_waveform: (channels_num, new_audio_samples)
59
+ """
60
+
61
+ # maximum pitch shift in semitones
62
+ max_pitch_shift = self.augmentations['pitch_shift'][source_type]
63
+
64
+ if max_pitch_shift == 0: # No pitch shift augmentations.
65
+ return waveform
66
+
67
+ # random pitch shift
68
+ rand_pitch = self.random_state.uniform(
69
+ low=-max_pitch_shift, high=max_pitch_shift
70
+ )
71
+
72
+ # We use librosa.resample instead of librosa.effects.pitch_shift
73
+ # because it is 10x times faster.
74
+ pitch_shift_factor = get_pitch_shift_factor(rand_pitch)
75
+ dummy_sample_rate = 10000 # Dummy constant.
76
+
77
+ channels_num = waveform.shape[0]
78
+
79
+ if channels_num == 1:
80
+ waveform = np.squeeze(waveform)
81
+
82
+ new_waveform = librosa.resample(
83
+ y=waveform,
84
+ orig_sr=dummy_sample_rate,
85
+ target_sr=dummy_sample_rate / pitch_shift_factor,
86
+ res_type='linear',
87
+ axis=-1,
88
+ )
89
+
90
+ if channels_num == 1:
91
+ new_waveform = new_waveform[None, :]
92
+
93
+ return new_waveform
94
+
95
+ def magnitude_scale(self, waveform: np.array, source_type: str) -> np.array:
96
+ r"""Scale the magnitude of a waveform.
97
+
98
+ Args:
99
+ waveform: (channels_num, audio_samples)
100
+ source_type: str
101
+
102
+ Returns:
103
+ new_waveform: (channels_num, audio_samples)
104
+ """
105
+ lower_db = self.augmentations['magnitude_scale'][source_type]['lower_db']
106
+ higher_db = self.augmentations['magnitude_scale'][source_type]['higher_db']
107
+
108
+ if lower_db == 0 and higher_db == 0: # No magnitude scale augmentation.
109
+ return waveform
110
+
111
+ # The magnitude (in dB) of the sample with the maximum value.
112
+ waveform_db = magnitude_to_db(np.max(np.abs(waveform)))
113
+
114
+ new_waveform_db = self.random_state.uniform(
115
+ waveform_db + lower_db, min(waveform_db + higher_db, 0)
116
+ )
117
+
118
+ relative_db = new_waveform_db - waveform_db
119
+
120
+ relative_scale = db_to_magnitude(relative_db)
121
+
122
+ new_waveform = waveform * relative_scale
123
+
124
+ return new_waveform
125
+
126
+ def swap_channel(self, waveform: np.array, source_type: str) -> np.array:
127
+ r"""Randomly swap channels.
128
+
129
+ Args:
130
+ waveform: (channels_num, audio_samples)
131
+ source_type: str
132
+
133
+ Returns:
134
+ new_waveform: (channels_num, audio_samples)
135
+ """
136
+ ndim = waveform.shape[0]
137
+
138
+ if ndim == 1:
139
+ return waveform
140
+ else:
141
+ random_axes = self.random_state.permutation(ndim)
142
+ return waveform[random_axes, :]
143
+
144
+ def flip_axis(self, waveform: np.array, source_type: str) -> np.array:
145
+ r"""Randomly flip the waveform along x-axis.
146
+
147
+ Args:
148
+ waveform: (channels_num, audio_samples)
149
+ source_type: str
150
+
151
+ Returns:
152
+ new_waveform: (channels_num, audio_samples)
153
+ """
154
+ ndim = waveform.shape[0]
155
+ random_values = self.random_state.choice([-1, 1], size=ndim)
156
+
157
+ return waveform * random_values[:, None]
bytesep/data/batch_data_preprocessors.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import torch
4
+
5
+
6
+ class BasicBatchDataPreprocessor:
7
+ def __init__(self, target_source_types: List[str]):
8
+ r"""Batch data preprocessor. Used for preparing mixtures and targets for
9
+ training. If there are multiple target source types, the waveforms of
10
+ those sources will be stacked along the channel dimension.
11
+
12
+ Args:
13
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
14
+ """
15
+ self.target_source_types = target_source_types
16
+
17
+ def __call__(self, batch_data_dict: Dict) -> List[Dict]:
18
+ r"""Format waveforms and targets for training.
19
+
20
+ Args:
21
+ batch_data_dict: dict, e.g., {
22
+ 'mixture': (batch_size, channels_num, segment_samples),
23
+ 'vocals': (batch_size, channels_num, segment_samples),
24
+ 'bass': (batch_size, channels_num, segment_samples),
25
+ ...,
26
+ }
27
+
28
+ Returns:
29
+ input_dict: dict, e.g., {
30
+ 'waveform': (batch_size, channels_num, segment_samples),
31
+ }
32
+ output_dict: dict, e.g., {
33
+ 'target': (batch_size, target_sources_num * channels_num, segment_samples)
34
+ }
35
+ """
36
+ mixtures = batch_data_dict['mixture']
37
+ # mixtures: (batch_size, channels_num, segment_samples)
38
+
39
+ # Concatenate waveforms of multiple targets along the channel axis.
40
+ targets = torch.cat(
41
+ [batch_data_dict[source_type] for source_type in self.target_source_types],
42
+ dim=1,
43
+ )
44
+ # targets: (batch_size, target_sources_num * channels_num, segment_samples)
45
+
46
+ input_dict = {'waveform': mixtures}
47
+ target_dict = {'waveform': targets}
48
+
49
+ return input_dict, target_dict
50
+
51
+
52
+ class ConditionalSisoBatchDataPreprocessor:
53
+ def __init__(self, target_source_types: List[str]):
54
+ r"""Conditional single input single output (SISO) batch data
55
+ preprocessor. Select one target source from several target sources as
56
+ training target and prepare the corresponding conditional vector.
57
+
58
+ Args:
59
+ target_source_types: List[str], e.g., ['vocals', 'bass', ...]
60
+ """
61
+ self.target_source_types = target_source_types
62
+
63
+ def __call__(self, batch_data_dict: Dict) -> List[Dict]:
64
+ r"""Format waveforms and targets for training.
65
+
66
+ Args:
67
+ batch_data_dict: dict, e.g., {
68
+ 'mixture': (batch_size, channels_num, segment_samples),
69
+ 'vocals': (batch_size, channels_num, segment_samples),
70
+ 'bass': (batch_size, channels_num, segment_samples),
71
+ ...,
72
+ }
73
+
74
+ Returns:
75
+ input_dict: dict, e.g., {
76
+ 'waveform': (batch_size, channels_num, segment_samples),
77
+ 'condition': (batch_size, target_sources_num),
78
+ }
79
+ output_dict: dict, e.g., {
80
+ 'target': (batch_size, channels_num, segment_samples)
81
+ }
82
+ """
83
+
84
+ batch_size = len(batch_data_dict['mixture'])
85
+ target_sources_num = len(self.target_source_types)
86
+
87
+ assert (
88
+ batch_size % target_sources_num == 0
89
+ ), "Batch size should be \
90
+ evenly divided by target sources number."
91
+
92
+ mixtures = batch_data_dict['mixture']
93
+ # mixtures: (batch_size, channels_num, segment_samples)
94
+
95
+ conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device)
96
+ # conditions: (batch_size, target_sources_num)
97
+
98
+ targets = []
99
+
100
+ for n in range(batch_size):
101
+
102
+ k = n % target_sources_num # source class index
103
+ source_type = self.target_source_types[k]
104
+
105
+ targets.append(batch_data_dict[source_type][n])
106
+
107
+ conditions[n, k] = 1
108
+
109
+ # conditions will looks like:
110
+ # [[1, 0, 0, 0],
111
+ # [0, 1, 0, 0],
112
+ # [0, 0, 1, 0],
113
+ # [0, 0, 0, 1],
114
+ # [1, 0, 0, 0],
115
+ # [0, 1, 0, 0],
116
+ # ...,
117
+ # ]
118
+
119
+ targets = torch.stack(targets, dim=0)
120
+ # targets: (batch_size, channels_num, segment_samples)
121
+
122
+ input_dict = {
123
+ 'waveform': mixtures,
124
+ 'condition': conditions,
125
+ }
126
+
127
+ target_dict = {'waveform': targets}
128
+
129
+ return input_dict, target_dict
130
+
131
+
132
+ def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object:
133
+ r"""Get batch data preprocessor class."""
134
+ if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor':
135
+ return BasicBatchDataPreprocessor
136
+
137
+ elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor':
138
+ return ConditionalSisoBatchDataPreprocessor
139
+
140
+ else:
141
+ raise NotImplementedError
bytesep/data/data_modules.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, NoReturn, Optional
2
+
3
+ import h5py
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from pytorch_lightning.core.datamodule import LightningDataModule
8
+
9
+ from bytesep.data.samplers import DistributedSamplerWrapper
10
+ from bytesep.utils import int16_to_float32
11
+
12
+
13
+ class DataModule(LightningDataModule):
14
+ def __init__(
15
+ self,
16
+ train_sampler: object,
17
+ train_dataset: object,
18
+ num_workers: int,
19
+ distributed: bool,
20
+ ):
21
+ r"""Data module.
22
+
23
+ Args:
24
+ train_sampler: Sampler object
25
+ train_dataset: Dataset object
26
+ num_workers: int
27
+ distributed: bool
28
+ """
29
+ super().__init__()
30
+ self._train_sampler = train_sampler
31
+ self.train_dataset = train_dataset
32
+ self.num_workers = num_workers
33
+ self.distributed = distributed
34
+
35
+ def setup(self, stage: Optional[str] = None) -> NoReturn:
36
+ r"""called on every device."""
37
+
38
+ # SegmentSampler is used for selecting segments for training.
39
+ # On multiple devices, each SegmentSampler samples a part of mini-batch
40
+ # data.
41
+ if self.distributed:
42
+ self.train_sampler = DistributedSamplerWrapper(self._train_sampler)
43
+
44
+ else:
45
+ self.train_sampler = self._train_sampler
46
+
47
+ def train_dataloader(self) -> torch.utils.data.DataLoader:
48
+ r"""Get train loader."""
49
+ train_loader = torch.utils.data.DataLoader(
50
+ dataset=self.train_dataset,
51
+ batch_sampler=self.train_sampler,
52
+ collate_fn=collate_fn,
53
+ num_workers=self.num_workers,
54
+ pin_memory=True,
55
+ )
56
+
57
+ return train_loader
58
+
59
+
60
+ class Dataset:
61
+ def __init__(self, augmentor: object, segment_samples: int):
62
+ r"""Used for getting data according to a meta.
63
+
64
+ Args:
65
+ augmentor: Augmentor class
66
+ segment_samples: int
67
+ """
68
+ self.augmentor = augmentor
69
+ self.segment_samples = segment_samples
70
+
71
+ def __getitem__(self, meta: Dict) -> Dict:
72
+ r"""Return data according to a meta. E.g., an input meta looks like: {
73
+ 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
74
+ 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}.
75
+ }
76
+
77
+ Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation).
78
+ Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation).
79
+ Finally, mixture is created by summing vocals and accompaniment.
80
+
81
+ Args:
82
+ meta: dict, e.g., {
83
+ 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
84
+ 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}
85
+ }
86
+
87
+ Returns:
88
+ data_dict: dict, e.g., {
89
+ 'vocals': (channels, segments_num),
90
+ 'accompaniment': (channels, segments_num),
91
+ 'mixture': (channels, segments_num),
92
+ }
93
+ """
94
+ source_types = meta.keys()
95
+ data_dict = {}
96
+
97
+ for source_type in source_types:
98
+ # E.g., ['vocals', 'bass', ...]
99
+
100
+ waveforms = [] # Audio segments to be mix-audio augmented.
101
+
102
+ for m in meta[source_type]:
103
+ # E.g., {
104
+ # 'hdf5_path': '.../song_A.h5',
105
+ # 'key_in_hdf5': 'vocals',
106
+ # 'begin_sample': '13406400',
107
+ # 'end_sample': 13538700,
108
+ # }
109
+
110
+ hdf5_path = m['hdf5_path']
111
+ key_in_hdf5 = m['key_in_hdf5']
112
+ bgn_sample = m['begin_sample']
113
+ end_sample = m['end_sample']
114
+
115
+ with h5py.File(hdf5_path, 'r') as hf:
116
+
117
+ if source_type == 'audioset':
118
+ index_in_hdf5 = m['index_in_hdf5']
119
+ waveform = int16_to_float32(
120
+ hf['waveform'][index_in_hdf5][bgn_sample:end_sample]
121
+ )
122
+ waveform = waveform[None, :]
123
+ else:
124
+ waveform = int16_to_float32(
125
+ hf[key_in_hdf5][:, bgn_sample:end_sample]
126
+ )
127
+
128
+ if self.augmentor:
129
+ waveform = self.augmentor(waveform, source_type)
130
+
131
+ waveform = librosa.util.fix_length(
132
+ waveform, size=self.segment_samples, axis=1
133
+ )
134
+ # (channels_num, segments_num)
135
+
136
+ waveforms.append(waveform)
137
+ # E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)]
138
+
139
+ # mix-audio augmentation
140
+ data_dict[source_type] = np.sum(waveforms, axis=0)
141
+ # data_dict[source_type]: (channels_num, audio_samples)
142
+
143
+ # data_dict looks like: {
144
+ # 'voclas': (channels_num, audio_samples),
145
+ # 'accompaniment': (channels_num, audio_samples)
146
+ # }
147
+
148
+ # Mix segments from different sources.
149
+ mixture = np.sum(
150
+ [data_dict[source_type] for source_type in source_types], axis=0
151
+ )
152
+ data_dict['mixture'] = mixture
153
+ # shape: (channels_num, audio_samples)
154
+
155
+ return data_dict
156
+
157
+
158
+ def collate_fn(list_data_dict: List[Dict]) -> Dict:
159
+ r"""Collate mini-batch data to inputs and targets for training.
160
+
161
+ Args:
162
+ list_data_dict: e.g., [
163
+ {'vocals': (channels_num, segment_samples),
164
+ 'accompaniment': (channels_num, segment_samples),
165
+ 'mixture': (channels_num, segment_samples)
166
+ },
167
+ {'vocals': (channels_num, segment_samples),
168
+ 'accompaniment': (channels_num, segment_samples),
169
+ 'mixture': (channels_num, segment_samples)
170
+ },
171
+ ...]
172
+
173
+ Returns:
174
+ data_dict: e.g. {
175
+ 'vocals': (batch_size, channels_num, segment_samples),
176
+ 'accompaniment': (batch_size, channels_num, segment_samples),
177
+ 'mixture': (batch_size, channels_num, segment_samples)
178
+ }
179
+ """
180
+ data_dict = {}
181
+
182
+ for key in list_data_dict[0].keys():
183
+ data_dict[key] = torch.Tensor(
184
+ np.array([data_dict[key] for data_dict in list_data_dict])
185
+ )
186
+
187
+ return data_dict
bytesep/data/samplers.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from typing import Dict, List, NoReturn
3
+
4
+ import numpy as np
5
+ import torch.distributed as dist
6
+
7
+
8
+ class SegmentSampler:
9
+ def __init__(
10
+ self,
11
+ indexes_path: str,
12
+ segment_samples: int,
13
+ mixaudio_dict: Dict,
14
+ batch_size: int,
15
+ steps_per_epoch: int,
16
+ random_seed=1234,
17
+ ):
18
+ r"""Sample training indexes of sources.
19
+
20
+ Args:
21
+ indexes_path: str, path of indexes dict
22
+ segment_samplers: int
23
+ mixaudio_dict, dict, including hyper-parameters for mix-audio data
24
+ augmentation, e.g., {'voclas': 2, 'accompaniment': 2}
25
+ batch_size: int
26
+ steps_per_epoch: int, #steps_per_epoch is called an `epoch`
27
+ random_seed: int
28
+ """
29
+ self.segment_samples = segment_samples
30
+ self.mixaudio_dict = mixaudio_dict
31
+ self.batch_size = batch_size
32
+ self.steps_per_epoch = steps_per_epoch
33
+
34
+ self.meta_dict = pickle.load(open(indexes_path, "rb"))
35
+ # E.g., {
36
+ # 'vocals': [
37
+ # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
38
+ # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
39
+ # ...
40
+ # ],
41
+ # 'accompaniment': [
42
+ # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300},
43
+ # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
44
+ # ...
45
+ # ]
46
+ # }
47
+
48
+ self.source_types = self.meta_dict.keys()
49
+ # E.g., ['vocals', 'accompaniment']
50
+
51
+ self.pointers_dict = {source_type: 0 for source_type in self.source_types}
52
+ # E.g., {'vocals': 0, 'accompaniment': 0}
53
+
54
+ self.indexes_dict = {
55
+ source_type: np.arange(len(self.meta_dict[source_type]))
56
+ for source_type in self.source_types
57
+ }
58
+ # E.g. {
59
+ # 'vocals': [0, 1, ..., 225751],
60
+ # 'accompaniment': [0, 1, ..., 225751]
61
+ # }
62
+
63
+ self.random_state = np.random.RandomState(random_seed)
64
+
65
+ # Shuffle indexes.
66
+ for source_type in self.source_types:
67
+ self.random_state.shuffle(self.indexes_dict[source_type])
68
+ print("{}: {}".format(source_type, len(self.indexes_dict[source_type])))
69
+
70
+ def __iter__(self) -> List[Dict]:
71
+ r"""Yield a batch of meta info.
72
+
73
+ Returns:
74
+ batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [
75
+ {'vocals': [
76
+ {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
77
+ {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
78
+ 'accompaniment': [
79
+ {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760},
80
+ {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}]
81
+ }
82
+ ...
83
+ ]
84
+ """
85
+ batch_size = self.batch_size
86
+
87
+ while True:
88
+ batch_meta_dict = {source_type: [] for source_type in self.source_types}
89
+
90
+ for source_type in self.source_types:
91
+ # E.g., ['vocals', 'accompaniment']
92
+
93
+ # Loop until get a mini-batch.
94
+ while len(batch_meta_dict[source_type]) != batch_size:
95
+
96
+ largest_index = (
97
+ len(self.indexes_dict[source_type])
98
+ - self.mixaudio_dict[source_type]
99
+ )
100
+ # E.g., 225750 = 225752 - 2
101
+
102
+ if self.pointers_dict[source_type] > largest_index:
103
+
104
+ # Reset pointer, and shuffle indexes.
105
+ self.pointers_dict[source_type] = 0
106
+ self.random_state.shuffle(self.indexes_dict[source_type])
107
+
108
+ source_metas = []
109
+ mix_audios_num = self.mixaudio_dict[source_type]
110
+
111
+ for _ in range(mix_audios_num):
112
+
113
+ pointer = self.pointers_dict[source_type]
114
+ # E.g., 1
115
+
116
+ index = self.indexes_dict[source_type][pointer]
117
+ # E.g., 12231
118
+
119
+ self.pointers_dict[source_type] += 1
120
+
121
+ source_meta = self.meta_dict[source_type][index]
122
+ # E.g., ['song_A.h5', 198450, 330750]
123
+
124
+ # source_metas.append(new_source_meta)
125
+ source_metas.append(source_meta)
126
+
127
+ batch_meta_dict[source_type].append(source_metas)
128
+ # When mix-audio is 2, batch_meta_dict looks like: {
129
+ # 'vocals': [
130
+ # [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
131
+ # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}],
132
+ # [{'hdf5_path': 'songC.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1186290, 'end_sample': 1318590},
133
+ # {'hdf5_path': 'songD.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 8462790, 'end_sample': 8595090}]
134
+ # ]
135
+ # 'accompaniment': [
136
+ # [{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250},
137
+ # {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260}],
138
+ # [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 2795940, 'end_sample': 2928240},
139
+ # {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 10923570, 'end_sample': 11055870}]
140
+ # ]
141
+ # }
142
+
143
+ batch_meta_list = [
144
+ {
145
+ source_type: batch_meta_dict[source_type][i]
146
+ for source_type in self.source_types
147
+ }
148
+ for i in range(batch_size)
149
+ ]
150
+ # When mix-audio is 2, batch_meta_list looks like: [
151
+ # {'vocals': [
152
+ # {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
153
+ # {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
154
+ # 'accompaniment': [
155
+ # {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760},
156
+ # {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}]
157
+ # }
158
+ # ...
159
+ # ]
160
+
161
+ yield batch_meta_list
162
+
163
+ def __len__(self) -> int:
164
+ return self.steps_per_epoch
165
+
166
+ def state_dict(self) -> Dict:
167
+ state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict}
168
+ return state
169
+
170
+ def load_state_dict(self, state) -> NoReturn:
171
+ self.pointers_dict = state['pointers_dict']
172
+ self.indexes_dict = state['indexes_dict']
173
+
174
+
175
+ class DistributedSamplerWrapper:
176
+ def __init__(self, sampler):
177
+ r"""Distributed wrapper of sampler."""
178
+ self.sampler = sampler
179
+
180
+ def __iter__(self):
181
+ num_replicas = dist.get_world_size()
182
+ rank = dist.get_rank()
183
+
184
+ for indices in self.sampler:
185
+ yield indices[rank::num_replicas]
186
+
187
+ def __len__(self) -> int:
188
+ return len(self.sampler)
bytesep/dataset_creation/__init__.py ADDED
File without changes
bytesep/dataset_creation/create_evaluation_audios/__init__.py ADDED
File without changes
bytesep/dataset_creation/create_evaluation_audios/piano-symphony.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import NoReturn
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile
8
+
9
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
10
+ read_csv as read_instruments_solo_csv,
11
+ )
12
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import (
13
+ read_csv as read_maestro_csv,
14
+ )
15
+ from bytesep.utils import load_random_segment
16
+
17
+
18
+ def create_evaluation(args) -> NoReturn:
19
+ r"""Random mix and write out audios for evaluation.
20
+
21
+ Args:
22
+ piano_dataset_dir: str, the directory of the piano dataset
23
+ symphony_dataset_dir: str, the directory of the symphony dataset
24
+ evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
25
+ sample_rate: int
26
+ channels: int, e.g., 1 | 2
27
+ evaluation_segments_num: int
28
+ mono: bool
29
+
30
+ Returns:
31
+ NoReturn
32
+ """
33
+
34
+ # arguments & parameters
35
+ piano_dataset_dir = args.piano_dataset_dir
36
+ symphony_dataset_dir = args.symphony_dataset_dir
37
+ evaluation_audios_dir = args.evaluation_audios_dir
38
+ sample_rate = args.sample_rate
39
+ channels = args.channels
40
+ evaluation_segments_num = args.evaluation_segments_num
41
+ mono = True if channels == 1 else False
42
+
43
+ split = 'test'
44
+ segment_seconds = 10.0
45
+
46
+ random_state = np.random.RandomState(1234)
47
+
48
+ piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv')
49
+ piano_names_dict = read_maestro_csv(piano_meta_csv)
50
+ piano_audio_names = piano_names_dict[split]
51
+
52
+ symphony_meta_csv = os.path.join(symphony_dataset_dir, 'validation.csv')
53
+ symphony_names_dict = read_instruments_solo_csv(symphony_meta_csv)
54
+ symphony_audio_names = symphony_names_dict[split]
55
+
56
+ for source_type in ['piano', 'symphony', 'mixture']:
57
+ output_dir = os.path.join(evaluation_audios_dir, split, source_type)
58
+ os.makedirs(output_dir, exist_ok=True)
59
+
60
+ for n in range(evaluation_segments_num):
61
+
62
+ print('{} / {}'.format(n, evaluation_segments_num))
63
+
64
+ # Randomly select and write out a clean piano segment.
65
+ piano_audio_name = random_state.choice(piano_audio_names)
66
+ piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name)
67
+
68
+ piano_audio = load_random_segment(
69
+ audio_path=piano_audio_path,
70
+ random_state=random_state,
71
+ segment_seconds=segment_seconds,
72
+ mono=mono,
73
+ sample_rate=sample_rate,
74
+ )
75
+
76
+ output_piano_path = os.path.join(
77
+ evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n)
78
+ )
79
+ soundfile.write(
80
+ file=output_piano_path, data=piano_audio.T, samplerate=sample_rate
81
+ )
82
+ print("Write out to {}".format(output_piano_path))
83
+
84
+ # Randomly select and write out a clean symphony segment.
85
+ symphony_audio_name = random_state.choice(symphony_audio_names)
86
+ symphony_audio_path = os.path.join(
87
+ symphony_dataset_dir, "mp3s", symphony_audio_name
88
+ )
89
+
90
+ symphony_audio = load_random_segment(
91
+ audio_path=symphony_audio_path,
92
+ random_state=random_state,
93
+ segment_seconds=segment_seconds,
94
+ mono=mono,
95
+ sample_rate=sample_rate,
96
+ )
97
+
98
+ output_symphony_path = os.path.join(
99
+ evaluation_audios_dir, split, 'symphony', '{:04d}.wav'.format(n)
100
+ )
101
+ soundfile.write(
102
+ file=output_symphony_path, data=symphony_audio.T, samplerate=sample_rate
103
+ )
104
+ print("Write out to {}".format(output_symphony_path))
105
+
106
+ # Mix piano and symphony segments and write out a mixture segment.
107
+ mixture_audio = symphony_audio + piano_audio
108
+ output_mixture_path = os.path.join(
109
+ evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
110
+ )
111
+ soundfile.write(
112
+ file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
113
+ )
114
+ print("Write out to {}".format(output_mixture_path))
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+
120
+ parser.add_argument(
121
+ "--piano_dataset_dir",
122
+ type=str,
123
+ required=True,
124
+ help="The directory of the piano dataset.",
125
+ )
126
+ parser.add_argument(
127
+ "--symphony_dataset_dir",
128
+ type=str,
129
+ required=True,
130
+ help="The directory of the symphony dataset.",
131
+ )
132
+ parser.add_argument(
133
+ "--evaluation_audios_dir",
134
+ type=str,
135
+ required=True,
136
+ help="The directory to write out randomly selected and mixed audio segments.",
137
+ )
138
+ parser.add_argument(
139
+ "--sample_rate",
140
+ type=int,
141
+ required=True,
142
+ help="Sample rate.",
143
+ )
144
+ parser.add_argument(
145
+ "--channels",
146
+ type=int,
147
+ required=True,
148
+ help="Audio channels, e.g, 1 or 2.",
149
+ )
150
+ parser.add_argument(
151
+ "--evaluation_segments_num",
152
+ type=int,
153
+ required=True,
154
+ help="The number of segments to create for evaluation.",
155
+ )
156
+
157
+ # Parse arguments.
158
+ args = parser.parse_args()
159
+
160
+ create_evaluation(args)
bytesep/dataset_creation/create_evaluation_audios/vctk-musdb18.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import soundfile
4
+ from typing import NoReturn
5
+
6
+ import musdb
7
+ import numpy as np
8
+
9
+ from bytesep.utils import load_audio
10
+
11
+
12
+ def create_evaluation(args) -> NoReturn:
13
+ r"""Random mix and write out audios for evaluation.
14
+
15
+ Args:
16
+ vctk_dataset_dir: str, the directory of the VCTK dataset
17
+ symphony_dataset_dir: str, the directory of the symphony dataset
18
+ evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
19
+ sample_rate: int
20
+ channels: int, e.g., 1 | 2
21
+ evaluation_segments_num: int
22
+ mono: bool
23
+
24
+ Returns:
25
+ NoReturn
26
+ """
27
+
28
+ # arguments & parameters
29
+ vctk_dataset_dir = args.vctk_dataset_dir
30
+ musdb18_dataset_dir = args.musdb18_dataset_dir
31
+ evaluation_audios_dir = args.evaluation_audios_dir
32
+ sample_rate = args.sample_rate
33
+ channels = args.channels
34
+ evaluation_segments_num = args.evaluation_segments_num
35
+ mono = True if channels == 1 else False
36
+
37
+ split = 'test'
38
+ random_state = np.random.RandomState(1234)
39
+
40
+ # paths
41
+ audios_dir = os.path.join(vctk_dataset_dir, "wav48", split)
42
+
43
+ for source_type in ['speech', 'music', 'mixture']:
44
+ output_dir = os.path.join(evaluation_audios_dir, split, source_type)
45
+ os.makedirs(output_dir, exist_ok=True)
46
+
47
+ # Get VCTK audio paths.
48
+ speech_audio_paths = []
49
+ speaker_ids = sorted(os.listdir(audios_dir))
50
+
51
+ for speaker_id in speaker_ids:
52
+ speaker_audios_dir = os.path.join(audios_dir, speaker_id)
53
+
54
+ audio_names = sorted(os.listdir(speaker_audios_dir))
55
+
56
+ for audio_name in audio_names:
57
+ speaker_audio_path = os.path.join(speaker_audios_dir, audio_name)
58
+ speech_audio_paths.append(speaker_audio_path)
59
+
60
+ # Get Musdb18 audio paths.
61
+ mus = musdb.DB(root=musdb18_dataset_dir, subsets=[split])
62
+ track_indexes = np.arange(len(mus.tracks))
63
+
64
+ for n in range(evaluation_segments_num):
65
+
66
+ print('{} / {}'.format(n, evaluation_segments_num))
67
+
68
+ # Randomly select and write out a clean speech segment.
69
+ speech_audio_path = random_state.choice(speech_audio_paths)
70
+
71
+ speech_audio = load_audio(
72
+ audio_path=speech_audio_path, mono=mono, sample_rate=sample_rate
73
+ )
74
+ # (channels_num, audio_samples)
75
+
76
+ if channels == 2:
77
+ speech_audio = np.tile(speech_audio, (2, 1))
78
+ # (channels_num, audio_samples)
79
+
80
+ output_speech_path = os.path.join(
81
+ evaluation_audios_dir, split, 'speech', '{:04d}.wav'.format(n)
82
+ )
83
+ soundfile.write(
84
+ file=output_speech_path, data=speech_audio.T, samplerate=sample_rate
85
+ )
86
+ print("Write out to {}".format(output_speech_path))
87
+
88
+ # Randomly select and write out a clean music segment.
89
+ track_index = random_state.choice(track_indexes)
90
+ track = mus[track_index]
91
+
92
+ segment_samples = speech_audio.shape[1]
93
+ start_sample = int(
94
+ random_state.uniform(0.0, segment_samples - speech_audio.shape[1])
95
+ )
96
+
97
+ music_audio = track.audio[start_sample : start_sample + segment_samples, :].T
98
+ # (channels_num, audio_samples)
99
+
100
+ output_music_path = os.path.join(
101
+ evaluation_audios_dir, split, 'music', '{:04d}.wav'.format(n)
102
+ )
103
+ soundfile.write(
104
+ file=output_music_path, data=music_audio.T, samplerate=sample_rate
105
+ )
106
+ print("Write out to {}".format(output_music_path))
107
+
108
+ # Mix speech and music segments and write out a mixture segment.
109
+ mixture_audio = speech_audio + music_audio
110
+ # (channels_num, audio_samples)
111
+
112
+ output_mixture_path = os.path.join(
113
+ evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
114
+ )
115
+ soundfile.write(
116
+ file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
117
+ )
118
+ print("Write out to {}".format(output_mixture_path))
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser()
123
+
124
+ parser.add_argument(
125
+ "--vctk_dataset_dir",
126
+ type=str,
127
+ required=True,
128
+ help="The directory of the VCTK dataset.",
129
+ )
130
+ parser.add_argument(
131
+ "--musdb18_dataset_dir",
132
+ type=str,
133
+ required=True,
134
+ help="The directory of the MUSDB18 dataset.",
135
+ )
136
+ parser.add_argument(
137
+ "--evaluation_audios_dir",
138
+ type=str,
139
+ required=True,
140
+ help="The directory to write out randomly selected and mixed audio segments.",
141
+ )
142
+ parser.add_argument(
143
+ "--sample_rate",
144
+ type=int,
145
+ required=True,
146
+ help="Sample rate",
147
+ )
148
+ parser.add_argument(
149
+ "--channels",
150
+ type=int,
151
+ required=True,
152
+ help="Audio channels, e.g, 1 or 2.",
153
+ )
154
+ parser.add_argument(
155
+ "--evaluation_segments_num",
156
+ type=int,
157
+ required=True,
158
+ help="The number of segments to create for evaluation.",
159
+ )
160
+
161
+ # Parse arguments.
162
+ args = parser.parse_args()
163
+
164
+ create_evaluation(args)
bytesep/dataset_creation/create_evaluation_audios/violin-piano.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import NoReturn
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile
8
+
9
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
10
+ read_csv as read_instruments_solo_csv,
11
+ )
12
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.maestro import (
13
+ read_csv as read_maestro_csv,
14
+ )
15
+ from bytesep.utils import load_random_segment
16
+
17
+
18
+ def create_evaluation(args) -> NoReturn:
19
+ r"""Random mix and write out audios for evaluation.
20
+
21
+ Args:
22
+ violin_dataset_dir: str, the directory of the violin dataset
23
+ piano_dataset_dir: str, the directory of the piano dataset
24
+ evaluation_audios_dir: str, the directory to write out randomly selected and mixed audio segments
25
+ sample_rate: int
26
+ channels: int, e.g., 1 | 2
27
+ evaluation_segments_num: int
28
+ mono: bool
29
+
30
+ Returns:
31
+ NoReturn
32
+ """
33
+
34
+ # arguments & parameters
35
+ violin_dataset_dir = args.violin_dataset_dir
36
+ piano_dataset_dir = args.piano_dataset_dir
37
+ evaluation_audios_dir = args.evaluation_audios_dir
38
+ sample_rate = args.sample_rate
39
+ channels = args.channels
40
+ evaluation_segments_num = args.evaluation_segments_num
41
+ mono = True if channels == 1 else False
42
+
43
+ split = 'test'
44
+ segment_seconds = 10.0
45
+
46
+ random_state = np.random.RandomState(1234)
47
+
48
+ violin_meta_csv = os.path.join(violin_dataset_dir, 'validation.csv')
49
+ violin_names_dict = read_instruments_solo_csv(violin_meta_csv)
50
+ violin_audio_names = violin_names_dict['{}'.format(split)]
51
+
52
+ piano_meta_csv = os.path.join(piano_dataset_dir, 'maestro-v2.0.0.csv')
53
+ piano_names_dict = read_maestro_csv(piano_meta_csv)
54
+ piano_audio_names = piano_names_dict['{}'.format(split)]
55
+
56
+ for source_type in ['violin', 'piano', 'mixture']:
57
+ output_dir = os.path.join(evaluation_audios_dir, split, source_type)
58
+ os.makedirs(output_dir, exist_ok=True)
59
+
60
+ for n in range(evaluation_segments_num):
61
+
62
+ print('{} / {}'.format(n, evaluation_segments_num))
63
+
64
+ # Randomly select and write out a clean violin segment.
65
+ violin_audio_name = random_state.choice(violin_audio_names)
66
+ violin_audio_path = os.path.join(violin_dataset_dir, "mp3s", violin_audio_name)
67
+
68
+ violin_audio = load_random_segment(
69
+ audio_path=violin_audio_path,
70
+ random_state=random_state,
71
+ segment_seconds=segment_seconds,
72
+ mono=mono,
73
+ sample_rate=sample_rate,
74
+ )
75
+ # (channels_num, audio_samples)
76
+
77
+ output_violin_path = os.path.join(
78
+ evaluation_audios_dir, split, 'violin', '{:04d}.wav'.format(n)
79
+ )
80
+ soundfile.write(
81
+ file=output_violin_path, data=violin_audio.T, samplerate=sample_rate
82
+ )
83
+ print("Write out to {}".format(output_violin_path))
84
+
85
+ # Randomly select and write out a clean piano segment.
86
+ piano_audio_name = random_state.choice(piano_audio_names)
87
+ piano_audio_path = os.path.join(piano_dataset_dir, piano_audio_name)
88
+
89
+ piano_audio = load_random_segment(
90
+ audio_path=piano_audio_path,
91
+ random_state=random_state,
92
+ segment_seconds=segment_seconds,
93
+ mono=mono,
94
+ sample_rate=sample_rate,
95
+ )
96
+ # (channels_num, audio_samples)
97
+
98
+ output_piano_path = os.path.join(
99
+ evaluation_audios_dir, split, 'piano', '{:04d}.wav'.format(n)
100
+ )
101
+ soundfile.write(
102
+ file=output_piano_path, data=piano_audio.T, samplerate=sample_rate
103
+ )
104
+ print("Write out to {}".format(output_piano_path))
105
+
106
+ # Mix violin and piano segments and write out a mixture segment.
107
+ mixture_audio = violin_audio + piano_audio
108
+ # (channels_num, audio_samples)
109
+
110
+ output_mixture_path = os.path.join(
111
+ evaluation_audios_dir, split, 'mixture', '{:04d}.wav'.format(n)
112
+ )
113
+ soundfile.write(
114
+ file=output_mixture_path, data=mixture_audio.T, samplerate=sample_rate
115
+ )
116
+ print("Write out to {}".format(output_mixture_path))
117
+
118
+
119
+ if __name__ == "__main__":
120
+ parser = argparse.ArgumentParser()
121
+
122
+ parser.add_argument(
123
+ "--violin_dataset_dir",
124
+ type=str,
125
+ required=True,
126
+ help="The directory of the violin dataset.",
127
+ )
128
+ parser.add_argument(
129
+ "--piano_dataset_dir",
130
+ type=str,
131
+ required=True,
132
+ help="The directory of the piano dataset.",
133
+ )
134
+ parser.add_argument(
135
+ "--evaluation_audios_dir",
136
+ type=str,
137
+ required=True,
138
+ help="The directory to write out randomly selected and mixed audio segments.",
139
+ )
140
+ parser.add_argument(
141
+ "--sample_rate",
142
+ type=int,
143
+ required=True,
144
+ help="Sample rate",
145
+ )
146
+ parser.add_argument(
147
+ "--channels",
148
+ type=int,
149
+ required=True,
150
+ help="Audio channels, e.g, 1 or 2.",
151
+ )
152
+ parser.add_argument(
153
+ "--evaluation_segments_num",
154
+ type=int,
155
+ required=True,
156
+ help="The number of segments to create for evaluation.",
157
+ )
158
+
159
+ # Parse arguments.
160
+ args = parser.parse_args()
161
+
162
+ create_evaluation(args)
bytesep/dataset_creation/create_indexes/__init__.py ADDED
File without changes
bytesep/dataset_creation/create_indexes/create_indexes.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+ from typing import NoReturn
5
+
6
+ import h5py
7
+
8
+ from bytesep.utils import read_yaml
9
+
10
+
11
+ def create_indexes(args) -> NoReturn:
12
+ r"""Create and write out training indexes into disk. The indexes may contain
13
+ information from multiple datasets. During training, training indexes will
14
+ be shuffled and iterated for selecting segments to be mixed. E.g., the
15
+ training indexes_dict looks like: {
16
+ 'vocals': [
17
+ {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
18
+ {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
19
+ ...
20
+ ]
21
+ 'accompaniment': [
22
+ {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
23
+ {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
24
+ ...
25
+ ]
26
+ }
27
+ """
28
+
29
+ # Arugments & parameters
30
+ workspace = args.workspace
31
+ config_yaml = args.config_yaml
32
+
33
+ # Only create indexes for training, because evalution is on entire pieces.
34
+ split = "train"
35
+
36
+ # Read config file.
37
+ configs = read_yaml(config_yaml)
38
+
39
+ sample_rate = configs["sample_rate"]
40
+ segment_samples = int(configs["segment_seconds"] * sample_rate)
41
+
42
+ # Path to write out index.
43
+ indexes_path = os.path.join(workspace, configs[split]["indexes"])
44
+ os.makedirs(os.path.dirname(indexes_path), exist_ok=True)
45
+
46
+ source_types = configs[split]["source_types"].keys()
47
+ # E.g., ['vocals', 'accompaniment']
48
+
49
+ indexes_dict = {source_type: [] for source_type in source_types}
50
+ # E.g., indexes_dict will looks like: {
51
+ # 'vocals': [
52
+ # {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}
53
+ # {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 136710}
54
+ # ...
55
+ # ]
56
+ # 'accompaniment': [
57
+ # {'hdf5_path': '.../songA.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 0, 'end_sample': 132300}
58
+ # {'hdf5_path': '.../songB.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 4410, 'end_sample': 136710}
59
+ # ...
60
+ # ]
61
+ # }
62
+
63
+ # Get training indexes for each source type.
64
+ for source_type in source_types:
65
+ # E.g., ['vocals', 'bass', ...]
66
+
67
+ print("--- {} ---".format(source_type))
68
+
69
+ dataset_types = configs[split]["source_types"][source_type]
70
+ # E.g., ['musdb18', ...]
71
+
72
+ # Each source can come from mulitple datasets.
73
+ for dataset_type in dataset_types:
74
+
75
+ hdf5s_dir = os.path.join(
76
+ workspace, dataset_types[dataset_type]["hdf5s_directory"]
77
+ )
78
+
79
+ hop_samples = int(dataset_types[dataset_type]["hop_seconds"] * sample_rate)
80
+
81
+ key_in_hdf5 = dataset_types[dataset_type]["key_in_hdf5"]
82
+ # E.g., 'vocals'
83
+
84
+ hdf5_names = sorted(os.listdir(hdf5s_dir))
85
+ print("Hdf5 files num: {}".format(len(hdf5_names)))
86
+
87
+ # Traverse all packed hdf5 files of a dataset.
88
+ for n, hdf5_name in enumerate(hdf5_names):
89
+
90
+ print(n, hdf5_name)
91
+ hdf5_path = os.path.join(hdf5s_dir, hdf5_name)
92
+
93
+ with h5py.File(hdf5_path, "r") as hf:
94
+
95
+ bgn_sample = 0
96
+ while bgn_sample + segment_samples < hf[key_in_hdf5].shape[-1]:
97
+ meta = {
98
+ 'hdf5_path': hdf5_path,
99
+ 'key_in_hdf5': key_in_hdf5,
100
+ 'begin_sample': bgn_sample,
101
+ 'end_sample': bgn_sample + segment_samples,
102
+ }
103
+ indexes_dict[source_type].append(meta)
104
+
105
+ bgn_sample += hop_samples
106
+
107
+ # If the audio length is shorter than the segment length,
108
+ # then use the entire audio as a segment.
109
+ if bgn_sample == 0:
110
+ meta = {
111
+ 'hdf5_path': hdf5_path,
112
+ 'key_in_hdf5': key_in_hdf5,
113
+ 'begin_sample': 0,
114
+ 'end_sample': segment_samples,
115
+ }
116
+ indexes_dict[source_type].append(meta)
117
+
118
+ print(
119
+ "Total indexes for {}: {}".format(
120
+ source_type, len(indexes_dict[source_type])
121
+ )
122
+ )
123
+
124
+ pickle.dump(indexes_dict, open(indexes_path, "wb"))
125
+ print("Write index dict to {}".format(indexes_path))
126
+
127
+
128
+ if __name__ == "__main__":
129
+ parser = argparse.ArgumentParser()
130
+
131
+ parser.add_argument(
132
+ "--workspace", type=str, required=True, help="Directory of workspace."
133
+ )
134
+ parser.add_argument(
135
+ "--config_yaml", type=str, required=True, help="User defined config file."
136
+ )
137
+
138
+ # Parse arguments.
139
+ args = parser.parse_args()
140
+
141
+ # Create training indexes.
142
+ create_indexes(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/__init__.py ADDED
File without changes
bytesep/dataset_creation/pack_audios_to_hdf5s/instruments_solo.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from typing import Dict, List, NoReturn
7
+
8
+ import h5py
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from bytesep.utils import float32_to_int16, load_audio
13
+
14
+
15
+ def read_csv(meta_csv) -> Dict:
16
+ r"""Get train & test names from csv.
17
+
18
+ Args:
19
+ meta_csv: str
20
+
21
+ Returns:
22
+ names_dict: dict, e.g., {
23
+ 'train', ['songA.mp3', 'songB.mp3', ...],
24
+ 'test': ['songE.mp3', 'songF.mp3', ...]
25
+ }
26
+ """
27
+ df = pd.read_csv(meta_csv, sep=',')
28
+
29
+ names_dict = {}
30
+
31
+ for split in ['train', 'test']:
32
+ audio_indexes = df['split'] == split
33
+ audio_names = list(df['audio_name'][audio_indexes])
34
+ audio_names = [
35
+ '{}.mp3'.format(pathlib.Path(audio_name).stem) for audio_name in audio_names
36
+ ]
37
+ names_dict[split] = audio_names
38
+
39
+ return names_dict
40
+
41
+
42
+ def pack_audios_to_hdf5s(args) -> NoReturn:
43
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
44
+
45
+ Args:
46
+ dataset_dir: str
47
+ split: str, 'train' | 'test'
48
+ source_type: str
49
+ hdf5s_dir: str, directory to write out hdf5 files
50
+ sample_rate: int
51
+ channels_num: int
52
+ mono: bool
53
+
54
+ Returns:
55
+ NoReturn
56
+ """
57
+
58
+ # arguments & parameters
59
+ dataset_dir = args.dataset_dir
60
+ split = args.split
61
+ source_type = args.source_type
62
+ hdf5s_dir = args.hdf5s_dir
63
+ sample_rate = args.sample_rate
64
+ channels = args.channels
65
+ mono = True if channels == 1 else False
66
+
67
+ # Only pack data for training data.
68
+ assert split == "train"
69
+
70
+ # paths
71
+ audios_dir = os.path.join(dataset_dir, 'mp3s')
72
+ meta_csv = os.path.join(dataset_dir, 'validation.csv')
73
+
74
+ os.makedirs(hdf5s_dir, exist_ok=True)
75
+
76
+ # Read train & test names.
77
+ names_dict = read_csv(meta_csv)
78
+
79
+ audio_names = names_dict[split]
80
+
81
+ params = []
82
+
83
+ for audio_index, audio_name in enumerate(audio_names):
84
+
85
+ audio_path = os.path.join(audios_dir, audio_name)
86
+
87
+ hdf5_path = os.path.join(
88
+ hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
89
+ )
90
+
91
+ param = (
92
+ audio_index,
93
+ audio_name,
94
+ source_type,
95
+ audio_path,
96
+ mono,
97
+ sample_rate,
98
+ hdf5_path,
99
+ )
100
+ params.append(param)
101
+
102
+ # Uncomment for debug.
103
+ # write_single_audio_to_hdf5(params[0])
104
+ # os._exit()
105
+
106
+ pack_hdf5s_time = time.time()
107
+
108
+ with ProcessPoolExecutor(max_workers=None) as pool:
109
+ # Maximum works on the machine
110
+ pool.map(write_single_audio_to_hdf5, params)
111
+
112
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
113
+
114
+
115
+ def write_single_audio_to_hdf5(param: List) -> NoReturn:
116
+ r"""Write single audio into hdf5 file."""
117
+
118
+ (
119
+ audio_index,
120
+ audio_name,
121
+ source_type,
122
+ audio_path,
123
+ mono,
124
+ sample_rate,
125
+ hdf5_path,
126
+ ) = param
127
+
128
+ with h5py.File(hdf5_path, "w") as hf:
129
+
130
+ hf.attrs.create("audio_name", data=audio_name, dtype="S100")
131
+ hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
132
+
133
+ audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate)
134
+ # audio: (channels_num, audio_samples)
135
+
136
+ hf.create_dataset(
137
+ name=source_type, data=float32_to_int16(audio), dtype=np.int16
138
+ )
139
+
140
+ print('{} Write hdf5 to {}'.format(audio_index, hdf5_path))
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+
146
+ parser.add_argument(
147
+ "--dataset_dir",
148
+ type=str,
149
+ required=True,
150
+ help="Directory of the instruments solo dataset.",
151
+ )
152
+ parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
153
+ parser.add_argument(
154
+ "--source_type",
155
+ type=str,
156
+ required=True,
157
+ )
158
+ parser.add_argument(
159
+ "--hdf5s_dir",
160
+ type=str,
161
+ required=True,
162
+ help="Directory to write out hdf5 files.",
163
+ )
164
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
165
+ parser.add_argument(
166
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
167
+ )
168
+
169
+ # Parse arguments.
170
+ args = parser.parse_args()
171
+
172
+ # Pack audios to hdf5 files.
173
+ pack_audios_to_hdf5s(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/maestro.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from typing import Dict, NoReturn
7
+
8
+ import pandas as pd
9
+
10
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
11
+ write_single_audio_to_hdf5,
12
+ )
13
+
14
+
15
+ def read_csv(meta_csv) -> Dict:
16
+ r"""Get train & test names from csv.
17
+
18
+ Args:
19
+ meta_csv: str
20
+
21
+ Returns:
22
+ names_dict: dict, e.g., {
23
+ 'train', ['a1.mp3', 'a2.mp3'],
24
+ 'test': ['b1.mp3', 'b2.mp3']
25
+ }
26
+ """
27
+ df = pd.read_csv(meta_csv, sep=',')
28
+
29
+ names_dict = {}
30
+
31
+ for split in ['train', 'test']:
32
+ audio_indexes = df['split'] == split
33
+ audio_names = list(df['audio_filename'][audio_indexes])
34
+ names_dict[split] = audio_names
35
+
36
+ return names_dict
37
+
38
+
39
+ def pack_audios_to_hdf5s(args) -> NoReturn:
40
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
41
+
42
+ Args:
43
+ dataset_dir: str
44
+ split: str, 'train' | 'test'
45
+ hdf5s_dir: str, directory to write out hdf5 files
46
+ sample_rate: int
47
+ channels_num: int
48
+ mono: bool
49
+
50
+ Returns:
51
+ NoReturn
52
+ """
53
+
54
+ # arguments & parameters
55
+ dataset_dir = args.dataset_dir
56
+ split = args.split
57
+ hdf5s_dir = args.hdf5s_dir
58
+ sample_rate = args.sample_rate
59
+ channels = args.channels
60
+ mono = True if channels == 1 else False
61
+
62
+ source_type = "piano"
63
+
64
+ # Only pack data for training data.
65
+ assert split == "train"
66
+
67
+ # paths
68
+ meta_csv = os.path.join(dataset_dir, 'maestro-v2.0.0.csv')
69
+
70
+ os.makedirs(hdf5s_dir, exist_ok=True)
71
+
72
+ # Read train & test names.
73
+ names_dict = read_csv(meta_csv)
74
+
75
+ audio_names = names_dict['{}'.format(split)]
76
+
77
+ params = []
78
+
79
+ for audio_index, audio_name in enumerate(audio_names):
80
+
81
+ audio_path = os.path.join(dataset_dir, audio_name)
82
+
83
+ hdf5_path = os.path.join(
84
+ hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
85
+ )
86
+
87
+ param = (
88
+ audio_index,
89
+ audio_name,
90
+ source_type,
91
+ audio_path,
92
+ mono,
93
+ sample_rate,
94
+ hdf5_path,
95
+ )
96
+ params.append(param)
97
+
98
+ # Uncomment for debug.
99
+ # write_single_audio_to_hdf5(params[0])
100
+ # os._exit(0)
101
+
102
+ pack_hdf5s_time = time.time()
103
+
104
+ with ProcessPoolExecutor(max_workers=None) as pool:
105
+ # Maximum works on the machine
106
+ pool.map(write_single_audio_to_hdf5, params)
107
+
108
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser()
113
+
114
+ parser.add_argument(
115
+ "--dataset_dir",
116
+ type=str,
117
+ required=True,
118
+ help="Directory of the MAESTRO dataset.",
119
+ )
120
+ parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
121
+ parser.add_argument(
122
+ "--hdf5s_dir",
123
+ type=str,
124
+ required=True,
125
+ help="Directory to write out hdf5 files.",
126
+ )
127
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
128
+ parser.add_argument(
129
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
130
+ )
131
+
132
+ # Parse arguments.
133
+ args = parser.parse_args()
134
+
135
+ # Pack audios to hdf5 files.
136
+ pack_audios_to_hdf5s(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/musdb18.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ from typing import NoReturn
6
+
7
+ import h5py
8
+ import librosa
9
+ import musdb
10
+ import numpy as np
11
+
12
+ from bytesep.utils import float32_to_int16
13
+
14
+ # Source types of the MUSDB18 dataset.
15
+ SOURCE_TYPES = ["vocals", "drums", "bass", "other", "accompaniment"]
16
+
17
+
18
+ def pack_audios_to_hdf5s(args) -> NoReturn:
19
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
20
+
21
+ Args:
22
+ dataset_dir: str
23
+ subset: str, 'train' | 'test'
24
+ split: str, '' | 'train' | 'valid'
25
+ hdf5s_dir: str, directory to write out hdf5 files
26
+ sample_rate: int
27
+ channels_num: int
28
+ mono: bool
29
+
30
+ Returns:
31
+ NoReturn
32
+ """
33
+
34
+ # arguments & parameters
35
+ dataset_dir = args.dataset_dir
36
+ subset = args.subset
37
+ split = None if args.split == "" else args.split
38
+ hdf5s_dir = args.hdf5s_dir
39
+ sample_rate = args.sample_rate
40
+ channels = args.channels
41
+
42
+ mono = True if channels == 1 else False
43
+ source_types = SOURCE_TYPES
44
+ resample_type = "kaiser_fast"
45
+
46
+ # Paths
47
+ os.makedirs(hdf5s_dir, exist_ok=True)
48
+
49
+ # Dataset of corresponding subset and split.
50
+ mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split)
51
+ print("Subset: {}, Split: {}, Total pieces: {}".format(subset, split, len(mus)))
52
+
53
+ params = [] # A list of params for multiple processing.
54
+
55
+ for track_index in range(len(mus.tracks)):
56
+
57
+ param = (
58
+ dataset_dir,
59
+ subset,
60
+ split,
61
+ track_index,
62
+ source_types,
63
+ mono,
64
+ sample_rate,
65
+ resample_type,
66
+ hdf5s_dir,
67
+ )
68
+
69
+ params.append(param)
70
+
71
+ # Uncomment for debug.
72
+ # write_single_audio_to_hdf5(params[0])
73
+ # os._exit(0)
74
+
75
+ pack_hdf5s_time = time.time()
76
+
77
+ with ProcessPoolExecutor(max_workers=None) as pool:
78
+ # Maximum works on the machine
79
+ pool.map(write_single_audio_to_hdf5, params)
80
+
81
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
82
+
83
+
84
+ def write_single_audio_to_hdf5(param) -> NoReturn:
85
+ r"""Write single audio into hdf5 file."""
86
+ (
87
+ dataset_dir,
88
+ subset,
89
+ split,
90
+ track_index,
91
+ source_types,
92
+ mono,
93
+ sample_rate,
94
+ resample_type,
95
+ hdf5s_dir,
96
+ ) = param
97
+
98
+ # Dataset of corresponding subset and split.
99
+ mus = musdb.DB(root=dataset_dir, subsets=[subset], split=split)
100
+ track = mus.tracks[track_index]
101
+
102
+ # Path to write out hdf5 file.
103
+ hdf5_path = os.path.join(hdf5s_dir, "{}.h5".format(track.name))
104
+
105
+ with h5py.File(hdf5_path, "w") as hf:
106
+
107
+ hf.attrs.create("audio_name", data=track.name.encode(), dtype="S100")
108
+ hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
109
+
110
+ for source_type in source_types:
111
+
112
+ audio = track.targets[source_type].audio.T
113
+ # (channels_num, audio_samples)
114
+
115
+ # Preprocess audio to mono / stereo, and resample.
116
+ audio = preprocess_audio(
117
+ audio, mono, track.rate, sample_rate, resample_type
118
+ )
119
+ # audio = load_audio(audio_path=audio_path, mono=mono, sample_rate=sample_rate)
120
+ # (channels_num, audio_samples) | (audio_samples,)
121
+
122
+ hf.create_dataset(
123
+ name=source_type, data=float32_to_int16(audio), dtype=np.int16
124
+ )
125
+
126
+ # Mixture
127
+ audio = track.audio.T
128
+ # (channels_num, audio_samples)
129
+
130
+ # Preprocess audio to mono / stereo, and resample.
131
+ audio = preprocess_audio(audio, mono, track.rate, sample_rate, resample_type)
132
+ # (channels_num, audio_samples)
133
+
134
+ hf.create_dataset(name="mixture", data=float32_to_int16(audio), dtype=np.int16)
135
+
136
+ print("{} Write to {}, {}".format(track_index, hdf5_path, audio.shape))
137
+
138
+
139
+ def preprocess_audio(audio, mono, origin_sr, sr, resample_type) -> np.array:
140
+ r"""Preprocess audio to mono / stereo, and resample.
141
+
142
+ Args:
143
+ audio: (channels_num, audio_samples), input audio
144
+ mono: bool
145
+ origin_sr: float, original sample rate
146
+ sr: float, target sample rate
147
+ resample_type: str, e.g., 'kaiser_fast'
148
+
149
+ Returns:
150
+ output: ndarray, output audio
151
+ """
152
+ if mono:
153
+ audio = np.mean(audio, axis=0)
154
+ # (audio_samples,)
155
+
156
+ output = librosa.core.resample(
157
+ audio, orig_sr=origin_sr, target_sr=sr, res_type=resample_type
158
+ )
159
+ # (audio_samples,) | (channels_num, audio_samples)
160
+
161
+ if output.ndim == 1:
162
+ output = output[None, :]
163
+ # (1, audio_samples,)
164
+
165
+ return output
166
+
167
+
168
+ if __name__ == "__main__":
169
+ parser = argparse.ArgumentParser()
170
+
171
+ parser.add_argument(
172
+ "--dataset_dir",
173
+ type=str,
174
+ required=True,
175
+ help="Directory of the MUSDB18 dataset.",
176
+ )
177
+ parser.add_argument(
178
+ "--subset",
179
+ type=str,
180
+ required=True,
181
+ choices=["train", "test"],
182
+ help="Train subset: 100 pieces; test subset: 50 pieces.",
183
+ )
184
+ parser.add_argument(
185
+ "--split",
186
+ type=str,
187
+ required=True,
188
+ choices=["", "train", "valid"],
189
+ help="Use '' to use all 100 pieces to train. Use 'train' to use 86 \
190
+ pieces for train, and use 'test' to use 14 pieces for valid.",
191
+ )
192
+ parser.add_argument(
193
+ "--hdf5s_dir",
194
+ type=str,
195
+ required=True,
196
+ help="Directory to write out hdf5 files.",
197
+ )
198
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
199
+ parser.add_argument(
200
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
201
+ )
202
+
203
+ # Parse arguments.
204
+ args = parser.parse_args()
205
+
206
+ # Pack audios into hdf5 files.
207
+ pack_audios_to_hdf5s(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/vctk.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from typing import NoReturn
7
+
8
+ from bytesep.dataset_creation.pack_audios_to_hdf5s.instruments_solo import (
9
+ write_single_audio_to_hdf5,
10
+ )
11
+
12
+
13
+ def pack_audios_to_hdf5s(args) -> NoReturn:
14
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
15
+
16
+ Args:
17
+ dataset_dir: str
18
+ split: str, 'train' | 'test'
19
+ hdf5s_dir: str, directory to write out hdf5 files
20
+ sample_rate: int
21
+ channels_num: int
22
+ mono: bool
23
+
24
+ Returns:
25
+ NoReturn
26
+ """
27
+
28
+ # arguments & parameters
29
+ dataset_dir = args.dataset_dir
30
+ split = args.split
31
+ hdf5s_dir = args.hdf5s_dir
32
+ sample_rate = args.sample_rate
33
+ channels = args.channels
34
+ mono = True if channels == 1 else False
35
+
36
+ source_type = "speech"
37
+
38
+ # Only pack data for training data.
39
+ assert split == "train"
40
+
41
+ audios_dir = os.path.join(dataset_dir, 'wav48', split)
42
+ os.makedirs(hdf5s_dir, exist_ok=True)
43
+
44
+ speaker_ids = sorted(os.listdir(audios_dir))
45
+
46
+ params = []
47
+ audio_index = 0
48
+
49
+ for speaker_id in speaker_ids:
50
+
51
+ speaker_audios_dir = os.path.join(audios_dir, speaker_id)
52
+
53
+ audio_names = sorted(os.listdir(speaker_audios_dir))
54
+
55
+ for audio_name in audio_names:
56
+
57
+ audio_path = os.path.join(speaker_audios_dir, audio_name)
58
+
59
+ hdf5_path = os.path.join(
60
+ hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
61
+ )
62
+
63
+ param = (
64
+ audio_index,
65
+ audio_name,
66
+ source_type,
67
+ audio_path,
68
+ mono,
69
+ sample_rate,
70
+ hdf5_path,
71
+ )
72
+ params.append(param)
73
+
74
+ audio_index += 1
75
+
76
+ # Uncomment for debug.
77
+ # write_single_audio_to_hdf5(params[0])
78
+ # os._exit(0)
79
+
80
+ pack_hdf5s_time = time.time()
81
+
82
+ with ProcessPoolExecutor(max_workers=None) as pool:
83
+ # Maximum works on the machine
84
+ pool.map(write_single_audio_to_hdf5, params)
85
+
86
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
87
+
88
+
89
+ if __name__ == "__main__":
90
+ parser = argparse.ArgumentParser()
91
+
92
+ parser.add_argument(
93
+ "--dataset_dir",
94
+ type=str,
95
+ required=True,
96
+ help="Directory of the VCTK dataset.",
97
+ )
98
+ parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
99
+ parser.add_argument(
100
+ "--hdf5s_dir",
101
+ type=str,
102
+ required=True,
103
+ help="Directory to write out hdf5 files.",
104
+ )
105
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
106
+ parser.add_argument(
107
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
108
+ )
109
+
110
+ # Parse arguments.
111
+ args = parser.parse_args()
112
+
113
+ # Pack audios into hdf5 files.
114
+ pack_audios_to_hdf5s(args)
bytesep/dataset_creation/pack_audios_to_hdf5s/voicebank-demand.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from typing import List, NoReturn
7
+
8
+ import h5py
9
+ import numpy as np
10
+
11
+ from bytesep.utils import float32_to_int16, load_audio
12
+
13
+
14
+ def pack_audios_to_hdf5s(args) -> NoReturn:
15
+ r"""Pack (resampled) audio files into hdf5 files to speed up loading.
16
+
17
+ Args:
18
+ dataset_dir: str
19
+ split: str, 'train' | 'test'
20
+ hdf5s_dir: str, directory to write out hdf5 files
21
+ sample_rate: int
22
+ channels_num: int
23
+ mono: bool
24
+
25
+ Returns:
26
+ NoReturn
27
+ """
28
+
29
+ # arguments & parameters
30
+ dataset_dir = args.dataset_dir
31
+ split = args.split
32
+ hdf5s_dir = args.hdf5s_dir
33
+ sample_rate = args.sample_rate
34
+ channels = args.channels
35
+ mono = True if channels == 1 else False
36
+
37
+ # Only pack data for training data.
38
+ assert split == "train"
39
+
40
+ speech_dir = os.path.join(dataset_dir, "clean_{}set_wav".format(split))
41
+ mixture_dir = os.path.join(dataset_dir, "noisy_{}set_wav".format(split))
42
+
43
+ os.makedirs(hdf5s_dir, exist_ok=True)
44
+
45
+ # Read names.
46
+ audio_names = sorted(os.listdir(speech_dir))
47
+
48
+ params = []
49
+
50
+ for audio_index, audio_name in enumerate(audio_names):
51
+
52
+ speech_path = os.path.join(speech_dir, audio_name)
53
+ mixture_path = os.path.join(mixture_dir, audio_name)
54
+
55
+ hdf5_path = os.path.join(
56
+ hdf5s_dir, "{}.h5".format(pathlib.Path(audio_name).stem)
57
+ )
58
+
59
+ param = (
60
+ audio_index,
61
+ audio_name,
62
+ speech_path,
63
+ mixture_path,
64
+ mono,
65
+ sample_rate,
66
+ hdf5_path,
67
+ )
68
+ params.append(param)
69
+
70
+ # Uncomment for debug.
71
+ # write_single_audio_to_hdf5(params[0])
72
+ # os._exit(0)
73
+
74
+ pack_hdf5s_time = time.time()
75
+
76
+ with ProcessPoolExecutor(max_workers=None) as pool:
77
+ # Maximum works on the machine
78
+ pool.map(write_single_audio_to_hdf5, params)
79
+
80
+ print("Pack hdf5 time: {:.3f} s".format(time.time() - pack_hdf5s_time))
81
+
82
+
83
+ def write_single_audio_to_hdf5(param: List) -> NoReturn:
84
+ r"""Write single audio into hdf5 file."""
85
+
86
+ (
87
+ audio_index,
88
+ audio_name,
89
+ speech_path,
90
+ mixture_path,
91
+ mono,
92
+ sample_rate,
93
+ hdf5_path,
94
+ ) = param
95
+
96
+ with h5py.File(hdf5_path, "w") as hf:
97
+
98
+ hf.attrs.create("audio_name", data=audio_name, dtype="S100")
99
+ hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32)
100
+
101
+ speech = load_audio(audio_path=speech_path, mono=mono, sample_rate=sample_rate)
102
+ # speech: (channels_num, audio_samples)
103
+
104
+ mixture = load_audio(
105
+ audio_path=mixture_path, mono=mono, sample_rate=sample_rate
106
+ )
107
+ # mixture: (channels_num, audio_samples)
108
+
109
+ noise = mixture - speech
110
+ # noise: (channels_num, audio_samples)
111
+
112
+ hf.create_dataset(name='speech', data=float32_to_int16(speech), dtype=np.int16)
113
+ hf.create_dataset(name='noise', data=float32_to_int16(noise), dtype=np.int16)
114
+
115
+ print('{} Write hdf5 to {}'.format(audio_index, hdf5_path))
116
+
117
+
118
+ if __name__ == "__main__":
119
+ parser = argparse.ArgumentParser()
120
+
121
+ parser.add_argument(
122
+ "--dataset_dir",
123
+ type=str,
124
+ required=True,
125
+ help="Directory of the Voicebank-Demand dataset.",
126
+ )
127
+ parser.add_argument("--split", type=str, required=True, choices=["train", "test"])
128
+ parser.add_argument(
129
+ "--hdf5s_dir",
130
+ type=str,
131
+ required=True,
132
+ help="Directory to write out hdf5 files.",
133
+ )
134
+ parser.add_argument("--sample_rate", type=int, required=True, help="Sample rate.")
135
+ parser.add_argument(
136
+ "--channels", type=int, required=True, help="Use 1 for mono, 2 for stereo."
137
+ )
138
+
139
+ # Parse arguments.
140
+ args = parser.parse_args()
141
+
142
+ # Pack audios into hdf5 files.
143
+ pack_audios_to_hdf5s(args)
bytesep/inference.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from typing import Dict
5
+ import pathlib
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import soundfile
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from bytesep.models.lightning_modules import get_model_class
14
+ from bytesep.utils import read_yaml
15
+
16
+
17
+ class Separator:
18
+ def __init__(
19
+ self, model: nn.Module, segment_samples: int, batch_size: int, device: str
20
+ ):
21
+ r"""Separate to separate an audio clip into a target source.
22
+
23
+ Args:
24
+ model: nn.Module, trained model
25
+ segment_samples: int, length of segments to be input to a model, e.g., 44100*30
26
+ batch_size, int, e.g., 12
27
+ device: str, e.g., 'cuda'
28
+ """
29
+ self.model = model
30
+ self.segment_samples = segment_samples
31
+ self.batch_size = batch_size
32
+ self.device = device
33
+
34
+ def separate(self, input_dict: Dict) -> np.array:
35
+ r"""Separate an audio clip into a target source.
36
+
37
+ Args:
38
+ input_dict: dict, e.g., {
39
+ waveform: (channels_num, audio_samples),
40
+ ...,
41
+ }
42
+
43
+ Returns:
44
+ sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples)
45
+ """
46
+ audio = input_dict['waveform']
47
+
48
+ audio_samples = audio.shape[-1]
49
+
50
+ # Pad the audio with zero in the end so that the length of audio can be
51
+ # evenly divided by segment_samples.
52
+ audio = self.pad_audio(audio)
53
+
54
+ # Enframe long audio into segments.
55
+ segments = self.enframe(audio, self.segment_samples)
56
+ # (segments_num, channels_num, segment_samples)
57
+
58
+ segments_input_dict = {'waveform': segments}
59
+
60
+ if 'condition' in input_dict.keys():
61
+ segments_num = len(segments)
62
+ segments_input_dict['condition'] = np.tile(
63
+ input_dict['condition'][None, :], (segments_num, 1)
64
+ )
65
+ # (batch_size, segments_num)
66
+
67
+ # Separate in mini-batches.
68
+ sep_segments = self._forward_in_mini_batches(
69
+ self.model, segments_input_dict, self.batch_size
70
+ )['waveform']
71
+ # (segments_num, channels_num, segment_samples)
72
+
73
+ # Deframe segments into long audio.
74
+ sep_audio = self.deframe(sep_segments)
75
+ # (channels_num, padded_audio_samples)
76
+
77
+ sep_audio = sep_audio[:, 0:audio_samples]
78
+ # (channels_num, audio_samples)
79
+
80
+ return sep_audio
81
+
82
+ def pad_audio(self, audio: np.array) -> np.array:
83
+ r"""Pad the audio with zero in the end so that the length of audio can
84
+ be evenly divided by segment_samples.
85
+
86
+ Args:
87
+ audio: (channels_num, audio_samples)
88
+
89
+ Returns:
90
+ padded_audio: (channels_num, audio_samples)
91
+ """
92
+ channels_num, audio_samples = audio.shape
93
+
94
+ # Number of segments
95
+ segments_num = int(np.ceil(audio_samples / self.segment_samples))
96
+
97
+ pad_samples = segments_num * self.segment_samples - audio_samples
98
+
99
+ padded_audio = np.concatenate(
100
+ (audio, np.zeros((channels_num, pad_samples))), axis=1
101
+ )
102
+ # (channels_num, padded_audio_samples)
103
+
104
+ return padded_audio
105
+
106
+ def enframe(self, audio: np.array, segment_samples: int) -> np.array:
107
+ r"""Enframe long audio into segments.
108
+
109
+ Args:
110
+ audio: (channels_num, audio_samples)
111
+ segment_samples: int
112
+
113
+ Returns:
114
+ segments: (segments_num, channels_num, segment_samples)
115
+ """
116
+ audio_samples = audio.shape[1]
117
+ assert audio_samples % segment_samples == 0
118
+
119
+ hop_samples = segment_samples // 2
120
+ segments = []
121
+
122
+ pointer = 0
123
+ while pointer + segment_samples <= audio_samples:
124
+ segments.append(audio[:, pointer : pointer + segment_samples])
125
+ pointer += hop_samples
126
+
127
+ segments = np.array(segments)
128
+
129
+ return segments
130
+
131
+ def deframe(self, segments: np.array) -> np.array:
132
+ r"""Deframe segments into long audio.
133
+
134
+ Args:
135
+ segments: (segments_num, channels_num, segment_samples)
136
+
137
+ Returns:
138
+ output: (channels_num, audio_samples)
139
+ """
140
+ (segments_num, _, segment_samples) = segments.shape
141
+
142
+ if segments_num == 1:
143
+ return segments[0]
144
+
145
+ assert self._is_integer(segment_samples * 0.25)
146
+ assert self._is_integer(segment_samples * 0.75)
147
+
148
+ output = []
149
+
150
+ output.append(segments[0, :, 0 : int(segment_samples * 0.75)])
151
+
152
+ for i in range(1, segments_num - 1):
153
+ output.append(
154
+ segments[
155
+ i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75)
156
+ ]
157
+ )
158
+
159
+ output.append(segments[-1, :, int(segment_samples * 0.25) :])
160
+
161
+ output = np.concatenate(output, axis=-1)
162
+
163
+ return output
164
+
165
+ def _is_integer(self, x: float) -> bool:
166
+ if x - int(x) < 1e-10:
167
+ return True
168
+ else:
169
+ return False
170
+
171
+ def _forward_in_mini_batches(
172
+ self, model: nn.Module, segments_input_dict: Dict, batch_size: int
173
+ ) -> Dict:
174
+ r"""Forward data to model in mini-batch.
175
+
176
+ Args:
177
+ model: nn.Module
178
+ segments_input_dict: dict, e.g., {
179
+ 'waveform': (segments_num, channels_num, segment_samples),
180
+ ...,
181
+ }
182
+ batch_size: int
183
+
184
+ Returns:
185
+ output_dict: dict, e.g. {
186
+ 'waveform': (segments_num, channels_num, segment_samples),
187
+ }
188
+ """
189
+ output_dict = {}
190
+
191
+ pointer = 0
192
+ segments_num = len(segments_input_dict['waveform'])
193
+
194
+ while True:
195
+ if pointer >= segments_num:
196
+ break
197
+
198
+ batch_input_dict = {}
199
+
200
+ for key in segments_input_dict.keys():
201
+ batch_input_dict[key] = torch.Tensor(
202
+ segments_input_dict[key][pointer : pointer + batch_size]
203
+ ).to(self.device)
204
+
205
+ pointer += batch_size
206
+
207
+ with torch.no_grad():
208
+ model.eval()
209
+ batch_output_dict = model(batch_input_dict)
210
+
211
+ for key in batch_output_dict.keys():
212
+ self._append_to_dict(
213
+ output_dict, key, batch_output_dict[key].data.cpu().numpy()
214
+ )
215
+
216
+ for key in output_dict.keys():
217
+ output_dict[key] = np.concatenate(output_dict[key], axis=0)
218
+
219
+ return output_dict
220
+
221
+ def _append_to_dict(self, dict, key, value):
222
+ if key in dict.keys():
223
+ dict[key].append(value)
224
+ else:
225
+ dict[key] = [value]
226
+
227
+
228
+ class SeparatorWrapper:
229
+ def __init__(
230
+ self, source_type='vocals', model=None, checkpoint_path=None, device='cuda'
231
+ ):
232
+
233
+ input_channels = 2
234
+ target_sources_num = 1
235
+ model_type = "ResUNet143_Subbandtime"
236
+ segment_samples = 44100 * 10
237
+ batch_size = 1
238
+
239
+ self.checkpoint_path = self.download_checkpoints(checkpoint_path, source_type)
240
+
241
+ if device == 'cuda' and torch.cuda.is_available():
242
+ self.device = 'cuda'
243
+ else:
244
+ self.device = 'cpu'
245
+
246
+ # Get model class.
247
+ Model = get_model_class(model_type)
248
+
249
+ # Create model.
250
+ self.model = Model(
251
+ input_channels=input_channels, target_sources_num=target_sources_num
252
+ )
253
+
254
+ # Load checkpoint.
255
+ checkpoint = torch.load(self.checkpoint_path, map_location='cpu')
256
+ self.model.load_state_dict(checkpoint["model"])
257
+
258
+ # Move model to device.
259
+ self.model.to(self.device)
260
+
261
+ # Create separator.
262
+ self.separator = Separator(
263
+ model=self.model,
264
+ segment_samples=segment_samples,
265
+ batch_size=batch_size,
266
+ device=self.device,
267
+ )
268
+
269
+ def download_checkpoints(self, checkpoint_path, source_type):
270
+
271
+ if source_type == "vocals":
272
+ checkpoint_bare_name = "resunet143_subbtandtime_vocals_8.8dB_350k_steps"
273
+
274
+ elif source_type == "accompaniment":
275
+ checkpoint_bare_name = (
276
+ "resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth"
277
+ )
278
+
279
+ else:
280
+ raise NotImplementedError
281
+
282
+ if not checkpoint_path:
283
+ checkpoint_path = '{}/bytesep_data/{}.pth'.format(
284
+ str(pathlib.Path.home()), checkpoint_bare_name
285
+ )
286
+
287
+ print('Checkpoint path: {}'.format(checkpoint_path))
288
+
289
+ if (
290
+ not os.path.exists(checkpoint_path)
291
+ or os.path.getsize(checkpoint_path) < 4e8
292
+ ):
293
+
294
+ os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
295
+
296
+ zenodo_dir = "https://zenodo.org/record/5507029/files"
297
+ zenodo_path = os.path.join(
298
+ zenodo_dir, "{}?download=1".format(checkpoint_bare_name)
299
+ )
300
+
301
+ os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
302
+
303
+ return checkpoint_path
304
+
305
+ def separate(self, audio):
306
+
307
+ input_dict = {'waveform': audio}
308
+
309
+ sep_wav = self.separator.separate(input_dict)
310
+
311
+ return sep_wav
312
+
313
+
314
+ def inference(args):
315
+
316
+ # Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
317
+ import torch.distributed as dist
318
+
319
+ dist.init_process_group(
320
+ 'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
321
+ )
322
+
323
+ # Arguments & parameters
324
+ config_yaml = args.config_yaml
325
+ checkpoint_path = args.checkpoint_path
326
+ audio_path = args.audio_path
327
+ output_path = args.output_path
328
+ device = (
329
+ torch.device('cuda')
330
+ if args.cuda and torch.cuda.is_available()
331
+ else torch.device('cpu')
332
+ )
333
+
334
+ configs = read_yaml(config_yaml)
335
+ sample_rate = configs['train']['sample_rate']
336
+ input_channels = configs['train']['channels']
337
+ target_source_types = configs['train']['target_source_types']
338
+ target_sources_num = len(target_source_types)
339
+ model_type = configs['train']['model_type']
340
+
341
+ segment_samples = int(30 * sample_rate)
342
+ batch_size = 1
343
+
344
+ print("Using {} for separating ..".format(device))
345
+
346
+ # paths
347
+ if os.path.dirname(output_path) != "":
348
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
349
+
350
+ # Get model class.
351
+ Model = get_model_class(model_type)
352
+
353
+ # Create model.
354
+ model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
355
+
356
+ # Load checkpoint.
357
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
358
+ model.load_state_dict(checkpoint["model"])
359
+
360
+ # Move model to device.
361
+ model.to(device)
362
+
363
+ # Create separator.
364
+ separator = Separator(
365
+ model=model,
366
+ segment_samples=segment_samples,
367
+ batch_size=batch_size,
368
+ device=device,
369
+ )
370
+
371
+ # Load audio.
372
+ audio, _ = librosa.load(audio_path, sr=sample_rate, mono=False)
373
+
374
+ # audio = audio[None, :]
375
+
376
+ input_dict = {'waveform': audio}
377
+
378
+ # Separate
379
+ separate_time = time.time()
380
+
381
+ sep_wav = separator.separate(input_dict)
382
+ # (channels_num, audio_samples)
383
+
384
+ print('Separate time: {:.3f} s'.format(time.time() - separate_time))
385
+
386
+ # Write out separated audio.
387
+ soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
388
+ os.system("ffmpeg -y -loglevel panic -i _zz.wav {}".format(output_path))
389
+ print('Write out to {}'.format(output_path))
390
+
391
+
392
+ if __name__ == "__main__":
393
+
394
+ parser = argparse.ArgumentParser(description="")
395
+ parser.add_argument("--config_yaml", type=str, required=True)
396
+ parser.add_argument("--checkpoint_path", type=str, required=True)
397
+ parser.add_argument("--audio_path", type=str, required=True)
398
+ parser.add_argument("--output_path", type=str, required=True)
399
+ parser.add_argument("--cuda", action='store_true', default=True)
400
+
401
+ args = parser.parse_args()
402
+ inference(args)
bytesep/inference_many.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import time
5
+ from typing import NoReturn
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import soundfile
10
+ import torch
11
+
12
+ from bytesep.inference import Separator
13
+ from bytesep.models.lightning_modules import get_model_class
14
+ from bytesep.utils import read_yaml
15
+
16
+
17
+ def inference(args) -> NoReturn:
18
+ r"""Separate all audios in a directory.
19
+
20
+ Args:
21
+ config_yaml: str, the config file of a model being trained
22
+ checkpoint_path: str, the path of checkpoint to be loaded
23
+ audios_dir: str, the directory of audios to be separated
24
+ output_dir: str, the directory to write out separated audios
25
+ scale_volume: bool, if True then the volume is scaled to the maximum value of 1.
26
+
27
+ Returns:
28
+ NoReturn
29
+ """
30
+
31
+ # Arguments & parameters
32
+ config_yaml = args.config_yaml
33
+ checkpoint_path = args.checkpoint_path
34
+ audios_dir = args.audios_dir
35
+ output_dir = args.output_dir
36
+ scale_volume = args.scale_volume
37
+ device = (
38
+ torch.device('cuda')
39
+ if args.cuda and torch.cuda.is_available()
40
+ else torch.device('cpu')
41
+ )
42
+
43
+ configs = read_yaml(config_yaml)
44
+ sample_rate = configs['train']['sample_rate']
45
+ input_channels = configs['train']['channels']
46
+ target_source_types = configs['train']['target_source_types']
47
+ target_sources_num = len(target_source_types)
48
+ model_type = configs['train']['model_type']
49
+ mono = input_channels == 1
50
+
51
+ segment_samples = int(30 * sample_rate)
52
+ batch_size = 1
53
+ device = "cuda"
54
+
55
+ models_contains_inplaceabn = True
56
+
57
+ # Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
58
+ if models_contains_inplaceabn:
59
+
60
+ import torch.distributed as dist
61
+
62
+ dist.init_process_group(
63
+ 'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
64
+ )
65
+
66
+ print("Using {} for separating ..".format(device))
67
+
68
+ # paths
69
+ os.makedirs(output_dir, exist_ok=True)
70
+
71
+ # Get model class.
72
+ Model = get_model_class(model_type)
73
+
74
+ # Create model.
75
+ model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
76
+
77
+ # Load checkpoint.
78
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
79
+ model.load_state_dict(checkpoint["model"])
80
+
81
+ # Move model to device.
82
+ model.to(device)
83
+
84
+ # Create separator.
85
+ separator = Separator(
86
+ model=model,
87
+ segment_samples=segment_samples,
88
+ batch_size=batch_size,
89
+ device=device,
90
+ )
91
+
92
+ audio_names = sorted(os.listdir(audios_dir))
93
+
94
+ for audio_name in audio_names:
95
+ audio_path = os.path.join(audios_dir, audio_name)
96
+
97
+ # Load audio.
98
+ audio, _ = librosa.load(audio_path, sr=sample_rate, mono=mono)
99
+
100
+ if audio.ndim == 1:
101
+ audio = audio[None, :]
102
+
103
+ input_dict = {'waveform': audio}
104
+
105
+ # Separate
106
+ separate_time = time.time()
107
+
108
+ sep_wav = separator.separate(input_dict)
109
+ # (channels_num, audio_samples)
110
+
111
+ print('Separate time: {:.3f} s'.format(time.time() - separate_time))
112
+
113
+ # Write out separated audio.
114
+ if scale_volume:
115
+ sep_wav /= np.max(np.abs(sep_wav))
116
+
117
+ soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
118
+
119
+ output_path = os.path.join(
120
+ output_dir, '{}.mp3'.format(pathlib.Path(audio_name).stem)
121
+ )
122
+ os.system('ffmpeg -y -loglevel panic -i _zz.wav "{}"'.format(output_path))
123
+ print('Write out to {}'.format(output_path))
124
+
125
+
126
+ if __name__ == "__main__":
127
+
128
+ parser = argparse.ArgumentParser(description="")
129
+ parser.add_argument(
130
+ "--config_yaml",
131
+ type=str,
132
+ required=True,
133
+ help="The config file of a model being trained.",
134
+ )
135
+ parser.add_argument(
136
+ "--checkpoint_path",
137
+ type=str,
138
+ required=True,
139
+ help="The path of checkpoint to be loaded.",
140
+ )
141
+ parser.add_argument(
142
+ "--audios_dir",
143
+ type=str,
144
+ required=True,
145
+ help="The directory of audios to be separated.",
146
+ )
147
+ parser.add_argument(
148
+ "--output_dir",
149
+ type=str,
150
+ required=True,
151
+ help="The directory to write out separated audios.",
152
+ )
153
+ parser.add_argument(
154
+ '--scale_volume',
155
+ action='store_true',
156
+ default=False,
157
+ help="set to True if separated audios are scaled to the maximum value of 1.",
158
+ )
159
+ parser.add_argument("--cuda", action='store_true', default=True)
160
+
161
+ args = parser.parse_args()
162
+
163
+ inference(args)
bytesep/losses.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchlibrosa.stft import STFT
7
+
8
+ from bytesep.models.pytorch_modules import Base
9
+
10
+
11
+ def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
12
+ r"""L1 loss.
13
+
14
+ Args:
15
+ output: torch.Tensor
16
+ target: torch.Tensor
17
+
18
+ Returns:
19
+ loss: torch.float
20
+ """
21
+ return torch.mean(torch.abs(output - target))
22
+
23
+
24
+ def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
25
+ r"""L1 loss in the time-domain.
26
+
27
+ Args:
28
+ output: torch.Tensor
29
+ target: torch.Tensor
30
+
31
+ Returns:
32
+ loss: torch.float
33
+ """
34
+ return l1(output, target)
35
+
36
+
37
+ class L1_Wav_L1_Sp(nn.Module, Base):
38
+ def __init__(self):
39
+ r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
40
+ super(L1_Wav_L1_Sp, self).__init__()
41
+
42
+ self.window_size = 2048
43
+ hop_size = 441
44
+ center = True
45
+ pad_mode = "reflect"
46
+ window = "hann"
47
+
48
+ self.stft = STFT(
49
+ n_fft=self.window_size,
50
+ hop_length=hop_size,
51
+ win_length=self.window_size,
52
+ window=window,
53
+ center=center,
54
+ pad_mode=pad_mode,
55
+ freeze_parameters=True,
56
+ )
57
+
58
+ def __call__(
59
+ self, output: torch.Tensor, target: torch.Tensor, **kwargs
60
+ ) -> torch.Tensor:
61
+ r"""L1 loss in the time-domain and on the spectrogram.
62
+
63
+ Args:
64
+ output: torch.Tensor
65
+ target: torch.Tensor
66
+
67
+ Returns:
68
+ loss: torch.float
69
+ """
70
+
71
+ # L1 loss in the time-domain.
72
+ wav_loss = l1_wav(output, target)
73
+
74
+ # L1 loss on the spectrogram.
75
+ sp_loss = l1(
76
+ self.wav_to_spectrogram(output, eps=1e-8),
77
+ self.wav_to_spectrogram(target, eps=1e-8),
78
+ )
79
+
80
+ # sp_loss /= math.sqrt(self.window_size)
81
+ # sp_loss *= 1.
82
+
83
+ # Total loss.
84
+ return wav_loss + sp_loss
85
+
86
+ return sp_loss
87
+
88
+
89
+ def get_loss_function(loss_type: str) -> Callable:
90
+ r"""Get loss function.
91
+
92
+ Args:
93
+ loss_type: str
94
+
95
+ Returns:
96
+ loss function: Callable
97
+ """
98
+
99
+ if loss_type == "l1_wav":
100
+ return l1_wav
101
+
102
+ elif loss_type == "l1_wav_l1_sp":
103
+ return L1_Wav_L1_Sp()
104
+
105
+ else:
106
+ raise NotImplementedError
bytesep/models/__init__.py ADDED
File without changes
bytesep/models/conditional_unet.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from torchlibrosa.stft import STFT, ISTFT, magphase
13
+
14
+ from bytesep.models.pytorch_modules import (
15
+ Base,
16
+ init_bn,
17
+ init_embedding,
18
+ init_layer,
19
+ act,
20
+ Subband,
21
+ )
22
+
23
+
24
+ class ConvBlock(nn.Module):
25
+ def __init__(
26
+ self,
27
+ in_channels,
28
+ out_channels,
29
+ condition_size,
30
+ kernel_size,
31
+ activation,
32
+ momentum,
33
+ ):
34
+ super(ConvBlock, self).__init__()
35
+
36
+ self.activation = activation
37
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
38
+
39
+ self.conv1 = nn.Conv2d(
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=kernel_size,
43
+ stride=(1, 1),
44
+ dilation=(1, 1),
45
+ padding=padding,
46
+ bias=False,
47
+ )
48
+
49
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
50
+
51
+ self.conv2 = nn.Conv2d(
52
+ in_channels=out_channels,
53
+ out_channels=out_channels,
54
+ kernel_size=kernel_size,
55
+ stride=(1, 1),
56
+ dilation=(1, 1),
57
+ padding=padding,
58
+ bias=False,
59
+ )
60
+
61
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
62
+
63
+ self.beta1 = nn.Linear(condition_size, out_channels, bias=True)
64
+ self.beta2 = nn.Linear(condition_size, out_channels, bias=True)
65
+
66
+ self.init_weights()
67
+
68
+ def init_weights(self):
69
+ init_layer(self.conv1)
70
+ init_layer(self.conv2)
71
+ init_bn(self.bn1)
72
+ init_bn(self.bn2)
73
+ init_embedding(self.beta1)
74
+ init_embedding(self.beta2)
75
+
76
+ def forward(self, x, condition):
77
+
78
+ b1 = self.beta1(condition)[:, :, None, None]
79
+ b2 = self.beta2(condition)[:, :, None, None]
80
+
81
+ x = act(self.bn1(self.conv1(x)) + b1, self.activation)
82
+ x = act(self.bn2(self.conv2(x)) + b2, self.activation)
83
+ return x
84
+
85
+
86
+ class EncoderBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ out_channels,
91
+ condition_size,
92
+ kernel_size,
93
+ downsample,
94
+ activation,
95
+ momentum,
96
+ ):
97
+ super(EncoderBlock, self).__init__()
98
+
99
+ self.conv_block = ConvBlock(
100
+ in_channels, out_channels, condition_size, kernel_size, activation, momentum
101
+ )
102
+ self.downsample = downsample
103
+
104
+ def forward(self, x, condition):
105
+ encoder = self.conv_block(x, condition)
106
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
107
+ return encoder_pool, encoder
108
+
109
+
110
+ class DecoderBlock(nn.Module):
111
+ def __init__(
112
+ self,
113
+ in_channels,
114
+ out_channels,
115
+ condition_size,
116
+ kernel_size,
117
+ upsample,
118
+ activation,
119
+ momentum,
120
+ ):
121
+ super(DecoderBlock, self).__init__()
122
+ self.kernel_size = kernel_size
123
+ self.stride = upsample
124
+ self.activation = activation
125
+
126
+ self.conv1 = torch.nn.ConvTranspose2d(
127
+ in_channels=in_channels,
128
+ out_channels=out_channels,
129
+ kernel_size=self.stride,
130
+ stride=self.stride,
131
+ padding=(0, 0),
132
+ bias=False,
133
+ dilation=(1, 1),
134
+ )
135
+
136
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
137
+
138
+ self.conv_block2 = ConvBlock(
139
+ out_channels * 2,
140
+ out_channels,
141
+ condition_size,
142
+ kernel_size,
143
+ activation,
144
+ momentum,
145
+ )
146
+
147
+ self.beta1 = nn.Linear(condition_size, out_channels, bias=True)
148
+
149
+ self.init_weights()
150
+
151
+ def init_weights(self):
152
+ init_layer(self.conv1)
153
+ init_bn(self.bn1)
154
+ init_embedding(self.beta1)
155
+
156
+ def forward(self, input_tensor, concat_tensor, condition):
157
+ b1 = self.beta1(condition)[:, :, None, None]
158
+ x = act(self.bn1(self.conv1(input_tensor)) + b1, self.activation)
159
+ x = torch.cat((x, concat_tensor), dim=1)
160
+ x = self.conv_block2(x, condition)
161
+ return x
162
+
163
+
164
+ class ConditionalUNet(nn.Module, Base):
165
+ def __init__(self, input_channels, target_sources_num):
166
+ super(ConditionalUNet, self).__init__()
167
+
168
+ self.input_channels = input_channels
169
+ condition_size = target_sources_num
170
+ self.output_sources_num = 1
171
+
172
+ window_size = 2048
173
+ hop_size = 441
174
+ center = True
175
+ pad_mode = "reflect"
176
+ window = "hann"
177
+ activation = "relu"
178
+ momentum = 0.01
179
+
180
+ self.subbands_num = 4
181
+ self.K = 3 # outputs: |M|, cos∠M, sin∠M
182
+
183
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
184
+
185
+ self.stft = STFT(
186
+ n_fft=window_size,
187
+ hop_length=hop_size,
188
+ win_length=window_size,
189
+ window=window,
190
+ center=center,
191
+ pad_mode=pad_mode,
192
+ freeze_parameters=True,
193
+ )
194
+
195
+ self.istft = ISTFT(
196
+ n_fft=window_size,
197
+ hop_length=hop_size,
198
+ win_length=window_size,
199
+ window=window,
200
+ center=center,
201
+ pad_mode=pad_mode,
202
+ freeze_parameters=True,
203
+ )
204
+
205
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
206
+
207
+ self.subband = Subband(subbands_num=self.subbands_num)
208
+
209
+ self.encoder_block1 = EncoderBlock(
210
+ in_channels=input_channels * self.subbands_num,
211
+ out_channels=32,
212
+ condition_size=condition_size,
213
+ kernel_size=(3, 3),
214
+ downsample=(2, 2),
215
+ activation=activation,
216
+ momentum=momentum,
217
+ )
218
+ self.encoder_block2 = EncoderBlock(
219
+ in_channels=32,
220
+ out_channels=64,
221
+ condition_size=condition_size,
222
+ kernel_size=(3, 3),
223
+ downsample=(2, 2),
224
+ activation=activation,
225
+ momentum=momentum,
226
+ )
227
+ self.encoder_block3 = EncoderBlock(
228
+ in_channels=64,
229
+ out_channels=128,
230
+ condition_size=condition_size,
231
+ kernel_size=(3, 3),
232
+ downsample=(2, 2),
233
+ activation=activation,
234
+ momentum=momentum,
235
+ )
236
+ self.encoder_block4 = EncoderBlock(
237
+ in_channels=128,
238
+ out_channels=256,
239
+ condition_size=condition_size,
240
+ kernel_size=(3, 3),
241
+ downsample=(2, 2),
242
+ activation=activation,
243
+ momentum=momentum,
244
+ )
245
+ self.encoder_block5 = EncoderBlock(
246
+ in_channels=256,
247
+ out_channels=384,
248
+ condition_size=condition_size,
249
+ kernel_size=(3, 3),
250
+ downsample=(2, 2),
251
+ activation=activation,
252
+ momentum=momentum,
253
+ )
254
+ self.encoder_block6 = EncoderBlock(
255
+ in_channels=384,
256
+ out_channels=384,
257
+ condition_size=condition_size,
258
+ kernel_size=(3, 3),
259
+ downsample=(2, 2),
260
+ activation=activation,
261
+ momentum=momentum,
262
+ )
263
+ self.conv_block7 = ConvBlock(
264
+ in_channels=384,
265
+ out_channels=384,
266
+ condition_size=condition_size,
267
+ kernel_size=(3, 3),
268
+ activation=activation,
269
+ momentum=momentum,
270
+ )
271
+ self.decoder_block1 = DecoderBlock(
272
+ in_channels=384,
273
+ out_channels=384,
274
+ condition_size=condition_size,
275
+ kernel_size=(3, 3),
276
+ upsample=(2, 2),
277
+ activation=activation,
278
+ momentum=momentum,
279
+ )
280
+ self.decoder_block2 = DecoderBlock(
281
+ in_channels=384,
282
+ out_channels=384,
283
+ condition_size=condition_size,
284
+ kernel_size=(3, 3),
285
+ upsample=(2, 2),
286
+ activation=activation,
287
+ momentum=momentum,
288
+ )
289
+ self.decoder_block3 = DecoderBlock(
290
+ in_channels=384,
291
+ out_channels=256,
292
+ condition_size=condition_size,
293
+ kernel_size=(3, 3),
294
+ upsample=(2, 2),
295
+ activation=activation,
296
+ momentum=momentum,
297
+ )
298
+ self.decoder_block4 = DecoderBlock(
299
+ in_channels=256,
300
+ out_channels=128,
301
+ condition_size=condition_size,
302
+ kernel_size=(3, 3),
303
+ upsample=(2, 2),
304
+ activation=activation,
305
+ momentum=momentum,
306
+ )
307
+ self.decoder_block5 = DecoderBlock(
308
+ in_channels=128,
309
+ out_channels=64,
310
+ condition_size=condition_size,
311
+ kernel_size=(3, 3),
312
+ upsample=(2, 2),
313
+ activation=activation,
314
+ momentum=momentum,
315
+ )
316
+ self.decoder_block6 = DecoderBlock(
317
+ in_channels=64,
318
+ out_channels=32,
319
+ condition_size=condition_size,
320
+ kernel_size=(3, 3),
321
+ upsample=(2, 2),
322
+ activation=activation,
323
+ momentum=momentum,
324
+ )
325
+
326
+ self.after_conv_block1 = ConvBlock(
327
+ in_channels=32,
328
+ out_channels=32,
329
+ condition_size=condition_size,
330
+ kernel_size=(3, 3),
331
+ activation=activation,
332
+ momentum=momentum,
333
+ )
334
+
335
+ self.after_conv2 = nn.Conv2d(
336
+ in_channels=32,
337
+ out_channels=input_channels
338
+ * self.subbands_num
339
+ * self.output_sources_num
340
+ * self.K,
341
+ kernel_size=(1, 1),
342
+ stride=(1, 1),
343
+ padding=(0, 0),
344
+ bias=True,
345
+ )
346
+
347
+ self.init_weights()
348
+
349
+ def init_weights(self):
350
+ init_bn(self.bn0)
351
+ init_layer(self.after_conv2)
352
+
353
+ def feature_maps_to_wav(self, x, sp, sin_in, cos_in, audio_length):
354
+
355
+ batch_size, _, time_steps, freq_bins = x.shape
356
+
357
+ x = x.reshape(
358
+ batch_size,
359
+ self.output_sources_num,
360
+ self.input_channels,
361
+ self.K,
362
+ time_steps,
363
+ freq_bins,
364
+ )
365
+ # x: (batch_size, output_sources_num, input_channles, K, time_steps, freq_bins)
366
+
367
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
368
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
369
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
370
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
371
+ # mask_cos, mask_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
372
+
373
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
374
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
375
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
376
+ out_cos = (
377
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
378
+ )
379
+ out_sin = (
380
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
381
+ )
382
+ # out_cos, out_sin: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
383
+
384
+ # Calculate |Y|.
385
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
386
+ # out_mag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
387
+
388
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
389
+ out_real = out_mag * out_cos
390
+ out_imag = out_mag * out_sin
391
+ # out_real, out_imag: (batch_size, output_sources_num, input_channles, time_steps, freq_bins)
392
+
393
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
394
+ shape = (
395
+ batch_size * self.output_sources_num * self.input_channels,
396
+ 1,
397
+ time_steps,
398
+ freq_bins,
399
+ )
400
+ out_real = out_real.reshape(shape)
401
+ out_imag = out_imag.reshape(shape)
402
+
403
+ # ISTFT.
404
+ wav_out = self.istft(out_real, out_imag, audio_length)
405
+ # (batch_size * output_sources_num * input_channels, segments_num)
406
+
407
+ # Reshape.
408
+ wav_out = wav_out.reshape(
409
+ batch_size, self.output_sources_num * self.input_channels, audio_length
410
+ )
411
+ # (batch_size, output_sources_num * input_channels, segments_num)
412
+
413
+ return wav_out
414
+
415
+ def forward(self, input_dict):
416
+ """
417
+ Args:
418
+ input: (batch_size, segment_samples, channels_num)
419
+
420
+ Outputs:
421
+ output_dict: {
422
+ 'wav': (batch_size, segment_samples, channels_num),
423
+ 'sp': (batch_size, channels_num, time_steps, freq_bins)}
424
+ """
425
+
426
+ mixture = input_dict['waveform']
427
+ condition = input_dict['condition']
428
+
429
+ sp, cos_in, sin_in = self.wav_to_spectrogram_phase(mixture)
430
+ """(batch_size, channels_num, time_steps, freq_bins)"""
431
+
432
+ # Batch normalization
433
+ x = sp.transpose(1, 3)
434
+ x = self.bn0(x)
435
+ x = x.transpose(1, 3)
436
+ """(batch_size, chanenls, time_steps, freq_bins)"""
437
+
438
+ # Pad spectrogram to be evenly divided by downsample ratio.
439
+ origin_len = x.shape[2]
440
+ pad_len = (
441
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
442
+ - origin_len
443
+ )
444
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
445
+ """(batch_size, channels, padded_time_steps, freq_bins)"""
446
+
447
+ # Let frequency bins be evenly divided by 2, e.g., 513 -> 512
448
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
449
+
450
+ x = self.subband.analysis(x)
451
+
452
+ # UNet
453
+ (x1_pool, x1) = self.encoder_block1(
454
+ x, condition
455
+ ) # x1_pool: (bs, 32, T / 2, F / 2)
456
+ (x2_pool, x2) = self.encoder_block2(
457
+ x1_pool, condition
458
+ ) # x2_pool: (bs, 64, T / 4, F / 4)
459
+ (x3_pool, x3) = self.encoder_block3(
460
+ x2_pool, condition
461
+ ) # x3_pool: (bs, 128, T / 8, F / 8)
462
+ (x4_pool, x4) = self.encoder_block4(
463
+ x3_pool, condition
464
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
465
+ (x5_pool, x5) = self.encoder_block5(
466
+ x4_pool, condition
467
+ ) # x5_pool: (bs, 512, T / 32, F / 32)
468
+ (x6_pool, x6) = self.encoder_block6(
469
+ x5_pool, condition
470
+ ) # x6_pool: (bs, 1024, T / 64, F / 64)
471
+ x_center = self.conv_block7(x6_pool, condition) # (bs, 2048, T / 64, F / 64)
472
+ x7 = self.decoder_block1(x_center, x6, condition) # (bs, 1024, T / 32, F / 32)
473
+ x8 = self.decoder_block2(x7, x5, condition) # (bs, 512, T / 16, F / 16)
474
+ x9 = self.decoder_block3(x8, x4, condition) # (bs, 256, T / 8, F / 8)
475
+ x10 = self.decoder_block4(x9, x3, condition) # (bs, 128, T / 4, F / 4)
476
+ x11 = self.decoder_block5(x10, x2, condition) # (bs, 64, T / 2, F / 2)
477
+ x12 = self.decoder_block6(x11, x1, condition) # (bs, 32, T, F)
478
+ x = self.after_conv_block1(x12, condition) # (bs, 32, T, F)
479
+ x = self.after_conv2(x)
480
+ # (batch_size, input_channles * subbands_num * targets_num * k, T, F // subbands_num)
481
+
482
+ x = self.subband.synthesis(x)
483
+ # (batch_size, input_channles * targets_num * K, T, F)
484
+
485
+ # Recover shape
486
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
487
+ x = x[:, :, 0:origin_len, :] # (bs, feature_maps, T, F)
488
+
489
+ audio_length = mixture.shape[2]
490
+
491
+ separated_audio = self.feature_maps_to_wav(x, sp, sin_in, cos_in, audio_length)
492
+ # separated_audio: (batch_size, output_sources_num * input_channels, segments_num)
493
+
494
+ output_dict = {'waveform': separated_audio}
495
+
496
+ return output_dict
bytesep/models/lightning_modules.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict
2
+
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+
9
+
10
+ class LitSourceSeparation(pl.LightningModule):
11
+ def __init__(
12
+ self,
13
+ batch_data_preprocessor,
14
+ model: nn.Module,
15
+ loss_function: Callable,
16
+ optimizer_type: str,
17
+ learning_rate: float,
18
+ lr_lambda: Callable,
19
+ ):
20
+ r"""Pytorch Lightning wrapper of PyTorch model, including forward,
21
+ optimization of model, etc.
22
+
23
+ Args:
24
+ batch_data_preprocessor: object, used for preparing inputs and
25
+ targets for training. E.g., BasicBatchDataPreprocessor is used
26
+ for preparing data in dictionary into tensor.
27
+ model: nn.Module
28
+ loss_function: function
29
+ learning_rate: float
30
+ lr_lambda: function
31
+ """
32
+ super().__init__()
33
+
34
+ self.batch_data_preprocessor = batch_data_preprocessor
35
+ self.model = model
36
+ self.optimizer_type = optimizer_type
37
+ self.loss_function = loss_function
38
+ self.learning_rate = learning_rate
39
+ self.lr_lambda = lr_lambda
40
+
41
+ def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float:
42
+ r"""Forward a mini-batch data to model, calculate loss function, and
43
+ train for one step. A mini-batch data is evenly distributed to multiple
44
+ devices (if there are) for parallel training.
45
+
46
+ Args:
47
+ batch_data_dict: e.g. {
48
+ 'vocals': (batch_size, channels_num, segment_samples),
49
+ 'accompaniment': (batch_size, channels_num, segment_samples),
50
+ 'mixture': (batch_size, channels_num, segment_samples)
51
+ }
52
+ batch_idx: int
53
+
54
+ Returns:
55
+ loss: float, loss function of this mini-batch
56
+ """
57
+ input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict)
58
+ # input_dict: {
59
+ # 'waveform': (batch_size, channels_num, segment_samples),
60
+ # (if_exist) 'condition': (batch_size, channels_num),
61
+ # }
62
+ # target_dict: {
63
+ # 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
64
+ # }
65
+
66
+ # Forward.
67
+ self.model.train()
68
+
69
+ output_dict = self.model(input_dict)
70
+ # output_dict: {
71
+ # 'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
72
+ # }
73
+
74
+ outputs = output_dict['waveform']
75
+ # outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples)
76
+
77
+ # Calculate loss.
78
+ loss = self.loss_function(
79
+ output=outputs,
80
+ target=target_dict['waveform'],
81
+ mixture=input_dict['waveform'],
82
+ )
83
+
84
+ return loss
85
+
86
+ def configure_optimizers(self) -> Any:
87
+ r"""Configure optimizer."""
88
+
89
+ if self.optimizer_type == "Adam":
90
+ optimizer = optim.Adam(
91
+ self.model.parameters(),
92
+ lr=self.learning_rate,
93
+ betas=(0.9, 0.999),
94
+ eps=1e-08,
95
+ weight_decay=0.0,
96
+ amsgrad=True,
97
+ )
98
+
99
+ elif self.optimizer_type == "AdamW":
100
+ optimizer = optim.AdamW(
101
+ self.model.parameters(),
102
+ lr=self.learning_rate,
103
+ betas=(0.9, 0.999),
104
+ eps=1e-08,
105
+ weight_decay=0.0,
106
+ amsgrad=True,
107
+ )
108
+
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ scheduler = {
113
+ 'scheduler': LambdaLR(optimizer, self.lr_lambda),
114
+ 'interval': 'step',
115
+ 'frequency': 1,
116
+ }
117
+
118
+ return [optimizer], [scheduler]
119
+
120
+
121
+ def get_model_class(model_type):
122
+ r"""Get model.
123
+
124
+ Args:
125
+ model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN'
126
+
127
+ Returns:
128
+ nn.Module
129
+ """
130
+ if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021':
131
+ from bytesep.models.resunet_ismir2021 import (
132
+ ResUNet143_DecouplePlusInplaceABN_ISMIR2021,
133
+ )
134
+
135
+ return ResUNet143_DecouplePlusInplaceABN_ISMIR2021
136
+
137
+ elif model_type == 'UNet':
138
+ from bytesep.models.unet import UNet
139
+
140
+ return UNet
141
+
142
+ elif model_type == 'UNetSubbandTime':
143
+ from bytesep.models.unet_subbandtime import UNetSubbandTime
144
+
145
+ return UNetSubbandTime
146
+
147
+ elif model_type == 'ResUNet143_Subbandtime':
148
+ from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime
149
+
150
+ return ResUNet143_Subbandtime
151
+
152
+ elif model_type == 'ResUNet143_DecouplePlus':
153
+ from bytesep.models.resunet import ResUNet143_DecouplePlus
154
+
155
+ return ResUNet143_DecouplePlus
156
+
157
+ elif model_type == 'ConditionalUNet':
158
+ from bytesep.models.conditional_unet import ConditionalUNet
159
+
160
+ return ConditionalUNet
161
+
162
+ elif model_type == 'LevelRNN':
163
+ from bytesep.models.levelrnn import LevelRNN
164
+
165
+ return LevelRNN
166
+
167
+ elif model_type == 'WavUNet':
168
+ from bytesep.models.wavunet import WavUNet
169
+
170
+ return WavUNet
171
+
172
+ elif model_type == 'WavUNetLevelRNN':
173
+ from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN
174
+
175
+ return WavUNetLevelRNN
176
+
177
+ elif model_type == 'TTnet':
178
+ from bytesep.models.ttnet import TTnet
179
+
180
+ return TTnet
181
+
182
+ elif model_type == 'TTnetNoTransformer':
183
+ from bytesep.models.ttnet_no_transformer import TTnetNoTransformer
184
+
185
+ return TTnetNoTransformer
186
+
187
+ else:
188
+ raise NotImplementedError
bytesep/models/pytorch_modules.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, NoReturn
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def init_embedding(layer: nn.Module) -> NoReturn:
10
+ r"""Initialize a Linear or Convolutional layer."""
11
+ nn.init.uniform_(layer.weight, -1.0, 1.0)
12
+
13
+ if hasattr(layer, 'bias'):
14
+ if layer.bias is not None:
15
+ layer.bias.data.fill_(0.0)
16
+
17
+
18
+ def init_layer(layer: nn.Module) -> NoReturn:
19
+ r"""Initialize a Linear or Convolutional layer."""
20
+ nn.init.xavier_uniform_(layer.weight)
21
+
22
+ if hasattr(layer, "bias"):
23
+ if layer.bias is not None:
24
+ layer.bias.data.fill_(0.0)
25
+
26
+
27
+ def init_bn(bn: nn.Module) -> NoReturn:
28
+ r"""Initialize a Batchnorm layer."""
29
+ bn.bias.data.fill_(0.0)
30
+ bn.weight.data.fill_(1.0)
31
+ bn.running_mean.data.fill_(0.0)
32
+ bn.running_var.data.fill_(1.0)
33
+
34
+
35
+ def act(x: torch.Tensor, activation: str) -> torch.Tensor:
36
+
37
+ if activation == "relu":
38
+ return F.relu_(x)
39
+
40
+ elif activation == "leaky_relu":
41
+ return F.leaky_relu_(x, negative_slope=0.01)
42
+
43
+ elif activation == "swish":
44
+ return x * torch.sigmoid(x)
45
+
46
+ else:
47
+ raise Exception("Incorrect activation!")
48
+
49
+
50
+ class Base:
51
+ def __init__(self):
52
+ r"""Base function for extracting spectrogram, cos, and sin, etc."""
53
+ pass
54
+
55
+ def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
56
+ r"""Calculate spectrogram.
57
+
58
+ Args:
59
+ input: (batch_size, segments_num)
60
+ eps: float
61
+
62
+ Returns:
63
+ spectrogram: (batch_size, time_steps, freq_bins)
64
+ """
65
+ (real, imag) = self.stft(input)
66
+ return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
67
+
68
+ def spectrogram_phase(
69
+ self, input: torch.Tensor, eps: float = 0.0
70
+ ) -> List[torch.Tensor]:
71
+ r"""Calculate the magnitude, cos, and sin of the STFT of input.
72
+
73
+ Args:
74
+ input: (batch_size, segments_num)
75
+ eps: float
76
+
77
+ Returns:
78
+ mag: (batch_size, time_steps, freq_bins)
79
+ cos: (batch_size, time_steps, freq_bins)
80
+ sin: (batch_size, time_steps, freq_bins)
81
+ """
82
+ (real, imag) = self.stft(input)
83
+ mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
84
+ cos = real / mag
85
+ sin = imag / mag
86
+ return mag, cos, sin
87
+
88
+ def wav_to_spectrogram_phase(
89
+ self, input: torch.Tensor, eps: float = 1e-10
90
+ ) -> List[torch.Tensor]:
91
+ r"""Convert waveforms to magnitude, cos, and sin of STFT.
92
+
93
+ Args:
94
+ input: (batch_size, channels_num, segment_samples)
95
+ eps: float
96
+
97
+ Outputs:
98
+ mag: (batch_size, channels_num, time_steps, freq_bins)
99
+ cos: (batch_size, channels_num, time_steps, freq_bins)
100
+ sin: (batch_size, channels_num, time_steps, freq_bins)
101
+ """
102
+ batch_size, channels_num, segment_samples = input.shape
103
+
104
+ # Reshape input with shapes of (n, segments_num) to meet the
105
+ # requirements of the stft function.
106
+ x = input.reshape(batch_size * channels_num, segment_samples)
107
+
108
+ mag, cos, sin = self.spectrogram_phase(x, eps=eps)
109
+ # mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins)
110
+
111
+ _, _, time_steps, freq_bins = mag.shape
112
+ mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins)
113
+ cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins)
114
+ sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins)
115
+
116
+ return mag, cos, sin
117
+
118
+ def wav_to_spectrogram(
119
+ self, input: torch.Tensor, eps: float = 1e-10
120
+ ) -> List[torch.Tensor]:
121
+
122
+ mag, cos, sin = self.wav_to_spectrogram_phase(input, eps)
123
+ return mag
124
+
125
+
126
+ class Subband:
127
+ def __init__(self, subbands_num: int):
128
+ r"""Warning!! This class is not used!!
129
+
130
+ This class does not work as good as [1] which split subbands in the
131
+ time-domain. Please refere to [1] for formal implementation.
132
+
133
+ [1] Liu, Haohe, et al. "Channel-wise subband input for better voice and
134
+ accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020).
135
+
136
+ Args:
137
+ subbands_num: int, e.g., 4
138
+ """
139
+ self.subbands_num = subbands_num
140
+
141
+ def analysis(self, x: torch.Tensor) -> torch.Tensor:
142
+ r"""Analysis time-frequency representation into subbands. Stack the
143
+ subbands along the channel axis.
144
+
145
+ Args:
146
+ x: (batch_size, channels_num, time_steps, freq_bins)
147
+
148
+ Returns:
149
+ output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
150
+ """
151
+ batch_size, channels_num, time_steps, freq_bins = x.shape
152
+
153
+ x = x.reshape(
154
+ batch_size,
155
+ channels_num,
156
+ time_steps,
157
+ self.subbands_num,
158
+ freq_bins // self.subbands_num,
159
+ )
160
+ # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
161
+
162
+ x = x.transpose(2, 3)
163
+
164
+ output = x.reshape(
165
+ batch_size,
166
+ channels_num * self.subbands_num,
167
+ time_steps,
168
+ freq_bins // self.subbands_num,
169
+ )
170
+ # output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
171
+
172
+ return output
173
+
174
+ def synthesis(self, x: torch.Tensor) -> torch.Tensor:
175
+ r"""Synthesis subband time-frequency representations into original
176
+ time-frequency representation.
177
+
178
+ Args:
179
+ x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
180
+
181
+ Returns:
182
+ output: (batch_size, channels_num, time_steps, freq_bins)
183
+ """
184
+ batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape
185
+
186
+ channels_num = subband_channels_num // self.subbands_num
187
+ freq_bins = subband_freq_bins * self.subbands_num
188
+
189
+ x = x.reshape(
190
+ batch_size,
191
+ channels_num,
192
+ self.subbands_num,
193
+ time_steps,
194
+ subband_freq_bins,
195
+ )
196
+ # x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num)
197
+
198
+ x = x.transpose(2, 3)
199
+ # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)
200
+
201
+ output = x.reshape(batch_size, channels_num, time_steps, freq_bins)
202
+ # x: (batch_size, channels_num, time_steps, freq_bins)
203
+
204
+ return output
bytesep/models/resunet.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchlibrosa.stft import ISTFT, STFT, magphase
6
+
7
+ from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
8
+
9
+
10
+ class ConvBlockRes(nn.Module):
11
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
12
+ r"""Residual block."""
13
+ super(ConvBlockRes, self).__init__()
14
+
15
+ self.activation = activation
16
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
17
+
18
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
19
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
20
+
21
+ self.conv1 = nn.Conv2d(
22
+ in_channels=in_channels,
23
+ out_channels=out_channels,
24
+ kernel_size=kernel_size,
25
+ stride=(1, 1),
26
+ dilation=(1, 1),
27
+ padding=padding,
28
+ bias=False,
29
+ )
30
+
31
+ self.conv2 = nn.Conv2d(
32
+ in_channels=out_channels,
33
+ out_channels=out_channels,
34
+ kernel_size=kernel_size,
35
+ stride=(1, 1),
36
+ dilation=(1, 1),
37
+ padding=padding,
38
+ bias=False,
39
+ )
40
+
41
+ if in_channels != out_channels:
42
+ self.shortcut = nn.Conv2d(
43
+ in_channels=in_channels,
44
+ out_channels=out_channels,
45
+ kernel_size=(1, 1),
46
+ stride=(1, 1),
47
+ padding=(0, 0),
48
+ )
49
+
50
+ self.is_shortcut = True
51
+ else:
52
+ self.is_shortcut = False
53
+
54
+ self.init_weights()
55
+
56
+ def init_weights(self):
57
+ init_bn(self.bn1)
58
+ init_bn(self.bn2)
59
+ init_layer(self.conv1)
60
+ init_layer(self.conv2)
61
+
62
+ if self.is_shortcut:
63
+ init_layer(self.shortcut)
64
+
65
+ def forward(self, x):
66
+ origin = x
67
+ x = self.conv1(act(self.bn1(x), self.activation))
68
+ x = self.conv2(act(self.bn2(x), self.activation))
69
+
70
+ if self.is_shortcut:
71
+ return self.shortcut(origin) + x
72
+ else:
73
+ return origin + x
74
+
75
+
76
+ class EncoderBlockRes4B(nn.Module):
77
+ def __init__(
78
+ self, in_channels, out_channels, kernel_size, downsample, activation, momentum
79
+ ):
80
+ r"""Encoder block, contains 8 convolutional layers."""
81
+ super(EncoderBlockRes4B, self).__init__()
82
+
83
+ self.conv_block1 = ConvBlockRes(
84
+ in_channels, out_channels, kernel_size, activation, momentum
85
+ )
86
+ self.conv_block2 = ConvBlockRes(
87
+ out_channels, out_channels, kernel_size, activation, momentum
88
+ )
89
+ self.conv_block3 = ConvBlockRes(
90
+ out_channels, out_channels, kernel_size, activation, momentum
91
+ )
92
+ self.conv_block4 = ConvBlockRes(
93
+ out_channels, out_channels, kernel_size, activation, momentum
94
+ )
95
+ self.downsample = downsample
96
+
97
+ def forward(self, x):
98
+ encoder = self.conv_block1(x)
99
+ encoder = self.conv_block2(encoder)
100
+ encoder = self.conv_block3(encoder)
101
+ encoder = self.conv_block4(encoder)
102
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
103
+ return encoder_pool, encoder
104
+
105
+
106
+ class DecoderBlockRes4B(nn.Module):
107
+ def __init__(
108
+ self, in_channels, out_channels, kernel_size, upsample, activation, momentum
109
+ ):
110
+ r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
111
+ super(DecoderBlockRes4B, self).__init__()
112
+ self.kernel_size = kernel_size
113
+ self.stride = upsample
114
+ self.activation = activation
115
+
116
+ self.conv1 = torch.nn.ConvTranspose2d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=self.stride,
120
+ stride=self.stride,
121
+ padding=(0, 0),
122
+ bias=False,
123
+ dilation=(1, 1),
124
+ )
125
+
126
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
127
+ self.conv_block2 = ConvBlockRes(
128
+ out_channels * 2, out_channels, kernel_size, activation, momentum
129
+ )
130
+ self.conv_block3 = ConvBlockRes(
131
+ out_channels, out_channels, kernel_size, activation, momentum
132
+ )
133
+ self.conv_block4 = ConvBlockRes(
134
+ out_channels, out_channels, kernel_size, activation, momentum
135
+ )
136
+ self.conv_block5 = ConvBlockRes(
137
+ out_channels, out_channels, kernel_size, activation, momentum
138
+ )
139
+
140
+ self.init_weights()
141
+
142
+ def init_weights(self):
143
+ init_bn(self.bn1)
144
+ init_layer(self.conv1)
145
+
146
+ def forward(self, input_tensor, concat_tensor):
147
+ x = self.conv1(act(self.bn1(input_tensor), self.activation))
148
+ x = torch.cat((x, concat_tensor), dim=1)
149
+ x = self.conv_block2(x)
150
+ x = self.conv_block3(x)
151
+ x = self.conv_block4(x)
152
+ x = self.conv_block5(x)
153
+ return x
154
+
155
+
156
+ class ResUNet143_DecouplePlus(nn.Module, Base):
157
+ def __init__(self, input_channels, target_sources_num):
158
+ super(ResUNet143_DecouplePlus, self).__init__()
159
+
160
+ self.input_channels = input_channels
161
+ self.target_sources_num = target_sources_num
162
+
163
+ window_size = 2048
164
+ hop_size = 441
165
+ center = True
166
+ pad_mode = "reflect"
167
+ window = "hann"
168
+ activation = "relu"
169
+ momentum = 0.01
170
+
171
+ self.subbands_num = 4
172
+ self.K = 4 # outputs: |M|, cos∠M, sin∠M, |M2|
173
+
174
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
175
+
176
+ self.stft = STFT(
177
+ n_fft=window_size,
178
+ hop_length=hop_size,
179
+ win_length=window_size,
180
+ window=window,
181
+ center=center,
182
+ pad_mode=pad_mode,
183
+ freeze_parameters=True,
184
+ )
185
+
186
+ self.istft = ISTFT(
187
+ n_fft=window_size,
188
+ hop_length=hop_size,
189
+ win_length=window_size,
190
+ window=window,
191
+ center=center,
192
+ pad_mode=pad_mode,
193
+ freeze_parameters=True,
194
+ )
195
+
196
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
197
+
198
+ self.subband = Subband(subbands_num=self.subbands_num)
199
+
200
+ self.encoder_block1 = EncoderBlockRes4B(
201
+ in_channels=input_channels * self.subbands_num,
202
+ out_channels=32,
203
+ kernel_size=(3, 3),
204
+ downsample=(2, 2),
205
+ activation=activation,
206
+ momentum=momentum,
207
+ )
208
+ self.encoder_block2 = EncoderBlockRes4B(
209
+ in_channels=32,
210
+ out_channels=64,
211
+ kernel_size=(3, 3),
212
+ downsample=(2, 2),
213
+ activation=activation,
214
+ momentum=momentum,
215
+ )
216
+ self.encoder_block3 = EncoderBlockRes4B(
217
+ in_channels=64,
218
+ out_channels=128,
219
+ kernel_size=(3, 3),
220
+ downsample=(2, 2),
221
+ activation=activation,
222
+ momentum=momentum,
223
+ )
224
+ self.encoder_block4 = EncoderBlockRes4B(
225
+ in_channels=128,
226
+ out_channels=256,
227
+ kernel_size=(3, 3),
228
+ downsample=(2, 2),
229
+ activation=activation,
230
+ momentum=momentum,
231
+ )
232
+ self.encoder_block5 = EncoderBlockRes4B(
233
+ in_channels=256,
234
+ out_channels=384,
235
+ kernel_size=(3, 3),
236
+ downsample=(2, 2),
237
+ activation=activation,
238
+ momentum=momentum,
239
+ )
240
+ self.encoder_block6 = EncoderBlockRes4B(
241
+ in_channels=384,
242
+ out_channels=384,
243
+ kernel_size=(3, 3),
244
+ downsample=(1, 2),
245
+ activation=activation,
246
+ momentum=momentum,
247
+ )
248
+ self.conv_block7a = EncoderBlockRes4B(
249
+ in_channels=384,
250
+ out_channels=384,
251
+ kernel_size=(3, 3),
252
+ downsample=(1, 1),
253
+ activation=activation,
254
+ momentum=momentum,
255
+ )
256
+ self.conv_block7b = EncoderBlockRes4B(
257
+ in_channels=384,
258
+ out_channels=384,
259
+ kernel_size=(3, 3),
260
+ downsample=(1, 1),
261
+ activation=activation,
262
+ momentum=momentum,
263
+ )
264
+ self.conv_block7c = EncoderBlockRes4B(
265
+ in_channels=384,
266
+ out_channels=384,
267
+ kernel_size=(3, 3),
268
+ downsample=(1, 1),
269
+ activation=activation,
270
+ momentum=momentum,
271
+ )
272
+ self.conv_block7d = EncoderBlockRes4B(
273
+ in_channels=384,
274
+ out_channels=384,
275
+ kernel_size=(3, 3),
276
+ downsample=(1, 1),
277
+ activation=activation,
278
+ momentum=momentum,
279
+ )
280
+ self.decoder_block1 = DecoderBlockRes4B(
281
+ in_channels=384,
282
+ out_channels=384,
283
+ kernel_size=(3, 3),
284
+ upsample=(1, 2),
285
+ activation=activation,
286
+ momentum=momentum,
287
+ )
288
+ self.decoder_block2 = DecoderBlockRes4B(
289
+ in_channels=384,
290
+ out_channels=384,
291
+ kernel_size=(3, 3),
292
+ upsample=(2, 2),
293
+ activation=activation,
294
+ momentum=momentum,
295
+ )
296
+ self.decoder_block3 = DecoderBlockRes4B(
297
+ in_channels=384,
298
+ out_channels=256,
299
+ kernel_size=(3, 3),
300
+ upsample=(2, 2),
301
+ activation=activation,
302
+ momentum=momentum,
303
+ )
304
+ self.decoder_block4 = DecoderBlockRes4B(
305
+ in_channels=256,
306
+ out_channels=128,
307
+ kernel_size=(3, 3),
308
+ upsample=(2, 2),
309
+ activation=activation,
310
+ momentum=momentum,
311
+ )
312
+ self.decoder_block5 = DecoderBlockRes4B(
313
+ in_channels=128,
314
+ out_channels=64,
315
+ kernel_size=(3, 3),
316
+ upsample=(2, 2),
317
+ activation=activation,
318
+ momentum=momentum,
319
+ )
320
+ self.decoder_block6 = DecoderBlockRes4B(
321
+ in_channels=64,
322
+ out_channels=32,
323
+ kernel_size=(3, 3),
324
+ upsample=(2, 2),
325
+ activation=activation,
326
+ momentum=momentum,
327
+ )
328
+
329
+ self.after_conv_block1 = EncoderBlockRes4B(
330
+ in_channels=32,
331
+ out_channels=32,
332
+ kernel_size=(3, 3),
333
+ downsample=(1, 1),
334
+ activation=activation,
335
+ momentum=momentum,
336
+ )
337
+
338
+ self.after_conv2 = nn.Conv2d(
339
+ in_channels=32,
340
+ out_channels=input_channels
341
+ * self.subbands_num
342
+ * target_sources_num
343
+ * self.K,
344
+ kernel_size=(1, 1),
345
+ stride=(1, 1),
346
+ padding=(0, 0),
347
+ bias=True,
348
+ )
349
+
350
+ self.init_weights()
351
+
352
+ def init_weights(self):
353
+ init_bn(self.bn0)
354
+ init_layer(self.after_conv2)
355
+
356
+ def feature_maps_to_wav(
357
+ self,
358
+ input_tensor: torch.Tensor,
359
+ sp: torch.Tensor,
360
+ sin_in: torch.Tensor,
361
+ cos_in: torch.Tensor,
362
+ audio_length: int,
363
+ ) -> torch.Tensor:
364
+ r"""Convert feature maps to waveform.
365
+
366
+ Args:
367
+ input_tensor: (batch_size, feature_maps, time_steps, freq_bins)
368
+ sp: (batch_size, feature_maps, time_steps, freq_bins)
369
+ sin_in: (batch_size, feature_maps, time_steps, freq_bins)
370
+ cos_in: (batch_size, feature_maps, time_steps, freq_bins)
371
+
372
+ Outputs:
373
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
374
+ """
375
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
376
+
377
+ x = input_tensor.reshape(
378
+ batch_size,
379
+ self.target_sources_num,
380
+ self.input_channels,
381
+ self.K,
382
+ time_steps,
383
+ freq_bins,
384
+ )
385
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
386
+
387
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
388
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
389
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
390
+ linear_mag = x[:, :, :, 3, :, :]
391
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
392
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
393
+
394
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
395
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
396
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
397
+ out_cos = (
398
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
399
+ )
400
+ out_sin = (
401
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
402
+ )
403
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
404
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
405
+
406
+ # Calculate |Y|.
407
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
408
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
409
+
410
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
411
+ out_real = out_mag * out_cos
412
+ out_imag = out_mag * out_sin
413
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
414
+
415
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
416
+ shape = (
417
+ batch_size * self.target_sources_num * self.input_channels,
418
+ 1,
419
+ time_steps,
420
+ freq_bins,
421
+ )
422
+ out_real = out_real.reshape(shape)
423
+ out_imag = out_imag.reshape(shape)
424
+
425
+ # ISTFT.
426
+ x = self.istft(out_real, out_imag, audio_length)
427
+ # (batch_size * target_sources_num * input_channels, segments_num)
428
+
429
+ # Reshape.
430
+ waveform = x.reshape(
431
+ batch_size, self.target_sources_num * self.input_channels, audio_length
432
+ )
433
+ # (batch_size, target_sources_num * input_channels, segments_num)
434
+
435
+ return waveform
436
+
437
+ def forward(self, input_dict):
438
+ r"""
439
+ Args:
440
+ input: (batch_size, channels_num, segment_samples)
441
+
442
+ Outputs:
443
+ output_dict: {
444
+ 'wav': (batch_size, channels_num, segment_samples)
445
+ }
446
+ """
447
+ mixtures = input_dict['waveform']
448
+ # (batch_size, input_channels, segment_samples)
449
+
450
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
451
+ # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
452
+
453
+ # Batch normalize on individual frequency bins.
454
+ x = mag.transpose(1, 3)
455
+ x = self.bn0(x)
456
+ x = x.transpose(1, 3)
457
+ """(batch_size, input_channels, time_steps, freq_bins)"""
458
+
459
+ # Pad spectrogram to be evenly divided by downsample ratio.
460
+ origin_len = x.shape[2]
461
+ pad_len = (
462
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
463
+ - origin_len
464
+ )
465
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
466
+ """(batch_size, input_channels, padded_time_steps, freq_bins)"""
467
+
468
+ # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
469
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
470
+
471
+ x = self.subband.analysis(x)
472
+ # (bs, input_channels, T, F'), where F' = F // subbands_num
473
+
474
+ # UNet
475
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
476
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
477
+ (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
478
+ (x4_pool, x4) = self.encoder_block4(
479
+ x3_pool
480
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
481
+ (x5_pool, x5) = self.encoder_block5(
482
+ x4_pool
483
+ ) # x5_pool: (bs, 384, T / 32, F / 32)
484
+ (x6_pool, x6) = self.encoder_block6(
485
+ x5_pool
486
+ ) # x6_pool: (bs, 384, T / 32, F / 64)
487
+ (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
488
+ (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
489
+ (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
490
+ (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
491
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
492
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
493
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
494
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
495
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
496
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
497
+ (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
498
+
499
+ x = self.after_conv2(x) # (bs, channels * 3, T, F)
500
+ # (batch_size, input_channles * subbands_num * targets_num * k, T, F')
501
+
502
+ x = self.subband.synthesis(x)
503
+ # (batch_size, input_channles * targets_num * K, T, F)
504
+
505
+ # Recover shape
506
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
507
+ x = x[:, :, 0:origin_len, :] # (bs, feature_maps, time_steps, freq_bins)
508
+
509
+ audio_length = mixtures.shape[2]
510
+
511
+ separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
512
+ # separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
513
+
514
+ output_dict = {'waveform': separated_audio}
515
+
516
+ return output_dict
bytesep/models/resunet_ismir2021.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from inplace_abn.abn import InPlaceABNSync
6
+ from torchlibrosa.stft import ISTFT, STFT, magphase
7
+
8
+ from bytesep.models.pytorch_modules import Base, init_bn, init_layer
9
+
10
+
11
+ class ConvBlockRes(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
13
+ r"""Residual block."""
14
+ super(ConvBlockRes, self).__init__()
15
+
16
+ self.activation = activation
17
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
18
+
19
+ # ABN is not used for bn1 because we found using abn1 will degrade performance.
20
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
21
+
22
+ self.abn2 = InPlaceABNSync(
23
+ num_features=out_channels, momentum=momentum, activation='leaky_relu'
24
+ )
25
+
26
+ self.conv1 = nn.Conv2d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=kernel_size,
30
+ stride=(1, 1),
31
+ dilation=(1, 1),
32
+ padding=padding,
33
+ bias=False,
34
+ )
35
+
36
+ self.conv2 = nn.Conv2d(
37
+ in_channels=out_channels,
38
+ out_channels=out_channels,
39
+ kernel_size=kernel_size,
40
+ stride=(1, 1),
41
+ dilation=(1, 1),
42
+ padding=padding,
43
+ bias=False,
44
+ )
45
+
46
+ if in_channels != out_channels:
47
+ self.shortcut = nn.Conv2d(
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ kernel_size=(1, 1),
51
+ stride=(1, 1),
52
+ padding=(0, 0),
53
+ )
54
+ self.is_shortcut = True
55
+ else:
56
+ self.is_shortcut = False
57
+
58
+ self.init_weights()
59
+
60
+ def init_weights(self):
61
+ init_bn(self.bn1)
62
+ init_layer(self.conv1)
63
+ init_layer(self.conv2)
64
+
65
+ if self.is_shortcut:
66
+ init_layer(self.shortcut)
67
+
68
+ def forward(self, x):
69
+ origin = x
70
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
71
+ x = self.conv2(self.abn2(x))
72
+
73
+ if self.is_shortcut:
74
+ return self.shortcut(origin) + x
75
+ else:
76
+ return origin + x
77
+
78
+
79
+ class EncoderBlockRes4B(nn.Module):
80
+ def __init__(
81
+ self, in_channels, out_channels, kernel_size, downsample, activation, momentum
82
+ ):
83
+ r"""Encoder block, contains 8 convolutional layers."""
84
+ super(EncoderBlockRes4B, self).__init__()
85
+
86
+ self.conv_block1 = ConvBlockRes(
87
+ in_channels, out_channels, kernel_size, activation, momentum
88
+ )
89
+ self.conv_block2 = ConvBlockRes(
90
+ out_channels, out_channels, kernel_size, activation, momentum
91
+ )
92
+ self.conv_block3 = ConvBlockRes(
93
+ out_channels, out_channels, kernel_size, activation, momentum
94
+ )
95
+ self.conv_block4 = ConvBlockRes(
96
+ out_channels, out_channels, kernel_size, activation, momentum
97
+ )
98
+ self.downsample = downsample
99
+
100
+ def forward(self, x):
101
+ encoder = self.conv_block1(x)
102
+ encoder = self.conv_block2(encoder)
103
+ encoder = self.conv_block3(encoder)
104
+ encoder = self.conv_block4(encoder)
105
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
106
+ return encoder_pool, encoder
107
+
108
+
109
+ class DecoderBlockRes4B(nn.Module):
110
+ def __init__(
111
+ self, in_channels, out_channels, kernel_size, upsample, activation, momentum
112
+ ):
113
+ r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
114
+ super(DecoderBlockRes4B, self).__init__()
115
+ self.kernel_size = kernel_size
116
+ self.stride = upsample
117
+ self.activation = activation
118
+
119
+ self.conv1 = torch.nn.ConvTranspose2d(
120
+ in_channels=in_channels,
121
+ out_channels=out_channels,
122
+ kernel_size=self.stride,
123
+ stride=self.stride,
124
+ padding=(0, 0),
125
+ bias=False,
126
+ dilation=(1, 1),
127
+ )
128
+
129
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
130
+ self.conv_block2 = ConvBlockRes(
131
+ out_channels * 2, out_channels, kernel_size, activation, momentum
132
+ )
133
+ self.conv_block3 = ConvBlockRes(
134
+ out_channels, out_channels, kernel_size, activation, momentum
135
+ )
136
+ self.conv_block4 = ConvBlockRes(
137
+ out_channels, out_channels, kernel_size, activation, momentum
138
+ )
139
+ self.conv_block5 = ConvBlockRes(
140
+ out_channels, out_channels, kernel_size, activation, momentum
141
+ )
142
+
143
+ self.init_weights()
144
+
145
+ def init_weights(self):
146
+ init_bn(self.bn1)
147
+ init_layer(self.conv1)
148
+
149
+ def forward(self, input_tensor, concat_tensor):
150
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
151
+ x = torch.cat((x, concat_tensor), dim=1)
152
+ x = self.conv_block2(x)
153
+ x = self.conv_block3(x)
154
+ x = self.conv_block4(x)
155
+ x = self.conv_block5(x)
156
+ return x
157
+
158
+
159
+ class ResUNet143_DecouplePlusInplaceABN_ISMIR2021(nn.Module, Base):
160
+ def __init__(self, input_channels, target_sources_num):
161
+ super(ResUNet143_DecouplePlusInplaceABN_ISMIR2021, self).__init__()
162
+
163
+ self.input_channels = input_channels
164
+ self.target_sources_num = target_sources_num
165
+
166
+ window_size = 2048
167
+ hop_size = 441
168
+ center = True
169
+ pad_mode = 'reflect'
170
+ window = 'hann'
171
+ activation = 'leaky_relu'
172
+ momentum = 0.01
173
+
174
+ self.subbands_num = 1
175
+
176
+ assert (
177
+ self.subbands_num == 1
178
+ ), "Using subbands_num > 1 on spectrogram \
179
+ will lead to unexpected performance sometimes. Suggest to use \
180
+ subband method on waveform."
181
+
182
+ # Downsample rate along the time axis.
183
+ self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q
184
+ self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
185
+
186
+ self.stft = STFT(
187
+ n_fft=window_size,
188
+ hop_length=hop_size,
189
+ win_length=window_size,
190
+ window=window,
191
+ center=center,
192
+ pad_mode=pad_mode,
193
+ freeze_parameters=True,
194
+ )
195
+
196
+ self.istft = ISTFT(
197
+ n_fft=window_size,
198
+ hop_length=hop_size,
199
+ win_length=window_size,
200
+ window=window,
201
+ center=center,
202
+ pad_mode=pad_mode,
203
+ freeze_parameters=True,
204
+ )
205
+
206
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
207
+
208
+ self.encoder_block1 = EncoderBlockRes4B(
209
+ in_channels=input_channels * self.subbands_num,
210
+ out_channels=32,
211
+ kernel_size=(3, 3),
212
+ downsample=(2, 2),
213
+ activation=activation,
214
+ momentum=momentum,
215
+ )
216
+ self.encoder_block2 = EncoderBlockRes4B(
217
+ in_channels=32,
218
+ out_channels=64,
219
+ kernel_size=(3, 3),
220
+ downsample=(2, 2),
221
+ activation=activation,
222
+ momentum=momentum,
223
+ )
224
+ self.encoder_block3 = EncoderBlockRes4B(
225
+ in_channels=64,
226
+ out_channels=128,
227
+ kernel_size=(3, 3),
228
+ downsample=(2, 2),
229
+ activation=activation,
230
+ momentum=momentum,
231
+ )
232
+ self.encoder_block4 = EncoderBlockRes4B(
233
+ in_channels=128,
234
+ out_channels=256,
235
+ kernel_size=(3, 3),
236
+ downsample=(2, 2),
237
+ activation=activation,
238
+ momentum=momentum,
239
+ )
240
+ self.encoder_block5 = EncoderBlockRes4B(
241
+ in_channels=256,
242
+ out_channels=384,
243
+ kernel_size=(3, 3),
244
+ downsample=(2, 2),
245
+ activation=activation,
246
+ momentum=momentum,
247
+ )
248
+ self.encoder_block6 = EncoderBlockRes4B(
249
+ in_channels=384,
250
+ out_channels=384,
251
+ kernel_size=(3, 3),
252
+ downsample=(1, 2),
253
+ activation=activation,
254
+ momentum=momentum,
255
+ )
256
+ self.conv_block7a = EncoderBlockRes4B(
257
+ in_channels=384,
258
+ out_channels=384,
259
+ kernel_size=(3, 3),
260
+ downsample=(1, 1),
261
+ activation=activation,
262
+ momentum=momentum,
263
+ )
264
+ self.conv_block7b = EncoderBlockRes4B(
265
+ in_channels=384,
266
+ out_channels=384,
267
+ kernel_size=(3, 3),
268
+ downsample=(1, 1),
269
+ activation=activation,
270
+ momentum=momentum,
271
+ )
272
+ self.conv_block7c = EncoderBlockRes4B(
273
+ in_channels=384,
274
+ out_channels=384,
275
+ kernel_size=(3, 3),
276
+ downsample=(1, 1),
277
+ activation=activation,
278
+ momentum=momentum,
279
+ )
280
+ self.conv_block7d = EncoderBlockRes4B(
281
+ in_channels=384,
282
+ out_channels=384,
283
+ kernel_size=(3, 3),
284
+ downsample=(1, 1),
285
+ activation=activation,
286
+ momentum=momentum,
287
+ )
288
+ self.decoder_block1 = DecoderBlockRes4B(
289
+ in_channels=384,
290
+ out_channels=384,
291
+ kernel_size=(3, 3),
292
+ upsample=(1, 2),
293
+ activation=activation,
294
+ momentum=momentum,
295
+ )
296
+ self.decoder_block2 = DecoderBlockRes4B(
297
+ in_channels=384,
298
+ out_channels=384,
299
+ kernel_size=(3, 3),
300
+ upsample=(2, 2),
301
+ activation=activation,
302
+ momentum=momentum,
303
+ )
304
+ self.decoder_block3 = DecoderBlockRes4B(
305
+ in_channels=384,
306
+ out_channels=256,
307
+ kernel_size=(3, 3),
308
+ upsample=(2, 2),
309
+ activation=activation,
310
+ momentum=momentum,
311
+ )
312
+ self.decoder_block4 = DecoderBlockRes4B(
313
+ in_channels=256,
314
+ out_channels=128,
315
+ kernel_size=(3, 3),
316
+ upsample=(2, 2),
317
+ activation=activation,
318
+ momentum=momentum,
319
+ )
320
+ self.decoder_block5 = DecoderBlockRes4B(
321
+ in_channels=128,
322
+ out_channels=64,
323
+ kernel_size=(3, 3),
324
+ upsample=(2, 2),
325
+ activation=activation,
326
+ momentum=momentum,
327
+ )
328
+ self.decoder_block6 = DecoderBlockRes4B(
329
+ in_channels=64,
330
+ out_channels=32,
331
+ kernel_size=(3, 3),
332
+ upsample=(2, 2),
333
+ activation=activation,
334
+ momentum=momentum,
335
+ )
336
+
337
+ self.after_conv_block1 = EncoderBlockRes4B(
338
+ in_channels=32,
339
+ out_channels=32,
340
+ kernel_size=(3, 3),
341
+ downsample=(1, 1),
342
+ activation=activation,
343
+ momentum=momentum,
344
+ )
345
+
346
+ self.after_conv2 = nn.Conv2d(
347
+ in_channels=32,
348
+ out_channels=target_sources_num
349
+ * input_channels
350
+ * self.K
351
+ * self.subbands_num,
352
+ kernel_size=(1, 1),
353
+ stride=(1, 1),
354
+ padding=(0, 0),
355
+ bias=True,
356
+ )
357
+
358
+ self.init_weights()
359
+
360
+ def init_weights(self):
361
+ init_bn(self.bn0)
362
+ init_layer(self.after_conv2)
363
+
364
+ def feature_maps_to_wav(
365
+ self,
366
+ input_tensor: torch.Tensor,
367
+ sp: torch.Tensor,
368
+ sin_in: torch.Tensor,
369
+ cos_in: torch.Tensor,
370
+ audio_length: int,
371
+ ) -> torch.Tensor:
372
+ r"""Convert feature maps to waveform.
373
+
374
+ Args:
375
+ input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
376
+ sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
377
+ sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
378
+ cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
379
+
380
+ Outputs:
381
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
382
+ """
383
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
384
+
385
+ x = input_tensor.reshape(
386
+ batch_size,
387
+ self.target_sources_num,
388
+ self.input_channels,
389
+ self.K,
390
+ time_steps,
391
+ freq_bins,
392
+ )
393
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
394
+
395
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
396
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
397
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
398
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
399
+ linear_mag = x[:, :, :, 3, :, :]
400
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
401
+
402
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
403
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
404
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
405
+ out_cos = (
406
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
407
+ )
408
+ out_sin = (
409
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
410
+ )
411
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
412
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
413
+
414
+ # Calculate |Y|.
415
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
416
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
417
+
418
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
419
+ out_real = out_mag * out_cos
420
+ out_imag = out_mag * out_sin
421
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
422
+
423
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
424
+ shape = (
425
+ batch_size * self.target_sources_num * self.input_channels,
426
+ 1,
427
+ time_steps,
428
+ freq_bins,
429
+ )
430
+ out_real = out_real.reshape(shape)
431
+ out_imag = out_imag.reshape(shape)
432
+
433
+ # ISTFT.
434
+ x = self.istft(out_real, out_imag, audio_length)
435
+ # (batch_size * target_sources_num * input_channels, segments_num)
436
+
437
+ # Reshape.
438
+ waveform = x.reshape(
439
+ batch_size, self.target_sources_num * self.input_channels, audio_length
440
+ )
441
+ # (batch_size, target_sources_num * input_channels, segments_num)
442
+
443
+ return waveform
444
+
445
+ def forward(self, input_dict):
446
+ r"""Forward data into the module.
447
+
448
+ Args:
449
+ input_dict: dict, e.g., {
450
+ waveform: (batch_size, input_channels, segment_samples),
451
+ ...,
452
+ }
453
+
454
+ Outputs:
455
+ output_dict: dict, e.g., {
456
+ 'waveform': (batch_size, input_channels, segment_samples),
457
+ ...,
458
+ }
459
+ """
460
+ mixtures = input_dict['waveform']
461
+ # (batch_size, input_channels, segment_samples)
462
+
463
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
464
+ # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
465
+
466
+ # Batch normalize on individual frequency bins.
467
+ x = mag.transpose(1, 3)
468
+ x = self.bn0(x)
469
+ x = x.transpose(1, 3)
470
+ # x: (batch_size, input_channels, time_steps, freq_bins)
471
+
472
+ # Pad spectrogram to be evenly divided by downsample ratio.
473
+ origin_len = x.shape[2]
474
+ pad_len = (
475
+ int(np.ceil(x.shape[2] / self.time_downsample_ratio))
476
+ * self.time_downsample_ratio
477
+ - origin_len
478
+ )
479
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
480
+ # (batch_size, channels, padded_time_steps, freq_bins)
481
+
482
+ # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024.
483
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
484
+
485
+ if self.subbands_num > 1:
486
+ x = self.subband.analysis(x)
487
+ # (bs, input_channels, T, F'), where F' = F // subbands_num
488
+
489
+ # UNet
490
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
491
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
492
+ (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
493
+ (x4_pool, x4) = self.encoder_block4(
494
+ x3_pool
495
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
496
+ (x5_pool, x5) = self.encoder_block5(
497
+ x4_pool
498
+ ) # x5_pool: (bs, 384, T / 32, F / 32)
499
+ (x6_pool, x6) = self.encoder_block6(
500
+ x5_pool
501
+ ) # x6_pool: (bs, 384, T / 32, F / 64)
502
+ (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
503
+ (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
504
+ (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
505
+ (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
506
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
507
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
508
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
509
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
510
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
511
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
512
+ (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
513
+
514
+ x = self.after_conv2(x) # (bs, channels * 3, T, F)
515
+ # (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
516
+
517
+ if self.subbands_num > 1:
518
+ x = self.subband.synthesis(x)
519
+ # (batch_size, target_sources_num * input_channles * self.K, T, F)
520
+
521
+ # Recover shape
522
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
523
+
524
+ x = x[:, :, 0:origin_len, :]
525
+ # (batch_size, target_sources_num * input_channles * self.K, T, F)
526
+
527
+ audio_length = mixtures.shape[2]
528
+
529
+ separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
530
+ # separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
531
+
532
+ output_dict = {'waveform': separated_audio}
533
+
534
+ return output_dict
bytesep/models/resunet_subbandtime.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchlibrosa.stft import ISTFT, STFT, magphase
6
+
7
+ from bytesep.models.pytorch_modules import Base, init_bn, init_layer
8
+ from bytesep.models.subband_tools.pqmf import PQMF
9
+
10
+
11
+ class ConvBlockRes(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
13
+ r"""Residual block."""
14
+ super(ConvBlockRes, self).__init__()
15
+
16
+ self.activation = activation
17
+ padding = [kernel_size[0] // 2, kernel_size[1] // 2]
18
+
19
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
20
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
21
+
22
+ self.conv1 = nn.Conv2d(
23
+ in_channels=in_channels,
24
+ out_channels=out_channels,
25
+ kernel_size=kernel_size,
26
+ stride=(1, 1),
27
+ dilation=(1, 1),
28
+ padding=padding,
29
+ bias=False,
30
+ )
31
+
32
+ self.conv2 = nn.Conv2d(
33
+ in_channels=out_channels,
34
+ out_channels=out_channels,
35
+ kernel_size=kernel_size,
36
+ stride=(1, 1),
37
+ dilation=(1, 1),
38
+ padding=padding,
39
+ bias=False,
40
+ )
41
+
42
+ if in_channels != out_channels:
43
+ self.shortcut = nn.Conv2d(
44
+ in_channels=in_channels,
45
+ out_channels=out_channels,
46
+ kernel_size=(1, 1),
47
+ stride=(1, 1),
48
+ padding=(0, 0),
49
+ )
50
+ self.is_shortcut = True
51
+ else:
52
+ self.is_shortcut = False
53
+
54
+ self.init_weights()
55
+
56
+ def init_weights(self):
57
+ init_bn(self.bn1)
58
+ init_bn(self.bn2)
59
+ init_layer(self.conv1)
60
+ init_layer(self.conv2)
61
+
62
+ if self.is_shortcut:
63
+ init_layer(self.shortcut)
64
+
65
+ def forward(self, x):
66
+ origin = x
67
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
68
+ x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
69
+
70
+ if self.is_shortcut:
71
+ return self.shortcut(origin) + x
72
+ else:
73
+ return origin + x
74
+
75
+
76
+ class EncoderBlockRes4B(nn.Module):
77
+ def __init__(
78
+ self, in_channels, out_channels, kernel_size, downsample, activation, momentum
79
+ ):
80
+ r"""Encoder block, contains 8 convolutional layers."""
81
+ super(EncoderBlockRes4B, self).__init__()
82
+
83
+ self.conv_block1 = ConvBlockRes(
84
+ in_channels, out_channels, kernel_size, activation, momentum
85
+ )
86
+ self.conv_block2 = ConvBlockRes(
87
+ out_channels, out_channels, kernel_size, activation, momentum
88
+ )
89
+ self.conv_block3 = ConvBlockRes(
90
+ out_channels, out_channels, kernel_size, activation, momentum
91
+ )
92
+ self.conv_block4 = ConvBlockRes(
93
+ out_channels, out_channels, kernel_size, activation, momentum
94
+ )
95
+ self.downsample = downsample
96
+
97
+ def forward(self, x):
98
+ encoder = self.conv_block1(x)
99
+ encoder = self.conv_block2(encoder)
100
+ encoder = self.conv_block3(encoder)
101
+ encoder = self.conv_block4(encoder)
102
+ encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
103
+ return encoder_pool, encoder
104
+
105
+
106
+ class DecoderBlockRes4B(nn.Module):
107
+ def __init__(
108
+ self, in_channels, out_channels, kernel_size, upsample, activation, momentum
109
+ ):
110
+ r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
111
+ super(DecoderBlockRes4B, self).__init__()
112
+ self.kernel_size = kernel_size
113
+ self.stride = upsample
114
+ self.activation = activation
115
+
116
+ self.conv1 = torch.nn.ConvTranspose2d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=self.stride,
120
+ stride=self.stride,
121
+ padding=(0, 0),
122
+ bias=False,
123
+ dilation=(1, 1),
124
+ )
125
+
126
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
127
+ self.conv_block2 = ConvBlockRes(
128
+ out_channels * 2, out_channels, kernel_size, activation, momentum
129
+ )
130
+ self.conv_block3 = ConvBlockRes(
131
+ out_channels, out_channels, kernel_size, activation, momentum
132
+ )
133
+ self.conv_block4 = ConvBlockRes(
134
+ out_channels, out_channels, kernel_size, activation, momentum
135
+ )
136
+ self.conv_block5 = ConvBlockRes(
137
+ out_channels, out_channels, kernel_size, activation, momentum
138
+ )
139
+
140
+ self.init_weights()
141
+
142
+ def init_weights(self):
143
+ init_bn(self.bn1)
144
+ init_layer(self.conv1)
145
+
146
+ def forward(self, input_tensor, concat_tensor):
147
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
148
+ x = torch.cat((x, concat_tensor), dim=1)
149
+ x = self.conv_block2(x)
150
+ x = self.conv_block3(x)
151
+ x = self.conv_block4(x)
152
+ x = self.conv_block5(x)
153
+ return x
154
+
155
+
156
+ class ResUNet143_Subbandtime(nn.Module, Base):
157
+ def __init__(self, input_channels, target_sources_num):
158
+ super(ResUNet143_Subbandtime, self).__init__()
159
+
160
+ self.input_channels = input_channels
161
+ self.target_sources_num = target_sources_num
162
+
163
+ window_size = 512
164
+ hop_size = 110
165
+ center = True
166
+ pad_mode = "reflect"
167
+ window = "hann"
168
+ activation = "leaky_relu"
169
+ momentum = 0.01
170
+
171
+ self.subbands_num = 4
172
+ self.K = 4 # outputs: |M|, cos∠M, sin∠M, Q
173
+
174
+ self.downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
175
+
176
+ self.pqmf = PQMF(
177
+ N=self.subbands_num,
178
+ M=64,
179
+ project_root='bytesep/models/subband_tools/filters',
180
+ )
181
+
182
+ self.stft = STFT(
183
+ n_fft=window_size,
184
+ hop_length=hop_size,
185
+ win_length=window_size,
186
+ window=window,
187
+ center=center,
188
+ pad_mode=pad_mode,
189
+ freeze_parameters=True,
190
+ )
191
+
192
+ self.istft = ISTFT(
193
+ n_fft=window_size,
194
+ hop_length=hop_size,
195
+ win_length=window_size,
196
+ window=window,
197
+ center=center,
198
+ pad_mode=pad_mode,
199
+ freeze_parameters=True,
200
+ )
201
+
202
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
203
+
204
+ self.encoder_block1 = EncoderBlockRes4B(
205
+ in_channels=input_channels * self.subbands_num,
206
+ out_channels=32,
207
+ kernel_size=(3, 3),
208
+ downsample=(2, 2),
209
+ activation=activation,
210
+ momentum=momentum,
211
+ )
212
+ self.encoder_block2 = EncoderBlockRes4B(
213
+ in_channels=32,
214
+ out_channels=64,
215
+ kernel_size=(3, 3),
216
+ downsample=(2, 2),
217
+ activation=activation,
218
+ momentum=momentum,
219
+ )
220
+ self.encoder_block3 = EncoderBlockRes4B(
221
+ in_channels=64,
222
+ out_channels=128,
223
+ kernel_size=(3, 3),
224
+ downsample=(2, 2),
225
+ activation=activation,
226
+ momentum=momentum,
227
+ )
228
+ self.encoder_block4 = EncoderBlockRes4B(
229
+ in_channels=128,
230
+ out_channels=256,
231
+ kernel_size=(3, 3),
232
+ downsample=(2, 2),
233
+ activation=activation,
234
+ momentum=momentum,
235
+ )
236
+ self.encoder_block5 = EncoderBlockRes4B(
237
+ in_channels=256,
238
+ out_channels=384,
239
+ kernel_size=(3, 3),
240
+ downsample=(2, 2),
241
+ activation=activation,
242
+ momentum=momentum,
243
+ )
244
+ self.encoder_block6 = EncoderBlockRes4B(
245
+ in_channels=384,
246
+ out_channels=384,
247
+ kernel_size=(3, 3),
248
+ downsample=(1, 2),
249
+ activation=activation,
250
+ momentum=momentum,
251
+ )
252
+ self.conv_block7a = EncoderBlockRes4B(
253
+ in_channels=384,
254
+ out_channels=384,
255
+ kernel_size=(3, 3),
256
+ downsample=(1, 1),
257
+ activation=activation,
258
+ momentum=momentum,
259
+ )
260
+ self.conv_block7b = EncoderBlockRes4B(
261
+ in_channels=384,
262
+ out_channels=384,
263
+ kernel_size=(3, 3),
264
+ downsample=(1, 1),
265
+ activation=activation,
266
+ momentum=momentum,
267
+ )
268
+ self.conv_block7c = EncoderBlockRes4B(
269
+ in_channels=384,
270
+ out_channels=384,
271
+ kernel_size=(3, 3),
272
+ downsample=(1, 1),
273
+ activation=activation,
274
+ momentum=momentum,
275
+ )
276
+ self.conv_block7d = EncoderBlockRes4B(
277
+ in_channels=384,
278
+ out_channels=384,
279
+ kernel_size=(3, 3),
280
+ downsample=(1, 1),
281
+ activation=activation,
282
+ momentum=momentum,
283
+ )
284
+ self.decoder_block1 = DecoderBlockRes4B(
285
+ in_channels=384,
286
+ out_channels=384,
287
+ kernel_size=(3, 3),
288
+ upsample=(1, 2),
289
+ activation=activation,
290
+ momentum=momentum,
291
+ )
292
+ self.decoder_block2 = DecoderBlockRes4B(
293
+ in_channels=384,
294
+ out_channels=384,
295
+ kernel_size=(3, 3),
296
+ upsample=(2, 2),
297
+ activation=activation,
298
+ momentum=momentum,
299
+ )
300
+ self.decoder_block3 = DecoderBlockRes4B(
301
+ in_channels=384,
302
+ out_channels=256,
303
+ kernel_size=(3, 3),
304
+ upsample=(2, 2),
305
+ activation=activation,
306
+ momentum=momentum,
307
+ )
308
+ self.decoder_block4 = DecoderBlockRes4B(
309
+ in_channels=256,
310
+ out_channels=128,
311
+ kernel_size=(3, 3),
312
+ upsample=(2, 2),
313
+ activation=activation,
314
+ momentum=momentum,
315
+ )
316
+ self.decoder_block5 = DecoderBlockRes4B(
317
+ in_channels=128,
318
+ out_channels=64,
319
+ kernel_size=(3, 3),
320
+ upsample=(2, 2),
321
+ activation=activation,
322
+ momentum=momentum,
323
+ )
324
+ self.decoder_block6 = DecoderBlockRes4B(
325
+ in_channels=64,
326
+ out_channels=32,
327
+ kernel_size=(3, 3),
328
+ upsample=(2, 2),
329
+ activation=activation,
330
+ momentum=momentum,
331
+ )
332
+
333
+ self.after_conv_block1 = EncoderBlockRes4B(
334
+ in_channels=32,
335
+ out_channels=32,
336
+ kernel_size=(3, 3),
337
+ downsample=(1, 1),
338
+ activation=activation,
339
+ momentum=momentum,
340
+ )
341
+
342
+ self.after_conv2 = nn.Conv2d(
343
+ in_channels=32,
344
+ out_channels=target_sources_num
345
+ * input_channels
346
+ * self.K
347
+ * self.subbands_num,
348
+ kernel_size=(1, 1),
349
+ stride=(1, 1),
350
+ padding=(0, 0),
351
+ bias=True,
352
+ )
353
+
354
+ self.init_weights()
355
+
356
+ def init_weights(self):
357
+ init_bn(self.bn0)
358
+ init_layer(self.after_conv2)
359
+
360
+ def feature_maps_to_wav(
361
+ self,
362
+ input_tensor: torch.Tensor,
363
+ sp: torch.Tensor,
364
+ sin_in: torch.Tensor,
365
+ cos_in: torch.Tensor,
366
+ audio_length: int,
367
+ ) -> torch.Tensor:
368
+ r"""Convert feature maps to waveform.
369
+
370
+ Args:
371
+ input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
372
+ sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
373
+ sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
374
+ cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
375
+
376
+ Outputs:
377
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
378
+ """
379
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
380
+
381
+ x = input_tensor.reshape(
382
+ batch_size,
383
+ self.target_sources_num,
384
+ self.input_channels,
385
+ self.K,
386
+ time_steps,
387
+ freq_bins,
388
+ )
389
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
390
+
391
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
392
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
393
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
394
+ linear_mag = torch.tanh(x[:, :, :, 3, :, :])
395
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
396
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
397
+
398
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
399
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
400
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
401
+ out_cos = (
402
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
403
+ )
404
+ out_sin = (
405
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
406
+ )
407
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
408
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
409
+
410
+ # Calculate |Y|.
411
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
412
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
413
+
414
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
415
+ out_real = out_mag * out_cos
416
+ out_imag = out_mag * out_sin
417
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
418
+
419
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
420
+ shape = (
421
+ batch_size * self.target_sources_num * self.input_channels,
422
+ 1,
423
+ time_steps,
424
+ freq_bins,
425
+ )
426
+ out_real = out_real.reshape(shape)
427
+ out_imag = out_imag.reshape(shape)
428
+
429
+ # ISTFT.
430
+ x = self.istft(out_real, out_imag, audio_length)
431
+ # (batch_size * target_sources_num * input_channels, segments_num)
432
+
433
+ # Reshape.
434
+ waveform = x.reshape(
435
+ batch_size, self.target_sources_num * self.input_channels, audio_length
436
+ )
437
+ # (batch_size, target_sources_num * input_channels, segments_num)
438
+
439
+ return waveform
440
+
441
+ def forward(self, input_dict):
442
+ r"""Forward data into the module.
443
+
444
+ Args:
445
+ input_dict: dict, e.g., {
446
+ waveform: (batch_size, input_channels, segment_samples),
447
+ ...,
448
+ }
449
+
450
+ Outputs:
451
+ output_dict: dict, e.g., {
452
+ 'waveform': (batch_size, input_channels, segment_samples),
453
+ ...,
454
+ }
455
+ """
456
+ mixtures = input_dict['waveform']
457
+ # (batch_size, input_channels, segment_samples)
458
+
459
+ subband_x = self.pqmf.analysis(mixtures)
460
+ # subband_x: (batch_size, input_channels * subbands_num, segment_samples)
461
+
462
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
463
+ # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
464
+
465
+ # Batch normalize on individual frequency bins.
466
+ x = mag.transpose(1, 3)
467
+ x = self.bn0(x)
468
+ x = x.transpose(1, 3)
469
+ # (batch_size, input_channels * subbands_num, time_steps, freq_bins)
470
+
471
+ # Pad spectrogram to be evenly divided by downsample ratio.
472
+ origin_len = x.shape[2]
473
+ pad_len = (
474
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
475
+ - origin_len
476
+ )
477
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
478
+ # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
479
+
480
+ # Let frequency bins be evenly divided by 2, e.g., 257 -> 256
481
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
482
+ # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
483
+
484
+ # UNet
485
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F / 2)
486
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F / 4)
487
+ (x3_pool, x3) = self.encoder_block3(x2_pool) # x3_pool: (bs, 128, T / 8, F / 8)
488
+ (x4_pool, x4) = self.encoder_block4(
489
+ x3_pool
490
+ ) # x4_pool: (bs, 256, T / 16, F / 16)
491
+ (x5_pool, x5) = self.encoder_block5(
492
+ x4_pool
493
+ ) # x5_pool: (bs, 384, T / 32, F / 32)
494
+ (x6_pool, x6) = self.encoder_block6(
495
+ x5_pool
496
+ ) # x6_pool: (bs, 384, T / 32, F / 64)
497
+ (x_center, _) = self.conv_block7a(x6_pool) # (bs, 384, T / 32, F / 64)
498
+ (x_center, _) = self.conv_block7b(x_center) # (bs, 384, T / 32, F / 64)
499
+ (x_center, _) = self.conv_block7c(x_center) # (bs, 384, T / 32, F / 64)
500
+ (x_center, _) = self.conv_block7d(x_center) # (bs, 384, T / 32, F / 64)
501
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F / 32)
502
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F / 16)
503
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F / 8)
504
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F / 4)
505
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F / 2)
506
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F)
507
+ (x, _) = self.after_conv_block1(x12) # (bs, 32, T, F)
508
+
509
+ x = self.after_conv2(x)
510
+ # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
511
+
512
+ # Recover shape
513
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257.
514
+
515
+ x = x[:, :, 0:origin_len, :]
516
+ # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
517
+
518
+ audio_length = subband_x.shape[2]
519
+
520
+ # Recover each subband spectrograms to subband waveforms. Then synthesis
521
+ # the subband waveforms to a waveform.
522
+ C1 = x.shape[1] // self.subbands_num
523
+ C2 = mag.shape[1] // self.subbands_num
524
+
525
+ separated_subband_audio = torch.cat(
526
+ [
527
+ self.feature_maps_to_wav(
528
+ input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
529
+ sp=mag[:, j * C2 : (j + 1) * C2, :, :],
530
+ sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
531
+ cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
532
+ audio_length=audio_length,
533
+ )
534
+ for j in range(self.subbands_num)
535
+ ],
536
+ dim=1,
537
+ )
538
+ # (batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
539
+
540
+ separated_audio = self.pqmf.synthesis(separated_subband_audio)
541
+ # (batch_size, input_channles, segment_samples)
542
+
543
+ output_dict = {'waveform': separated_audio}
544
+
545
+ return output_dict
bytesep/models/subband_tools/__init__.py ADDED
File without changes
bytesep/models/subband_tools/fDomainHelper.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchlibrosa.stft import STFT, ISTFT, magphase
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from tools.pytorch.modules.pqmf import PQMF
6
+
7
+
8
+ class FDomainHelper(nn.Module):
9
+ def __init__(
10
+ self,
11
+ window_size=2048,
12
+ hop_size=441,
13
+ center=True,
14
+ pad_mode='reflect',
15
+ window='hann',
16
+ freeze_parameters=True,
17
+ subband=None,
18
+ root="/Users/admin/Documents/projects/",
19
+ ):
20
+ super(FDomainHelper, self).__init__()
21
+ self.subband = subband
22
+ if self.subband is None:
23
+ self.stft = STFT(
24
+ n_fft=window_size,
25
+ hop_length=hop_size,
26
+ win_length=window_size,
27
+ window=window,
28
+ center=center,
29
+ pad_mode=pad_mode,
30
+ freeze_parameters=freeze_parameters,
31
+ )
32
+
33
+ self.istft = ISTFT(
34
+ n_fft=window_size,
35
+ hop_length=hop_size,
36
+ win_length=window_size,
37
+ window=window,
38
+ center=center,
39
+ pad_mode=pad_mode,
40
+ freeze_parameters=freeze_parameters,
41
+ )
42
+ else:
43
+ self.stft = STFT(
44
+ n_fft=window_size // self.subband,
45
+ hop_length=hop_size // self.subband,
46
+ win_length=window_size // self.subband,
47
+ window=window,
48
+ center=center,
49
+ pad_mode=pad_mode,
50
+ freeze_parameters=freeze_parameters,
51
+ )
52
+
53
+ self.istft = ISTFT(
54
+ n_fft=window_size // self.subband,
55
+ hop_length=hop_size // self.subband,
56
+ win_length=window_size // self.subband,
57
+ window=window,
58
+ center=center,
59
+ pad_mode=pad_mode,
60
+ freeze_parameters=freeze_parameters,
61
+ )
62
+
63
+ if subband is not None and root is not None:
64
+ self.qmf = PQMF(subband, 64, root)
65
+
66
+ def complex_spectrogram(self, input, eps=0.0):
67
+ # [batchsize, samples]
68
+ # return [batchsize, 2, t-steps, f-bins]
69
+ real, imag = self.stft(input)
70
+ return torch.cat([real, imag], dim=1)
71
+
72
+ def reverse_complex_spectrogram(self, input, eps=0.0, length=None):
73
+ # [batchsize, 2[real,imag], t-steps, f-bins]
74
+ wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length)
75
+ return wav
76
+
77
+ def spectrogram(self, input, eps=0.0):
78
+ (real, imag) = self.stft(input.float())
79
+ return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
80
+
81
+ def spectrogram_phase(self, input, eps=0.0):
82
+ (real, imag) = self.stft(input.float())
83
+ mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
84
+ cos = real / mag
85
+ sin = imag / mag
86
+ return mag, cos, sin
87
+
88
+ def wav_to_spectrogram_phase(self, input, eps=1e-8):
89
+ """Waveform to spectrogram.
90
+
91
+ Args:
92
+ input: (batch_size, channels_num, segment_samples)
93
+
94
+ Outputs:
95
+ output: (batch_size, channels_num, time_steps, freq_bins)
96
+ """
97
+ sp_list = []
98
+ cos_list = []
99
+ sin_list = []
100
+ channels_num = input.shape[1]
101
+ for channel in range(channels_num):
102
+ mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps)
103
+ sp_list.append(mag)
104
+ cos_list.append(cos)
105
+ sin_list.append(sin)
106
+
107
+ sps = torch.cat(sp_list, dim=1)
108
+ coss = torch.cat(cos_list, dim=1)
109
+ sins = torch.cat(sin_list, dim=1)
110
+ return sps, coss, sins
111
+
112
+ def spectrogram_phase_to_wav(self, sps, coss, sins, length):
113
+ channels_num = sps.size()[1]
114
+ res = []
115
+ for i in range(channels_num):
116
+ res.append(
117
+ self.istft(
118
+ sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...],
119
+ sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...],
120
+ length,
121
+ )
122
+ )
123
+ res[-1] = res[-1].unsqueeze(1)
124
+ return torch.cat(res, dim=1)
125
+
126
+ def wav_to_spectrogram(self, input, eps=1e-8):
127
+ """Waveform to spectrogram.
128
+
129
+ Args:
130
+ input: (batch_size,channels_num, segment_samples)
131
+
132
+ Outputs:
133
+ output: (batch_size, channels_num, time_steps, freq_bins)
134
+ """
135
+ sp_list = []
136
+ channels_num = input.shape[1]
137
+ for channel in range(channels_num):
138
+ sp_list.append(self.spectrogram(input[:, channel, :], eps=eps))
139
+ output = torch.cat(sp_list, dim=1)
140
+ return output
141
+
142
+ def spectrogram_to_wav(self, input, spectrogram, length=None):
143
+ """Spectrogram to waveform.
144
+ Args:
145
+ input: (batch_size, segment_samples, channels_num)
146
+ spectrogram: (batch_size, channels_num, time_steps, freq_bins)
147
+
148
+ Outputs:
149
+ output: (batch_size, segment_samples, channels_num)
150
+ """
151
+ channels_num = input.shape[1]
152
+ wav_list = []
153
+ for channel in range(channels_num):
154
+ (real, imag) = self.stft(input[:, channel, :])
155
+ (_, cos, sin) = magphase(real, imag)
156
+ wav_list.append(
157
+ self.istft(
158
+ spectrogram[:, channel : channel + 1, :, :] * cos,
159
+ spectrogram[:, channel : channel + 1, :, :] * sin,
160
+ length,
161
+ )
162
+ )
163
+
164
+ output = torch.stack(wav_list, dim=1)
165
+ return output
166
+
167
+ # todo the following code is not bug free!
168
+ def wav_to_complex_spectrogram(self, input, eps=0.0):
169
+ # [batchsize , channels, samples]
170
+ # [batchsize, 2[real,imag]*channels, t-steps, f-bins]
171
+ res = []
172
+ channels_num = input.shape[1]
173
+ for channel in range(channels_num):
174
+ res.append(self.complex_spectrogram(input[:, channel, :], eps=eps))
175
+ return torch.cat(res, dim=1)
176
+
177
+ def complex_spectrogram_to_wav(self, input, eps=0.0, length=None):
178
+ # [batchsize, 2[real,imag]*channels, t-steps, f-bins]
179
+ # return [batchsize, channels, samples]
180
+ channels = input.size()[1] // 2
181
+ wavs = []
182
+ for i in range(channels):
183
+ wavs.append(
184
+ self.reverse_complex_spectrogram(
185
+ input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length
186
+ )
187
+ )
188
+ wavs[-1] = wavs[-1].unsqueeze(1)
189
+ return torch.cat(wavs, dim=1)
190
+
191
+ def wav_to_complex_subband_spectrogram(self, input, eps=0.0):
192
+ # [batchsize, channels, samples]
193
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
194
+ subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
195
+ subspec = self.wav_to_complex_spectrogram(subwav)
196
+ return subspec
197
+
198
+ def complex_subband_spectrogram_to_wav(self, input, eps=0.0):
199
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
200
+ # [batchsize, channels, samples]
201
+ subwav = self.complex_spectrogram_to_wav(input)
202
+ data = self.qmf.synthesis(subwav)
203
+ return data
204
+
205
+ def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8):
206
+ """
207
+ :param input:
208
+ :param eps:
209
+ :return:
210
+ loss = torch.nn.L1Loss()
211
+ model = FDomainHelper(subband=4)
212
+ data = torch.randn((3,1, 44100*3))
213
+
214
+ sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data)
215
+ wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4)
216
+
217
+ print(loss(data,wav))
218
+ print(torch.max(torch.abs(data-wav)))
219
+
220
+ """
221
+ # [batchsize, channels, samples]
222
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
223
+ subwav = self.qmf.analysis(input) # [batchsize, subband*channels, samples]
224
+ sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps)
225
+ return sps, coss, sins
226
+
227
+ def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0):
228
+ # [batchsize, 2[real,imag]*subband*channels, t-steps, f-bins]
229
+ # [batchsize, channels, samples]
230
+ subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length)
231
+ data = self.qmf.synthesis(subwav)
232
+ return data
233
+
234
+
235
+ if __name__ == "__main__":
236
+ # from thop import profile
237
+ # from thop import clever_format
238
+ # from tools.file.wav import *
239
+ # import time
240
+ #
241
+ # wav = torch.randn((1,2,44100))
242
+ # model = FDomainHelper()
243
+
244
+ from tools.file.wav import *
245
+
246
+ loss = torch.nn.L1Loss()
247
+ model = FDomainHelper()
248
+ data = torch.randn((3, 1, 44100 * 5))
249
+
250
+ sps = model.wav_to_complex_spectrogram(data)
251
+ print(sps.size())
252
+ wav = model.complex_spectrogram_to_wav(sps, 44100 * 5)
253
+
254
+ print(loss(data, wav))
255
+ print(torch.max(torch.abs(data - wav)))
bytesep/models/subband_tools/filters/f_4_64.mat ADDED
Binary file (2.19 kB). View file
 
bytesep/models/subband_tools/filters/h_4_64.mat ADDED
Binary file (2.19 kB). View file
 
bytesep/models/subband_tools/pqmf.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : subband_util.py
3
+ @Contact : liu.8948@buckeyemail.osu.edu
4
+ @License : (C)Copyright 2020-2021
5
+ @Modify Time @Author @Version @Desciption
6
+ ------------ ------- -------- -----------
7
+ 2020/4/3 4:54 PM Haohe Liu 1.0 None
8
+ '''
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ import os.path as op
15
+ from scipy.io import loadmat
16
+
17
+
18
+ def load_mat2numpy(fname=""):
19
+ '''
20
+ Args:
21
+ fname: pth to mat
22
+ type:
23
+ Returns: dic object
24
+ '''
25
+ if len(fname) == 0:
26
+ return None
27
+ else:
28
+ return loadmat(fname)
29
+
30
+
31
+ class PQMF(nn.Module):
32
+ def __init__(self, N, M, project_root):
33
+ super().__init__()
34
+ self.N = N # nsubband
35
+ self.M = M # nfilter
36
+ try:
37
+ assert (N, M) in [(8, 64), (4, 64), (2, 64)]
38
+ except:
39
+ print("Warning:", N, "subbandand ", M, " filter is not supported")
40
+ self.pad_samples = 64
41
+ self.name = str(N) + "_" + str(M) + ".mat"
42
+ self.ana_conv_filter = nn.Conv1d(
43
+ 1, out_channels=N, kernel_size=M, stride=N, bias=False
44
+ )
45
+ data = load_mat2numpy(op.join(project_root, "f_" + self.name))
46
+ data = data['f'].astype(np.float32) / N
47
+ data = np.flipud(data.T).T
48
+ data = np.reshape(data, (N, 1, M)).copy()
49
+ dict_new = self.ana_conv_filter.state_dict().copy()
50
+ dict_new['weight'] = torch.from_numpy(data)
51
+ self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
52
+ self.ana_conv_filter.load_state_dict(dict_new)
53
+
54
+ self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
55
+ self.syn_conv_filter = nn.Conv1d(
56
+ N, out_channels=N, kernel_size=M // N, stride=1, bias=False
57
+ )
58
+ gk = load_mat2numpy(op.join(project_root, "h_" + self.name))
59
+ gk = gk['h'].astype(np.float32)
60
+ gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
61
+ gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
62
+ dict_new = self.syn_conv_filter.state_dict().copy()
63
+ dict_new['weight'] = torch.from_numpy(gk)
64
+ self.syn_conv_filter.load_state_dict(dict_new)
65
+
66
+ for param in self.parameters():
67
+ param.requires_grad = False
68
+
69
+ def __analysis_channel(self, inputs):
70
+ return self.ana_conv_filter(self.ana_pad(inputs))
71
+
72
+ def __systhesis_channel(self, inputs):
73
+ ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
74
+ return torch.reshape(ret, (ret.shape[0], 1, -1))
75
+
76
+ def analysis(self, inputs):
77
+ '''
78
+ :param inputs: [batchsize,channel,raw_wav],value:[0,1]
79
+ :return:
80
+ '''
81
+ inputs = F.pad(inputs, ((0, self.pad_samples)))
82
+ ret = None
83
+ for i in range(inputs.size()[1]): # channels
84
+ if ret is None:
85
+ ret = self.__analysis_channel(inputs[:, i : i + 1, :])
86
+ else:
87
+ ret = torch.cat(
88
+ (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
89
+ )
90
+ return ret
91
+
92
+ def synthesis(self, data):
93
+ '''
94
+ :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
95
+ :return:
96
+ '''
97
+ ret = None
98
+ # data = F.pad(data,((0,self.pad_samples//self.N)))
99
+ for i in range(data.size()[1]): # channels
100
+ if i % self.N == 0:
101
+ if ret is None:
102
+ ret = self.__systhesis_channel(data[:, i : i + self.N, :])
103
+ else:
104
+ new = self.__systhesis_channel(data[:, i : i + self.N, :])
105
+ ret = torch.cat((ret, new), dim=1)
106
+ ret = ret[..., : -self.pad_samples]
107
+ return ret
108
+
109
+ def forward(self, inputs):
110
+ return self.ana_conv_filter(self.ana_pad(inputs))
111
+
112
+
113
+ if __name__ == "__main__":
114
+ import torch
115
+ import numpy as np
116
+ import matplotlib.pyplot as plt
117
+ from tools.file.wav import *
118
+
119
+ pqmf = PQMF(N=4, M=64, project_root="/Users/admin/Documents/projects")
120
+
121
+ rs = np.random.RandomState(0)
122
+ x = torch.tensor(rs.rand(4, 2, 32000), dtype=torch.float32)
123
+
124
+ a1 = pqmf.analysis(x)
125
+ a2 = pqmf.synthesis(a1)
126
+
127
+ print(a2.size(), x.size())
128
+
129
+ plt.subplot(211)
130
+ plt.plot(x[0, 0, -500:])
131
+ plt.subplot(212)
132
+ plt.plot(a2[0, 0, -500:])
133
+ plt.plot(x[0, 0, -500:] - a2[0, 0, -500:])
134
+ plt.show()
135
+
136
+ print(torch.sum(torch.abs(x[...] - a2[...])))
bytesep/models/unet.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List, NoReturn, Tuple
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from torchlibrosa.stft import ISTFT, STFT, magphase
13
+
14
+ from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
15
+
16
+
17
+ class ConvBlock(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_channels: int,
21
+ out_channels: int,
22
+ kernel_size: Tuple,
23
+ activation: str,
24
+ momentum: float,
25
+ ):
26
+ r"""Convolutional block."""
27
+ super(ConvBlock, self).__init__()
28
+
29
+ self.activation = activation
30
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
31
+
32
+ self.conv1 = nn.Conv2d(
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
35
+ kernel_size=kernel_size,
36
+ stride=(1, 1),
37
+ dilation=(1, 1),
38
+ padding=padding,
39
+ bias=False,
40
+ )
41
+
42
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
43
+
44
+ self.conv2 = nn.Conv2d(
45
+ in_channels=out_channels,
46
+ out_channels=out_channels,
47
+ kernel_size=kernel_size,
48
+ stride=(1, 1),
49
+ dilation=(1, 1),
50
+ padding=padding,
51
+ bias=False,
52
+ )
53
+
54
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
55
+
56
+ self.init_weights()
57
+
58
+ def init_weights(self) -> NoReturn:
59
+ r"""Initialize weights."""
60
+ init_layer(self.conv1)
61
+ init_layer(self.conv2)
62
+ init_bn(self.bn1)
63
+ init_bn(self.bn2)
64
+
65
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
66
+ r"""Forward data into the module.
67
+
68
+ Args:
69
+ input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
70
+
71
+ Returns:
72
+ output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
73
+ """
74
+ x = act(self.bn1(self.conv1(input_tensor)), self.activation)
75
+ x = act(self.bn2(self.conv2(x)), self.activation)
76
+ output_tensor = x
77
+
78
+ return output_tensor
79
+
80
+
81
+ class EncoderBlock(nn.Module):
82
+ def __init__(
83
+ self,
84
+ in_channels: int,
85
+ out_channels: int,
86
+ kernel_size: Tuple,
87
+ downsample: Tuple,
88
+ activation: str,
89
+ momentum: float,
90
+ ):
91
+ r"""Encoder block."""
92
+ super(EncoderBlock, self).__init__()
93
+
94
+ self.conv_block = ConvBlock(
95
+ in_channels, out_channels, kernel_size, activation, momentum
96
+ )
97
+ self.downsample = downsample
98
+
99
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
100
+ r"""Forward data into the module.
101
+
102
+ Args:
103
+ input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
104
+
105
+ Returns:
106
+ encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
107
+ encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
108
+ """
109
+ encoder_tensor = self.conv_block(input_tensor)
110
+ # encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
111
+
112
+ encoder_pool = F.avg_pool2d(encoder_tensor, kernel_size=self.downsample)
113
+ # encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
114
+
115
+ return encoder_pool, encoder_tensor
116
+
117
+
118
+ class DecoderBlock(nn.Module):
119
+ def __init__(
120
+ self,
121
+ in_channels: int,
122
+ out_channels: int,
123
+ kernel_size: Tuple,
124
+ upsample: Tuple,
125
+ activation: str,
126
+ momentum: float,
127
+ ):
128
+ r"""Decoder block."""
129
+ super(DecoderBlock, self).__init__()
130
+
131
+ self.kernel_size = kernel_size
132
+ self.stride = upsample
133
+ self.activation = activation
134
+
135
+ self.conv1 = torch.nn.ConvTranspose2d(
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ kernel_size=self.stride,
139
+ stride=self.stride,
140
+ padding=(0, 0),
141
+ bias=False,
142
+ dilation=(1, 1),
143
+ )
144
+
145
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
146
+
147
+ self.conv_block2 = ConvBlock(
148
+ out_channels * 2, out_channels, kernel_size, activation, momentum
149
+ )
150
+
151
+ self.init_weights()
152
+
153
+ def init_weights(self):
154
+ r"""Initialize weights."""
155
+ init_layer(self.conv1)
156
+ init_bn(self.bn1)
157
+
158
+ def forward(
159
+ self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor
160
+ ) -> torch.Tensor:
161
+ r"""Forward data into the module.
162
+
163
+ Args:
164
+ torch_tensor: (batch_size, in_feature_maps, downsampled_time_steps, downsampled_freq_bins)
165
+ concat_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
166
+
167
+ Returns:
168
+ output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
169
+ """
170
+ x = act(self.bn1(self.conv1(input_tensor)), self.activation)
171
+ # (batch_size, in_feature_maps, time_steps, freq_bins)
172
+
173
+ x = torch.cat((x, concat_tensor), dim=1)
174
+ # (batch_size, in_feature_maps * 2, time_steps, freq_bins)
175
+
176
+ output_tensor = self.conv_block2(x)
177
+ # output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
178
+
179
+ return output_tensor
180
+
181
+
182
+ class UNet(nn.Module, Base):
183
+ def __init__(self, input_channels: int, target_sources_num: int):
184
+ r"""UNet."""
185
+ super(UNet, self).__init__()
186
+
187
+ self.input_channels = input_channels
188
+ self.target_sources_num = target_sources_num
189
+
190
+ window_size = 2048
191
+ hop_size = 441
192
+ center = True
193
+ pad_mode = "reflect"
194
+ window = "hann"
195
+ activation = "leaky_relu"
196
+ momentum = 0.01
197
+
198
+ self.subbands_num = 1
199
+
200
+ assert (
201
+ self.subbands_num == 1
202
+ ), "Using subbands_num > 1 on spectrogram \
203
+ will lead to unexpected performance sometimes. Suggest to use \
204
+ subband method on waveform."
205
+
206
+ self.K = 3 # outputs: |M|, cos∠M, sin∠M
207
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
208
+
209
+ self.stft = STFT(
210
+ n_fft=window_size,
211
+ hop_length=hop_size,
212
+ win_length=window_size,
213
+ window=window,
214
+ center=center,
215
+ pad_mode=pad_mode,
216
+ freeze_parameters=True,
217
+ )
218
+
219
+ self.istft = ISTFT(
220
+ n_fft=window_size,
221
+ hop_length=hop_size,
222
+ win_length=window_size,
223
+ window=window,
224
+ center=center,
225
+ pad_mode=pad_mode,
226
+ freeze_parameters=True,
227
+ )
228
+
229
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
230
+
231
+ self.subband = Subband(subbands_num=self.subbands_num)
232
+
233
+ self.encoder_block1 = EncoderBlock(
234
+ in_channels=input_channels * self.subbands_num,
235
+ out_channels=32,
236
+ kernel_size=(3, 3),
237
+ downsample=(2, 2),
238
+ activation=activation,
239
+ momentum=momentum,
240
+ )
241
+ self.encoder_block2 = EncoderBlock(
242
+ in_channels=32,
243
+ out_channels=64,
244
+ kernel_size=(3, 3),
245
+ downsample=(2, 2),
246
+ activation=activation,
247
+ momentum=momentum,
248
+ )
249
+ self.encoder_block3 = EncoderBlock(
250
+ in_channels=64,
251
+ out_channels=128,
252
+ kernel_size=(3, 3),
253
+ downsample=(2, 2),
254
+ activation=activation,
255
+ momentum=momentum,
256
+ )
257
+ self.encoder_block4 = EncoderBlock(
258
+ in_channels=128,
259
+ out_channels=256,
260
+ kernel_size=(3, 3),
261
+ downsample=(2, 2),
262
+ activation=activation,
263
+ momentum=momentum,
264
+ )
265
+ self.encoder_block5 = EncoderBlock(
266
+ in_channels=256,
267
+ out_channels=384,
268
+ kernel_size=(3, 3),
269
+ downsample=(2, 2),
270
+ activation=activation,
271
+ momentum=momentum,
272
+ )
273
+ self.encoder_block6 = EncoderBlock(
274
+ in_channels=384,
275
+ out_channels=384,
276
+ kernel_size=(3, 3),
277
+ downsample=(2, 2),
278
+ activation=activation,
279
+ momentum=momentum,
280
+ )
281
+ self.conv_block7 = ConvBlock(
282
+ in_channels=384,
283
+ out_channels=384,
284
+ kernel_size=(3, 3),
285
+ activation=activation,
286
+ momentum=momentum,
287
+ )
288
+ self.decoder_block1 = DecoderBlock(
289
+ in_channels=384,
290
+ out_channels=384,
291
+ kernel_size=(3, 3),
292
+ upsample=(2, 2),
293
+ activation=activation,
294
+ momentum=momentum,
295
+ )
296
+ self.decoder_block2 = DecoderBlock(
297
+ in_channels=384,
298
+ out_channels=384,
299
+ kernel_size=(3, 3),
300
+ upsample=(2, 2),
301
+ activation=activation,
302
+ momentum=momentum,
303
+ )
304
+ self.decoder_block3 = DecoderBlock(
305
+ in_channels=384,
306
+ out_channels=256,
307
+ kernel_size=(3, 3),
308
+ upsample=(2, 2),
309
+ activation=activation,
310
+ momentum=momentum,
311
+ )
312
+ self.decoder_block4 = DecoderBlock(
313
+ in_channels=256,
314
+ out_channels=128,
315
+ kernel_size=(3, 3),
316
+ upsample=(2, 2),
317
+ activation=activation,
318
+ momentum=momentum,
319
+ )
320
+ self.decoder_block5 = DecoderBlock(
321
+ in_channels=128,
322
+ out_channels=64,
323
+ kernel_size=(3, 3),
324
+ upsample=(2, 2),
325
+ activation=activation,
326
+ momentum=momentum,
327
+ )
328
+
329
+ self.decoder_block6 = DecoderBlock(
330
+ in_channels=64,
331
+ out_channels=32,
332
+ kernel_size=(3, 3),
333
+ upsample=(2, 2),
334
+ activation=activation,
335
+ momentum=momentum,
336
+ )
337
+
338
+ self.after_conv_block1 = ConvBlock(
339
+ in_channels=32,
340
+ out_channels=32,
341
+ kernel_size=(3, 3),
342
+ activation=activation,
343
+ momentum=momentum,
344
+ )
345
+
346
+ self.after_conv2 = nn.Conv2d(
347
+ in_channels=32,
348
+ out_channels=target_sources_num
349
+ * input_channels
350
+ * self.K
351
+ * self.subbands_num,
352
+ kernel_size=(1, 1),
353
+ stride=(1, 1),
354
+ padding=(0, 0),
355
+ bias=True,
356
+ )
357
+
358
+ self.init_weights()
359
+
360
+ def init_weights(self):
361
+ r"""Initialize weights."""
362
+ init_bn(self.bn0)
363
+ init_layer(self.after_conv2)
364
+
365
+ def feature_maps_to_wav(
366
+ self,
367
+ input_tensor: torch.Tensor,
368
+ sp: torch.Tensor,
369
+ sin_in: torch.Tensor,
370
+ cos_in: torch.Tensor,
371
+ audio_length: int,
372
+ ) -> torch.Tensor:
373
+ r"""Convert feature maps to waveform.
374
+
375
+ Args:
376
+ input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
377
+ sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
378
+ sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
379
+ cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
380
+
381
+ Outputs:
382
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
383
+ """
384
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
385
+
386
+ x = input_tensor.reshape(
387
+ batch_size,
388
+ self.target_sources_num,
389
+ self.input_channels,
390
+ self.K,
391
+ time_steps,
392
+ freq_bins,
393
+ )
394
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
395
+
396
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
397
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
398
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
399
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
400
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
401
+
402
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
403
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
404
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
405
+ out_cos = (
406
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
407
+ )
408
+ out_sin = (
409
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
410
+ )
411
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
412
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
413
+
414
+ # Calculate |Y|.
415
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
416
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
417
+
418
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
419
+ out_real = out_mag * out_cos
420
+ out_imag = out_mag * out_sin
421
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
422
+
423
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
424
+ shape = (
425
+ batch_size * self.target_sources_num * self.input_channels,
426
+ 1,
427
+ time_steps,
428
+ freq_bins,
429
+ )
430
+ out_real = out_real.reshape(shape)
431
+ out_imag = out_imag.reshape(shape)
432
+
433
+ # ISTFT.
434
+ x = self.istft(out_real, out_imag, audio_length)
435
+ # (batch_size * target_sources_num * input_channels, segments_num)
436
+
437
+ # Reshape.
438
+ waveform = x.reshape(
439
+ batch_size, self.target_sources_num * self.input_channels, audio_length
440
+ )
441
+ # (batch_size, target_sources_num * input_channels, segments_num)
442
+
443
+ return waveform
444
+
445
+ def forward(self, input_dict: Dict) -> Dict:
446
+ r"""Forward data into the module.
447
+
448
+ Args:
449
+ input_dict: dict, e.g., {
450
+ waveform: (batch_size, input_channels, segment_samples),
451
+ ...,
452
+ }
453
+
454
+ Outputs:
455
+ output_dict: dict, e.g., {
456
+ 'waveform': (batch_size, input_channels, segment_samples),
457
+ ...,
458
+ }
459
+ """
460
+ mixtures = input_dict['waveform']
461
+ # (batch_size, input_channels, segment_samples)
462
+
463
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
464
+ # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
465
+
466
+ # Batch normalize on individual frequency bins.
467
+ x = mag.transpose(1, 3)
468
+ x = self.bn0(x)
469
+ x = x.transpose(1, 3)
470
+ # x: (batch_size, input_channels, time_steps, freq_bins)
471
+
472
+ # Pad spectrogram to be evenly divided by downsample ratio.
473
+ origin_len = x.shape[2]
474
+ pad_len = (
475
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
476
+ - origin_len
477
+ )
478
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
479
+ # x: (batch_size, input_channels, padded_time_steps, freq_bins)
480
+
481
+ # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
482
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
483
+
484
+ if self.subbands_num > 1:
485
+ x = self.subband.analysis(x)
486
+ # (bs, input_channels, T, F'), where F' = F // subbands_num
487
+
488
+ # UNet
489
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
490
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
491
+ (x3_pool, x3) = self.encoder_block3(
492
+ x2_pool
493
+ ) # x3_pool: (bs, 128, T / 8, F' / 8)
494
+ (x4_pool, x4) = self.encoder_block4(
495
+ x3_pool
496
+ ) # x4_pool: (bs, 256, T / 16, F' / 16)
497
+ (x5_pool, x5) = self.encoder_block5(
498
+ x4_pool
499
+ ) # x5_pool: (bs, 384, T / 32, F' / 32)
500
+ (x6_pool, x6) = self.encoder_block6(
501
+ x5_pool
502
+ ) # x6_pool: (bs, 384, T / 64, F' / 64)
503
+ x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
504
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
505
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
506
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
507
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
508
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
509
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
510
+ x = self.after_conv_block1(x12) # (bs, 32, T, F')
511
+
512
+ x = self.after_conv2(x)
513
+ # (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
514
+
515
+ if self.subbands_num > 1:
516
+ x = self.subband.synthesis(x)
517
+ # (batch_size, target_sources_num * input_channles * self.K, T, F)
518
+
519
+ # Recover shape
520
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
521
+
522
+ x = x[:, :, 0:origin_len, :]
523
+ # (batch_size, target_sources_num * input_channles * self.K, T, F)
524
+
525
+ audio_length = mixtures.shape[2]
526
+
527
+ separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
528
+ # separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
529
+
530
+ output_dict = {'waveform': separated_audio}
531
+
532
+ return output_dict
bytesep/models/unet_subbandtime.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchlibrosa.stft import ISTFT, STFT, magphase
8
+
9
+ from bytesep.models.pytorch_modules import Base, init_bn, init_layer
10
+ from bytesep.models.subband_tools.pqmf import PQMF
11
+ from bytesep.models.unet import ConvBlock, DecoderBlock, EncoderBlock
12
+
13
+
14
+ class UNetSubbandTime(nn.Module, Base):
15
+ def __init__(self, input_channels: int, target_sources_num: int):
16
+ r"""Subband waveform UNet."""
17
+ super(UNetSubbandTime, self).__init__()
18
+
19
+ self.input_channels = input_channels
20
+ self.target_sources_num = target_sources_num
21
+
22
+ window_size = 512 # 2048 // 4
23
+ hop_size = 110 # 441 // 4
24
+ center = True
25
+ pad_mode = "reflect"
26
+ window = "hann"
27
+ activation = "leaky_relu"
28
+ momentum = 0.01
29
+
30
+ self.subbands_num = 4
31
+ self.K = 3 # outputs: |M|, cos∠M, sin∠M
32
+
33
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
34
+
35
+ self.pqmf = PQMF(
36
+ N=self.subbands_num,
37
+ M=64,
38
+ project_root='bytesep/models/subband_tools/filters',
39
+ )
40
+
41
+ self.stft = STFT(
42
+ n_fft=window_size,
43
+ hop_length=hop_size,
44
+ win_length=window_size,
45
+ window=window,
46
+ center=center,
47
+ pad_mode=pad_mode,
48
+ freeze_parameters=True,
49
+ )
50
+
51
+ self.istft = ISTFT(
52
+ n_fft=window_size,
53
+ hop_length=hop_size,
54
+ win_length=window_size,
55
+ window=window,
56
+ center=center,
57
+ pad_mode=pad_mode,
58
+ freeze_parameters=True,
59
+ )
60
+
61
+ self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
62
+
63
+ self.encoder_block1 = EncoderBlock(
64
+ in_channels=input_channels * self.subbands_num,
65
+ out_channels=32,
66
+ kernel_size=(3, 3),
67
+ downsample=(2, 2),
68
+ activation=activation,
69
+ momentum=momentum,
70
+ )
71
+ self.encoder_block2 = EncoderBlock(
72
+ in_channels=32,
73
+ out_channels=64,
74
+ kernel_size=(3, 3),
75
+ downsample=(2, 2),
76
+ activation=activation,
77
+ momentum=momentum,
78
+ )
79
+ self.encoder_block3 = EncoderBlock(
80
+ in_channels=64,
81
+ out_channels=128,
82
+ kernel_size=(3, 3),
83
+ downsample=(2, 2),
84
+ activation=activation,
85
+ momentum=momentum,
86
+ )
87
+ self.encoder_block4 = EncoderBlock(
88
+ in_channels=128,
89
+ out_channels=256,
90
+ kernel_size=(3, 3),
91
+ downsample=(2, 2),
92
+ activation=activation,
93
+ momentum=momentum,
94
+ )
95
+ self.encoder_block5 = EncoderBlock(
96
+ in_channels=256,
97
+ out_channels=384,
98
+ kernel_size=(3, 3),
99
+ downsample=(2, 2),
100
+ activation=activation,
101
+ momentum=momentum,
102
+ )
103
+ self.encoder_block6 = EncoderBlock(
104
+ in_channels=384,
105
+ out_channels=384,
106
+ kernel_size=(3, 3),
107
+ downsample=(2, 2),
108
+ activation=activation,
109
+ momentum=momentum,
110
+ )
111
+ self.conv_block7 = ConvBlock(
112
+ in_channels=384,
113
+ out_channels=384,
114
+ kernel_size=(3, 3),
115
+ activation=activation,
116
+ momentum=momentum,
117
+ )
118
+ self.decoder_block1 = DecoderBlock(
119
+ in_channels=384,
120
+ out_channels=384,
121
+ kernel_size=(3, 3),
122
+ upsample=(2, 2),
123
+ activation=activation,
124
+ momentum=momentum,
125
+ )
126
+ self.decoder_block2 = DecoderBlock(
127
+ in_channels=384,
128
+ out_channels=384,
129
+ kernel_size=(3, 3),
130
+ upsample=(2, 2),
131
+ activation=activation,
132
+ momentum=momentum,
133
+ )
134
+ self.decoder_block3 = DecoderBlock(
135
+ in_channels=384,
136
+ out_channels=256,
137
+ kernel_size=(3, 3),
138
+ upsample=(2, 2),
139
+ activation=activation,
140
+ momentum=momentum,
141
+ )
142
+ self.decoder_block4 = DecoderBlock(
143
+ in_channels=256,
144
+ out_channels=128,
145
+ kernel_size=(3, 3),
146
+ upsample=(2, 2),
147
+ activation=activation,
148
+ momentum=momentum,
149
+ )
150
+ self.decoder_block5 = DecoderBlock(
151
+ in_channels=128,
152
+ out_channels=64,
153
+ kernel_size=(3, 3),
154
+ upsample=(2, 2),
155
+ activation=activation,
156
+ momentum=momentum,
157
+ )
158
+
159
+ self.decoder_block6 = DecoderBlock(
160
+ in_channels=64,
161
+ out_channels=32,
162
+ kernel_size=(3, 3),
163
+ upsample=(2, 2),
164
+ activation=activation,
165
+ momentum=momentum,
166
+ )
167
+
168
+ self.after_conv_block1 = ConvBlock(
169
+ in_channels=32,
170
+ out_channels=32,
171
+ kernel_size=(3, 3),
172
+ activation=activation,
173
+ momentum=momentum,
174
+ )
175
+
176
+ self.after_conv2 = nn.Conv2d(
177
+ in_channels=32,
178
+ out_channels=target_sources_num
179
+ * input_channels
180
+ * self.K
181
+ * self.subbands_num,
182
+ kernel_size=(1, 1),
183
+ stride=(1, 1),
184
+ padding=(0, 0),
185
+ bias=True,
186
+ )
187
+
188
+ self.init_weights()
189
+
190
+ def init_weights(self):
191
+ r"""Initialize weights."""
192
+ init_bn(self.bn0)
193
+ init_layer(self.after_conv2)
194
+
195
+ def feature_maps_to_wav(
196
+ self,
197
+ input_tensor: torch.Tensor,
198
+ sp: torch.Tensor,
199
+ sin_in: torch.Tensor,
200
+ cos_in: torch.Tensor,
201
+ audio_length: int,
202
+ ) -> torch.Tensor:
203
+ r"""Convert feature maps to waveform.
204
+
205
+ Args:
206
+ input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
207
+ sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
208
+ sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
209
+ cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
210
+
211
+ Outputs:
212
+ waveform: (batch_size, target_sources_num * input_channels, segment_samples)
213
+ """
214
+ batch_size, _, time_steps, freq_bins = input_tensor.shape
215
+
216
+ x = input_tensor.reshape(
217
+ batch_size,
218
+ self.target_sources_num,
219
+ self.input_channels,
220
+ self.K,
221
+ time_steps,
222
+ freq_bins,
223
+ )
224
+ # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
225
+
226
+ mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
227
+ _mask_real = torch.tanh(x[:, :, :, 1, :, :])
228
+ _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
229
+ _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
230
+ # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
231
+
232
+ # Y = |Y|cos∠Y + j|Y|sin∠Y
233
+ # = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
234
+ # = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
235
+ out_cos = (
236
+ cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
237
+ )
238
+ out_sin = (
239
+ sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
240
+ )
241
+ # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
242
+ # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
243
+
244
+ # Calculate |Y|.
245
+ out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
246
+ # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
247
+
248
+ # Calculate Y_{real} and Y_{imag} for ISTFT.
249
+ out_real = out_mag * out_cos
250
+ out_imag = out_mag * out_sin
251
+ # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
252
+
253
+ # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
254
+ shape = (
255
+ batch_size * self.target_sources_num * self.input_channels,
256
+ 1,
257
+ time_steps,
258
+ freq_bins,
259
+ )
260
+ out_real = out_real.reshape(shape)
261
+ out_imag = out_imag.reshape(shape)
262
+
263
+ # ISTFT.
264
+ x = self.istft(out_real, out_imag, audio_length)
265
+ # (batch_size * target_sources_num * input_channels, segments_num)
266
+
267
+ # Reshape.
268
+ waveform = x.reshape(
269
+ batch_size, self.target_sources_num * self.input_channels, audio_length
270
+ )
271
+ # (batch_size, target_sources_num * input_channels, segments_num)
272
+
273
+ return waveform
274
+
275
+ def forward(self, input_dict: Dict) -> Dict:
276
+ """Forward data into the module.
277
+
278
+ Args:
279
+ input_dict: dict, e.g., {
280
+ waveform: (batch_size, input_channels, segment_samples),
281
+ ...,
282
+ }
283
+
284
+ Outputs:
285
+ output_dict: dict, e.g., {
286
+ 'waveform': (batch_size, input_channels, segment_samples),
287
+ ...,
288
+ }
289
+ """
290
+ mixtures = input_dict['waveform']
291
+ # (batch_size, input_channels, segment_samples)
292
+
293
+ if self.subbands_num > 1:
294
+ subband_x = self.pqmf.analysis(mixtures)
295
+ # -- subband_x: (batch_size, input_channels * subbands_num, segment_samples)
296
+ # -- subband_x: (batch_size, subbands_num * input_channels, segment_samples)
297
+ else:
298
+ subband_x = mixtures
299
+
300
+ # from IPython import embed; embed(using=False); os._exit(0)
301
+ # import soundfile
302
+ # soundfile.write(file='_zz.wav', data=subband_x.data.cpu().numpy()[0, 2], samplerate=11025)
303
+
304
+ mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
305
+ # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)
306
+
307
+ # Batch normalize on individual frequency bins.
308
+ x = mag.transpose(1, 3)
309
+ x = self.bn0(x)
310
+ x = x.transpose(1, 3)
311
+ # (batch_size, input_channels * subbands_num, time_steps, freq_bins)
312
+
313
+ # Pad spectrogram to be evenly divided by downsample ratio.
314
+ origin_len = x.shape[2]
315
+ pad_len = (
316
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
317
+ - origin_len
318
+ )
319
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
320
+ # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
321
+
322
+ # Let frequency bins be evenly divided by 2, e.g., 257 -> 256
323
+ x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
324
+ # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)
325
+
326
+ # UNet
327
+ (x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
328
+ (x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
329
+ (x3_pool, x3) = self.encoder_block3(
330
+ x2_pool
331
+ ) # x3_pool: (bs, 128, T / 8, F' / 8)
332
+ (x4_pool, x4) = self.encoder_block4(
333
+ x3_pool
334
+ ) # x4_pool: (bs, 256, T / 16, F' / 16)
335
+ (x5_pool, x5) = self.encoder_block5(
336
+ x4_pool
337
+ ) # x5_pool: (bs, 384, T / 32, F' / 32)
338
+ (x6_pool, x6) = self.encoder_block6(
339
+ x5_pool
340
+ ) # x6_pool: (bs, 384, T / 64, F' / 64)
341
+ x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
342
+ x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
343
+ x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
344
+ x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
345
+ x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
346
+ x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
347
+ x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
348
+ x = self.after_conv_block1(x12) # (bs, 32, T, F')
349
+
350
+ x = self.after_conv2(x)
351
+ # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
352
+
353
+ # Recover shape
354
+ x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 256 -> 257.
355
+
356
+ x = x[:, :, 0:origin_len, :]
357
+ # (batch_size, subbands_num * target_sources_num * input_channles * self.K, T, F')
358
+
359
+ audio_length = subband_x.shape[2]
360
+
361
+ # Recover each subband spectrograms to subband waveforms. Then synthesis
362
+ # the subband waveforms to a waveform.
363
+ C1 = x.shape[1] // self.subbands_num
364
+ C2 = mag.shape[1] // self.subbands_num
365
+
366
+ separated_subband_audio = torch.cat(
367
+ [
368
+ self.feature_maps_to_wav(
369
+ input_tensor=x[:, j * C1 : (j + 1) * C1, :, :],
370
+ sp=mag[:, j * C2 : (j + 1) * C2, :, :],
371
+ sin_in=sin_in[:, j * C2 : (j + 1) * C2, :, :],
372
+ cos_in=cos_in[:, j * C2 : (j + 1) * C2, :, :],
373
+ audio_length=audio_length,
374
+ )
375
+ for j in range(self.subbands_num)
376
+ ],
377
+ dim=1,
378
+ )
379
+ # (batch_size, subbands_num * target_sources_num * input_channles, segment_samples)
380
+
381
+ if self.subbands_num > 1:
382
+ separated_audio = self.pqmf.synthesis(separated_subband_audio)
383
+ # (batch_size, target_sources_num * input_channles, segment_samples)
384
+ else:
385
+ separated_audio = separated_subband_audio
386
+
387
+ output_dict = {'waveform': separated_audio}
388
+
389
+ return output_dict
bytesep/optimizers/__init__.py ADDED
File without changes
bytesep/optimizers/lr_schedulers.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int):
2
+ r"""Get lr_lambda for LambdaLR. E.g.,
3
+
4
+ .. code-block: python
5
+ lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)
6
+
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ LambdaLR(optimizer, lr_lambda)
9
+
10
+ Args:
11
+ warm_up_steps: int, steps for warm up
12
+ reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps
13
+
14
+ Returns:
15
+ learning rate: float
16
+ """
17
+ if step <= warm_up_steps:
18
+ return step / warm_up_steps
19
+ else:
20
+ return 0.9 ** (step // reduce_lr_steps)
bytesep/plot_results/__init__.py ADDED
File without changes
bytesep/plot_results/musdb18.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+
9
+ def load_sdrs(workspace, task_name, filename, config, gpus, source_type):
10
+
11
+ stat_path = os.path.join(
12
+ workspace,
13
+ "statistics",
14
+ task_name,
15
+ filename,
16
+ "config={},gpus={}".format(config, gpus),
17
+ "statistics.pkl",
18
+ )
19
+
20
+ stat_dict = pickle.load(open(stat_path, 'rb'))
21
+
22
+ median_sdrs = [e['median_sdr_dict'][source_type] for e in stat_dict['test']]
23
+
24
+ return median_sdrs
25
+
26
+
27
+ def plot_statistics(args):
28
+
29
+ # arguments & parameters
30
+ workspace = args.workspace
31
+ select = args.select
32
+ task_name = "musdb18"
33
+ filename = "train"
34
+
35
+ # paths
36
+ fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select))
37
+ os.makedirs(os.path.dirname(fig_path), exist_ok=True)
38
+
39
+ linewidth = 1
40
+ lines = []
41
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
42
+
43
+ if select == '1a':
44
+ sdrs = load_sdrs(
45
+ workspace,
46
+ task_name,
47
+ filename,
48
+ config='vocals-accompaniment,unet',
49
+ gpus=1,
50
+ source_type="vocals",
51
+ )
52
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
53
+ lines.append(line)
54
+ ylim = 15
55
+
56
+ elif select == '1b':
57
+ sdrs = load_sdrs(
58
+ workspace,
59
+ task_name,
60
+ filename,
61
+ config='accompaniment-vocals,unet',
62
+ gpus=1,
63
+ source_type="accompaniment",
64
+ )
65
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
66
+ lines.append(line)
67
+ ylim = 20
68
+
69
+ if select == '1c':
70
+ sdrs = load_sdrs(
71
+ workspace,
72
+ task_name,
73
+ filename,
74
+ config='vocals-accompaniment,unet',
75
+ gpus=1,
76
+ source_type="vocals",
77
+ )
78
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
79
+ lines.append(line)
80
+
81
+ sdrs = load_sdrs(
82
+ workspace,
83
+ task_name,
84
+ filename,
85
+ config='vocals-accompaniment,resunet',
86
+ gpus=2,
87
+ source_type="vocals",
88
+ )
89
+ (line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth)
90
+ lines.append(line)
91
+
92
+ sdrs = load_sdrs(
93
+ workspace,
94
+ task_name,
95
+ filename,
96
+ config='vocals-accompaniment,unet_subbandtime',
97
+ gpus=1,
98
+ source_type="vocals",
99
+ )
100
+ (line,) = ax.plot(sdrs, label='unet_subband,l1_wav', linewidth=linewidth)
101
+ lines.append(line)
102
+
103
+ sdrs = load_sdrs(
104
+ workspace,
105
+ task_name,
106
+ filename,
107
+ config='vocals-accompaniment,resunet_subbandtime',
108
+ gpus=1,
109
+ source_type="vocals",
110
+ )
111
+ (line,) = ax.plot(sdrs, label='resunet_subband,l1_wav', linewidth=linewidth)
112
+ lines.append(line)
113
+
114
+ ylim = 15
115
+
116
+ elif select == '1d':
117
+ sdrs = load_sdrs(
118
+ workspace,
119
+ task_name,
120
+ filename,
121
+ config='accompaniment-vocals,unet',
122
+ gpus=1,
123
+ source_type="accompaniment",
124
+ )
125
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
126
+ lines.append(line)
127
+
128
+ sdrs = load_sdrs(
129
+ workspace,
130
+ task_name,
131
+ filename,
132
+ config='accompaniment-vocals,resunet',
133
+ gpus=2,
134
+ source_type="accompaniment",
135
+ )
136
+ (line,) = ax.plot(sdrs, label='ResUNet_ISMIR2021,l1_wav', linewidth=linewidth)
137
+ lines.append(line)
138
+
139
+ # sdrs = load_sdrs(
140
+ # workspace,
141
+ # task_name,
142
+ # filename,
143
+ # config='accompaniment-vocals,unet_subbandtime',
144
+ # gpus=1,
145
+ # source_type="accompaniment",
146
+ # )
147
+ # (line,) = ax.plot(sdrs, label='UNet_subbtandtime,l1_wav', linewidth=linewidth)
148
+ # lines.append(line)
149
+
150
+ sdrs = load_sdrs(
151
+ workspace,
152
+ task_name,
153
+ filename,
154
+ config='accompaniment-vocals,resunet_subbandtime',
155
+ gpus=1,
156
+ source_type="accompaniment",
157
+ )
158
+ (line,) = ax.plot(
159
+ sdrs, label='ResUNet_subbtandtime,l1_wav', linewidth=linewidth
160
+ )
161
+ lines.append(line)
162
+
163
+ ylim = 20
164
+
165
+ else:
166
+ raise Exception('Error!')
167
+
168
+ eval_every_iterations = 10000
169
+ total_ticks = 50
170
+ ticks_freq = 10
171
+
172
+ ax.set_ylim(0, ylim)
173
+ ax.set_xlim(0, total_ticks)
174
+ ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq))
175
+ ax.xaxis.set_ticklabels(
176
+ np.arange(
177
+ 0,
178
+ total_ticks * eval_every_iterations + 1,
179
+ ticks_freq * eval_every_iterations,
180
+ )
181
+ )
182
+ ax.yaxis.set_ticks(np.arange(ylim + 1))
183
+ ax.yaxis.set_ticklabels(np.arange(ylim + 1))
184
+ ax.grid(color='b', linestyle='solid', linewidth=0.3)
185
+ plt.legend(handles=lines, loc=4)
186
+
187
+ plt.savefig(fig_path)
188
+ print('Save figure to {}'.format(fig_path))
189
+
190
+
191
+ if __name__ == '__main__':
192
+ parser = argparse.ArgumentParser()
193
+ parser.add_argument('--workspace', type=str, required=True)
194
+ parser.add_argument('--select', type=str, required=True)
195
+
196
+ args = parser.parse_args()
197
+
198
+ plot_statistics(args)
bytesep/plot_results/plot_vctk-musdb18.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import argparse
5
+ import h5py
6
+ import math
7
+ import time
8
+ import logging
9
+ import pickle
10
+ import matplotlib.pyplot as plt
11
+
12
+
13
+ def load_sdrs(workspace, task_name, filename, config, gpus):
14
+
15
+ stat_path = os.path.join(
16
+ workspace,
17
+ "statistics",
18
+ task_name,
19
+ filename,
20
+ "config={},gpus={}".format(config, gpus),
21
+ "statistics.pkl",
22
+ )
23
+
24
+ stat_dict = pickle.load(open(stat_path, 'rb'))
25
+
26
+ median_sdrs = [e['sdr'] for e in stat_dict['test']]
27
+
28
+ return median_sdrs
29
+
30
+
31
+ def plot_statistics(args):
32
+
33
+ # arguments & parameters
34
+ workspace = args.workspace
35
+ select = args.select
36
+ task_name = "vctk-musdb18"
37
+ filename = "train"
38
+
39
+ # paths
40
+ fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select))
41
+ os.makedirs(os.path.dirname(fig_path), exist_ok=True)
42
+
43
+ linewidth = 1
44
+ lines = []
45
+ fig, ax = plt.subplots(1, 1, figsize=(8, 6))
46
+ ylim = 30
47
+ expand = 1
48
+
49
+ if select == '1a':
50
+ sdrs = load_sdrs(workspace, task_name, filename, config='unet', gpus=1)
51
+ (line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth)
52
+ lines.append(line)
53
+
54
+ else:
55
+ raise Exception('Error!')
56
+
57
+ eval_every_iterations = 10000
58
+ total_ticks = 50
59
+ ticks_freq = 10
60
+
61
+ ax.set_ylim(0, ylim)
62
+ ax.set_xlim(0, total_ticks)
63
+ ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq))
64
+ ax.xaxis.set_ticklabels(
65
+ np.arange(
66
+ 0,
67
+ total_ticks * eval_every_iterations + 1,
68
+ ticks_freq * eval_every_iterations,
69
+ )
70
+ )
71
+ ax.yaxis.set_ticks(np.arange(ylim + 1))
72
+ ax.yaxis.set_ticklabels(np.arange(ylim + 1))
73
+ ax.grid(color='b', linestyle='solid', linewidth=0.3)
74
+ plt.legend(handles=lines, loc=4)
75
+
76
+ plt.savefig(fig_path)
77
+ print('Save figure to {}'.format(fig_path))
78
+
79
+
80
+ if __name__ == '__main__':
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument('--workspace', type=str, required=True)
83
+ parser.add_argument('--select', type=str, required=True)
84
+
85
+ args = parser.parse_args()
86
+
87
+ plot_statistics(args)
bytesep/train.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ from functools import partial
6
+ from typing import List, NoReturn
7
+
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning.plugins import DDPPlugin
10
+
11
+ from bytesep.callbacks import get_callbacks
12
+ from bytesep.data.augmentors import Augmentor
13
+ from bytesep.data.batch_data_preprocessors import (
14
+ get_batch_data_preprocessor_class,
15
+ )
16
+ from bytesep.data.data_modules import DataModule, Dataset
17
+ from bytesep.data.samplers import SegmentSampler
18
+ from bytesep.losses import get_loss_function
19
+ from bytesep.models.lightning_modules import (
20
+ LitSourceSeparation,
21
+ get_model_class,
22
+ )
23
+ from bytesep.optimizers.lr_schedulers import get_lr_lambda
24
+ from bytesep.utils import (
25
+ create_logging,
26
+ get_pitch_shift_factor,
27
+ read_yaml,
28
+ check_configs_gramma,
29
+ )
30
+
31
+
32
+ def get_dirs(
33
+ workspace: str, task_name: str, filename: str, config_yaml: str, gpus: int
34
+ ) -> List[str]:
35
+ r"""Get directories.
36
+
37
+ Args:
38
+ workspace: str
39
+ task_name, str, e.g., 'musdb18'
40
+ filenmae: str
41
+ config_yaml: str
42
+ gpus: int, e.g., 0 for cpu and 8 for training with 8 gpu cards
43
+
44
+ Returns:
45
+ checkpoints_dir: str
46
+ logs_dir: str
47
+ logger: pl.loggers.TensorBoardLogger
48
+ statistics_path: str
49
+ """
50
+
51
+ # save checkpoints dir
52
+ checkpoints_dir = os.path.join(
53
+ workspace,
54
+ "checkpoints",
55
+ task_name,
56
+ filename,
57
+ "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
58
+ )
59
+ os.makedirs(checkpoints_dir, exist_ok=True)
60
+
61
+ # logs dir
62
+ logs_dir = os.path.join(
63
+ workspace,
64
+ "logs",
65
+ task_name,
66
+ filename,
67
+ "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
68
+ )
69
+ os.makedirs(logs_dir, exist_ok=True)
70
+
71
+ # loggings
72
+ create_logging(logs_dir, filemode='w')
73
+ logging.info(args)
74
+
75
+ # tensorboard logs dir
76
+ tb_logs_dir = os.path.join(workspace, "tensorboard_logs")
77
+ os.makedirs(tb_logs_dir, exist_ok=True)
78
+
79
+ experiment_name = os.path.join(task_name, filename, pathlib.Path(config_yaml).stem)
80
+ logger = pl.loggers.TensorBoardLogger(save_dir=tb_logs_dir, name=experiment_name)
81
+
82
+ # statistics path
83
+ statistics_path = os.path.join(
84
+ workspace,
85
+ "statistics",
86
+ task_name,
87
+ filename,
88
+ "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
89
+ "statistics.pkl",
90
+ )
91
+ os.makedirs(os.path.dirname(statistics_path), exist_ok=True)
92
+
93
+ return checkpoints_dir, logs_dir, logger, statistics_path
94
+
95
+
96
+ def _get_data_module(
97
+ workspace: str, config_yaml: str, num_workers: int, distributed: bool
98
+ ) -> DataModule:
99
+ r"""Create data_module. Mini-batch data can be obtained by:
100
+
101
+ code-block:: python
102
+
103
+ data_module.setup()
104
+ for batch_data_dict in data_module.train_dataloader():
105
+ print(batch_data_dict.keys())
106
+ break
107
+
108
+ Args:
109
+ workspace: str
110
+ config_yaml: str
111
+ num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores
112
+ for preparing data in parallel
113
+ distributed: bool
114
+
115
+ Returns:
116
+ data_module: DataModule
117
+ """
118
+
119
+ configs = read_yaml(config_yaml)
120
+ input_source_types = configs['train']['input_source_types']
121
+ indexes_path = os.path.join(workspace, configs['train']['indexes_dict'])
122
+ sample_rate = configs['train']['sample_rate']
123
+ segment_seconds = configs['train']['segment_seconds']
124
+ mixaudio_dict = configs['train']['augmentations']['mixaudio']
125
+ augmentations = configs['train']['augmentations']
126
+ max_pitch_shift = max(
127
+ [
128
+ augmentations['pitch_shift'][source_type]
129
+ for source_type in input_source_types
130
+ ]
131
+ )
132
+ batch_size = configs['train']['batch_size']
133
+ steps_per_epoch = configs['train']['steps_per_epoch']
134
+
135
+ segment_samples = int(segment_seconds * sample_rate)
136
+ ex_segment_samples = int(segment_samples * get_pitch_shift_factor(max_pitch_shift))
137
+
138
+ # sampler
139
+ train_sampler = SegmentSampler(
140
+ indexes_path=indexes_path,
141
+ segment_samples=ex_segment_samples,
142
+ mixaudio_dict=mixaudio_dict,
143
+ batch_size=batch_size,
144
+ steps_per_epoch=steps_per_epoch,
145
+ )
146
+
147
+ # augmentor
148
+ augmentor = Augmentor(augmentations=augmentations)
149
+
150
+ # dataset
151
+ train_dataset = Dataset(augmentor, segment_samples)
152
+
153
+ # data module
154
+ data_module = DataModule(
155
+ train_sampler=train_sampler,
156
+ train_dataset=train_dataset,
157
+ num_workers=num_workers,
158
+ distributed=distributed,
159
+ )
160
+
161
+ return data_module
162
+
163
+
164
+ def train(args) -> NoReturn:
165
+ r"""Train & evaluate and save checkpoints.
166
+
167
+ Args:
168
+ workspace: str, directory of workspace
169
+ gpus: int
170
+ config_yaml: str, path of config file for training
171
+ """
172
+
173
+ # arugments & parameters
174
+ workspace = args.workspace
175
+ gpus = args.gpus
176
+ config_yaml = args.config_yaml
177
+ filename = args.filename
178
+
179
+ num_workers = 8
180
+ distributed = True if gpus > 1 else False
181
+ evaluate_device = "cuda" if gpus > 0 else "cpu"
182
+
183
+ # Read config file.
184
+ configs = read_yaml(config_yaml)
185
+ check_configs_gramma(configs)
186
+ task_name = configs['task_name']
187
+ target_source_types = configs['train']['target_source_types']
188
+ target_sources_num = len(target_source_types)
189
+ channels = configs['train']['channels']
190
+ batch_data_preprocessor_type = configs['train']['batch_data_preprocessor']
191
+ model_type = configs['train']['model_type']
192
+ loss_type = configs['train']['loss_type']
193
+ optimizer_type = configs['train']['optimizer_type']
194
+ learning_rate = float(configs['train']['learning_rate'])
195
+ precision = configs['train']['precision']
196
+ early_stop_steps = configs['train']['early_stop_steps']
197
+ warm_up_steps = configs['train']['warm_up_steps']
198
+ reduce_lr_steps = configs['train']['reduce_lr_steps']
199
+
200
+ # paths
201
+ checkpoints_dir, logs_dir, logger, statistics_path = get_dirs(
202
+ workspace, task_name, filename, config_yaml, gpus
203
+ )
204
+
205
+ # training data module
206
+ data_module = _get_data_module(
207
+ workspace=workspace,
208
+ config_yaml=config_yaml,
209
+ num_workers=num_workers,
210
+ distributed=distributed,
211
+ )
212
+
213
+ # batch data preprocessor
214
+ BatchDataPreprocessor = get_batch_data_preprocessor_class(
215
+ batch_data_preprocessor_type=batch_data_preprocessor_type
216
+ )
217
+
218
+ batch_data_preprocessor = BatchDataPreprocessor(
219
+ target_source_types=target_source_types
220
+ )
221
+
222
+ # model
223
+ Model = get_model_class(model_type=model_type)
224
+ model = Model(input_channels=channels, target_sources_num=target_sources_num)
225
+
226
+ # loss function
227
+ loss_function = get_loss_function(loss_type=loss_type)
228
+
229
+ # callbacks
230
+ callbacks = get_callbacks(
231
+ task_name=task_name,
232
+ config_yaml=config_yaml,
233
+ workspace=workspace,
234
+ checkpoints_dir=checkpoints_dir,
235
+ statistics_path=statistics_path,
236
+ logger=logger,
237
+ model=model,
238
+ evaluate_device=evaluate_device,
239
+ )
240
+ # callbacks = []
241
+
242
+ # learning rate reduce function
243
+ lr_lambda = partial(
244
+ get_lr_lambda, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps
245
+ )
246
+
247
+ # pytorch-lightning model
248
+ pl_model = LitSourceSeparation(
249
+ batch_data_preprocessor=batch_data_preprocessor,
250
+ model=model,
251
+ optimizer_type=optimizer_type,
252
+ loss_function=loss_function,
253
+ learning_rate=learning_rate,
254
+ lr_lambda=lr_lambda,
255
+ )
256
+
257
+ # trainer
258
+ trainer = pl.Trainer(
259
+ checkpoint_callback=False,
260
+ gpus=gpus,
261
+ callbacks=callbacks,
262
+ max_steps=early_stop_steps,
263
+ accelerator="ddp",
264
+ sync_batchnorm=True,
265
+ precision=precision,
266
+ replace_sampler_ddp=False,
267
+ plugins=[DDPPlugin(find_unused_parameters=True)],
268
+ profiler='simple',
269
+ )
270
+
271
+ # Fit, evaluate, and save checkpoints.
272
+ trainer.fit(pl_model, data_module)
273
+
274
+
275
+ if __name__ == "__main__":
276
+
277
+ parser = argparse.ArgumentParser(description="")
278
+ subparsers = parser.add_subparsers(dest="mode")
279
+
280
+ parser_train = subparsers.add_parser("train")
281
+ parser_train.add_argument(
282
+ "--workspace", type=str, required=True, help="Directory of workspace."
283
+ )
284
+ parser_train.add_argument("--gpus", type=int, required=True)
285
+ parser_train.add_argument(
286
+ "--config_yaml",
287
+ type=str,
288
+ required=True,
289
+ help="Path of config file for training.",
290
+ )
291
+
292
+ args = parser.parse_args()
293
+ args.filename = pathlib.Path(__file__).stem
294
+
295
+ if args.mode == "train":
296
+ train(args)
297
+
298
+ else:
299
+ raise Exception("Error argument!")
bytesep/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import os
4
+ import pickle
5
+ from typing import Dict, NoReturn
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import yaml
10
+
11
+
12
+ def create_logging(log_dir: str, filemode: str) -> logging:
13
+ r"""Create logging to write out log files.
14
+
15
+ Args:
16
+ logs_dir, str, directory to write out logs
17
+ filemode: str, e.g., "w"
18
+
19
+ Returns:
20
+ logging
21
+ """
22
+ os.makedirs(log_dir, exist_ok=True)
23
+ i1 = 0
24
+
25
+ while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
26
+ i1 += 1
27
+
28
+ log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
29
+ logging.basicConfig(
30
+ level=logging.DEBUG,
31
+ format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
32
+ datefmt="%a, %d %b %Y %H:%M:%S",
33
+ filename=log_path,
34
+ filemode=filemode,
35
+ )
36
+
37
+ # Print to console
38
+ console = logging.StreamHandler()
39
+ console.setLevel(logging.INFO)
40
+ formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
41
+ console.setFormatter(formatter)
42
+ logging.getLogger("").addHandler(console)
43
+
44
+ return logging
45
+
46
+
47
+ def load_audio(
48
+ audio_path: str,
49
+ mono: bool,
50
+ sample_rate: float,
51
+ offset: float = 0.0,
52
+ duration: float = None,
53
+ ) -> np.array:
54
+ r"""Load audio.
55
+
56
+ Args:
57
+ audio_path: str
58
+ mono: bool
59
+ sample_rate: float
60
+ """
61
+ audio, _ = librosa.core.load(
62
+ audio_path, sr=sample_rate, mono=mono, offset=offset, duration=duration
63
+ )
64
+ # (audio_samples,) | (channels_num, audio_samples)
65
+
66
+ if audio.ndim == 1:
67
+ audio = audio[None, :]
68
+ # (1, audio_samples,)
69
+
70
+ return audio
71
+
72
+
73
+ def load_random_segment(
74
+ audio_path: str, random_state, segment_seconds: float, mono: bool, sample_rate: int
75
+ ) -> np.array:
76
+ r"""Randomly select an audio segment from a recording."""
77
+
78
+ duration = librosa.get_duration(filename=audio_path)
79
+
80
+ start_time = random_state.uniform(0.0, duration - segment_seconds)
81
+
82
+ audio = load_audio(
83
+ audio_path=audio_path,
84
+ mono=mono,
85
+ sample_rate=sample_rate,
86
+ offset=start_time,
87
+ duration=segment_seconds,
88
+ )
89
+ # (channels_num, audio_samples)
90
+
91
+ return audio
92
+
93
+
94
+ def float32_to_int16(x: np.float32) -> np.int16:
95
+
96
+ x = np.clip(x, a_min=-1, a_max=1)
97
+
98
+ return (x * 32767.0).astype(np.int16)
99
+
100
+
101
+ def int16_to_float32(x: np.int16) -> np.float32:
102
+
103
+ return (x / 32767.0).astype(np.float32)
104
+
105
+
106
+ def read_yaml(config_yaml: str):
107
+
108
+ with open(config_yaml, "r") as fr:
109
+ configs = yaml.load(fr, Loader=yaml.FullLoader)
110
+
111
+ return configs
112
+
113
+
114
+ def check_configs_gramma(configs: Dict) -> NoReturn:
115
+ r"""Check if the gramma of the config dictionary for training is legal."""
116
+ input_source_types = configs['train']['input_source_types']
117
+
118
+ for augmentation_type in configs['train']['augmentations'].keys():
119
+ augmentation_dict = configs['train']['augmentations'][augmentation_type]
120
+
121
+ for source_type in augmentation_dict.keys():
122
+ if source_type not in input_source_types:
123
+ error_msg = (
124
+ "The source type '{}'' in configs['train']['augmentations']['{}'] "
125
+ "must be one of input_source_types {}".format(
126
+ source_type, augmentation_type, input_source_types
127
+ )
128
+ )
129
+ raise Exception(error_msg)
130
+
131
+
132
+ def magnitude_to_db(x: float) -> float:
133
+ eps = 1e-10
134
+ return 20.0 * np.log10(max(x, eps))
135
+
136
+
137
+ def db_to_magnitude(x: float) -> float:
138
+ return 10.0 ** (x / 20)
139
+
140
+
141
+ def get_pitch_shift_factor(shift_pitch: float) -> float:
142
+ r"""The factor of the audio length to be scaled."""
143
+ return 2 ** (shift_pitch / 12)
144
+
145
+
146
+ class StatisticsContainer(object):
147
+ def __init__(self, statistics_path):
148
+ self.statistics_path = statistics_path
149
+
150
+ self.backup_statistics_path = "{}_{}.pkl".format(
151
+ os.path.splitext(self.statistics_path)[0],
152
+ datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
153
+ )
154
+
155
+ self.statistics_dict = {"train": [], "test": []}
156
+
157
+ def append(self, steps, statistics, split):
158
+ statistics["steps"] = steps
159
+ self.statistics_dict[split].append(statistics)
160
+
161
+ def dump(self):
162
+ pickle.dump(self.statistics_dict, open(self.statistics_path, "wb"))
163
+ pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb"))
164
+ logging.info(" Dump statistics to {}".format(self.statistics_path))
165
+ logging.info(" Dump statistics to {}".format(self.backup_statistics_path))
166
+
167
+ '''
168
+ def load_state_dict(self, resume_steps):
169
+ self.statistics_dict = pickle.load(open(self.statistics_path, "rb"))
170
+
171
+ resume_statistics_dict = {"train": [], "test": []}
172
+
173
+ for key in self.statistics_dict.keys():
174
+ for statistics in self.statistics_dict[key]:
175
+ if statistics["steps"] <= resume_steps:
176
+ resume_statistics_dict[key].append(statistics)
177
+
178
+ self.statistics_dict = resume_statistics_dict
179
+ '''
180
+
181
+
182
+ def calculate_sdr(ref: np.array, est: np.array) -> float:
183
+ s_true = ref
184
+ s_artif = est - ref
185
+ sdr = 10.0 * (
186
+ np.log10(np.clip(np.mean(s_true ** 2), 1e-8, np.inf))
187
+ - np.log10(np.clip(np.mean(s_artif ** 2), 1e-8, np.inf))
188
+ )
189
+ return sdr
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 88
3
+ target-version = ['py37']
4
+ skip-string-normalization = true
5
+ include = '\.pyi?$'
6
+ exclude = '''
7
+ (
8
+ /(
9
+ \.eggs # exclude a few common directories in the
10
+ | \.git # root of the project
11
+ | \.hg
12
+ | \.mypy_cache
13
+ | \.tox
14
+ | \.venv
15
+ | _build
16
+ | buck-out
17
+ | build
18
+ | dist
19
+ )/
20
+ )
21
+ '''