ArianatorQualquer commited on
Commit
1277194
·
verified ·
1 Parent(s): 7918631

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +223 -0
utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
3
+
4
+ import time
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import yaml
9
+ from ml_collections import ConfigDict
10
+ from omegaconf import OmegaConf
11
+ from tqdm import tqdm
12
+
13
+ def get_model_from_config(model_type, config_path):
14
+ with open(config_path) as f:
15
+ if model_type == 'htdemucs':
16
+ config = OmegaConf.load(config_path)
17
+ else:
18
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
19
+
20
+ if model_type == 'mdx23c':
21
+ from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
22
+ model = TFC_TDF_net(config)
23
+ elif model_type == 'htdemucs':
24
+ from models.demucs4ht import get_model
25
+ model = get_model(config)
26
+ elif model_type == 'segm_models':
27
+ from models.segm_models import Segm_Models_Net
28
+ model = Segm_Models_Net(config)
29
+ elif model_type == 'torchseg':
30
+ from models.torchseg_models import Torchseg_Net
31
+ model = Torchseg_Net(config)
32
+ elif model_type == 'mel_band_roformer':
33
+ from models.bs_roformer import MelBandRoformer
34
+ model = MelBandRoformer(
35
+ **dict(config.model)
36
+ )
37
+ elif model_type == 'bs_roformer':
38
+ from models.bs_roformer import BSRoformer
39
+ model = BSRoformer(
40
+ **dict(config.model)
41
+ )
42
+ elif model_type == 'swin_upernet':
43
+ from models.upernet_swin_transformers import Swin_UperNet_Model
44
+ model = Swin_UperNet_Model(config)
45
+ elif model_type == 'bandit':
46
+ from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
47
+ model = MultiMaskMultiSourceBandSplitRNNSimple(
48
+ **config.model
49
+ )
50
+ elif model_type == 'bandit_v2':
51
+ from models.bandit_v2.bandit import Bandit
52
+ model = Bandit(
53
+ **config.kwargs
54
+ )
55
+ elif model_type == 'scnet_unofficial':
56
+ from models.scnet_unofficial import SCNet
57
+ model = SCNet(
58
+ **config.model
59
+ )
60
+ elif model_type == 'scnet':
61
+ from models.scnet import SCNet
62
+ model = SCNet(
63
+ **config.model
64
+ )
65
+ else:
66
+ print('Unknown model: {}'.format(model_type))
67
+ model = None
68
+
69
+ return model, config
70
+
71
+ def _getWindowingArray(window_size, fade_size):
72
+ fadein = torch.linspace(0, 1, fade_size)
73
+ fadeout = torch.linspace(1, 0, fade_size)
74
+ window = torch.ones(window_size)
75
+ window[-fade_size:] *= fadeout
76
+ window[:fade_size] *= fadein
77
+ return window
78
+
79
+ def demix_track(config, model, mix, device, pbar=False):
80
+ # Verifique se 'use_amp' está presente e defina um padrão se não estiver
81
+ use_amp = getattr(config.training, 'use_amp', False)
82
+
83
+ C = config.audio.chunk_size
84
+ N = config.inference.num_overlap
85
+ fade_size = C // 10
86
+ step = int(C // N)
87
+ border = C - step
88
+ batch_size = config.inference.batch_size
89
+
90
+ length_init = mix.shape[-1]
91
+
92
+ # Do pad from the beginning and end to account floating window results better
93
+ if length_init > 2 * border and (border > 0):
94
+ mix = nn.functional.pad(mix, (border, border), mode='reflect')
95
+
96
+ # windowingArray crossfades at segment boundaries to mitigate clicking artifacts
97
+ windowingArray = _getWindowingArray(C, fade_size)
98
+
99
+ with torch.cuda.amp.autocast(enabled=use_amp):
100
+ with torch.inference_mode():
101
+ if config.training.target_instrument is not None:
102
+ req_shape = (1, ) + tuple(mix.shape)
103
+ else:
104
+ req_shape = (len(config.training.instruments),) + tuple(mix.shape)
105
+
106
+ result = torch.zeros(req_shape, dtype=torch.float32)
107
+ counter = torch.zeros(req_shape, dtype=torch.float32)
108
+ i = 0
109
+ batch_data = []
110
+ batch_locations = []
111
+ progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
112
+
113
+ while i < mix.shape[1]:
114
+ part = mix[:, i:i + C].to(device)
115
+ length = part.shape[-1]
116
+ if length < C:
117
+ if length > C // 2 + 1:
118
+ part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
119
+ else:
120
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
121
+ batch_data.append(part)
122
+ batch_locations.append((i, length))
123
+ i += step
124
+
125
+ if len(batch_data) >= batch_size or (i >= mix.shape[1]):
126
+ arr = torch.stack(batch_data, dim=0)
127
+ x = model(arr)
128
+
129
+ window = windowingArray
130
+ if i - step == 0: # First audio chunk, no fadein
131
+ window[:fade_size] = 1
132
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
133
+ window[-fade_size:] = 1
134
+
135
+ for j in range(len(batch_locations)):
136
+ start, l = batch_locations[j]
137
+ result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
138
+ counter[..., start:start+l] += window[..., :l]
139
+
140
+ batch_data = []
141
+ batch_locations = []
142
+
143
+ if progress_bar:
144
+ progress_bar.update(step)
145
+
146
+ if progress_bar:
147
+ progress_bar.close()
148
+
149
+ estimated_sources = result / counter
150
+ estimated_sources = estimated_sources.cpu().numpy()
151
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
152
+
153
+ if length_init > 2 * border and (border > 0):
154
+ # Remove pad
155
+ estimated_sources = estimated_sources[..., border:-border]
156
+
157
+ if config.training.target_instrument is None:
158
+ return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
159
+ else:
160
+ return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)}
161
+
162
+ def demix_track_demucs(config, model, mix, device, pbar=False):
163
+ # Verifique se 'use_amp' está presente e defina um padrão se não estiver
164
+ use_amp = getattr(config.training, 'use_amp', False)
165
+
166
+ S = len(config.training.instruments)
167
+ C = config.training.samplerate * config.training.segment
168
+ N = config.inference.num_overlap
169
+ batch_size = config.inference.batch_size
170
+ step = C // N
171
+
172
+ with torch.cuda.amp.autocast(enabled=use_amp):
173
+ with torch.inference_mode():
174
+ req_shape = (S, ) + tuple(mix.shape)
175
+ result = torch.zeros(req_shape, dtype=torch.float32)
176
+ counter = torch.zeros(req_shape, dtype=torch.float32)
177
+ i = 0
178
+ batch_data = []
179
+ batch_locations = []
180
+ progress_bar = tqdm(total=mix.shape[1], desc="Processing audio chunks", leave=False) if pbar else None
181
+
182
+ while i < mix.shape[1]:
183
+ part = mix[:, i:i + C].to(device)
184
+ length = part.shape[-1]
185
+ if length < C:
186
+ part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
187
+ batch_data.append(part)
188
+ batch_locations.append((i, length))
189
+ i += step
190
+
191
+ if len(batch_data) >= batch_size or (i >= mix.shape[1]):
192
+ arr = torch.stack(batch_data, dim=0)
193
+ x = model(arr)
194
+ for j in range(len(batch_locations)):
195
+ start, l = batch_locations[j]
196
+ result[..., start:start+l] += x[j][..., :l].cpu()
197
+ counter[..., start:start+l] += 1.
198
+ batch_data = []
199
+ batch_locations = []
200
+
201
+ if progress_bar:
202
+ progress_bar.update(step)
203
+
204
+ if progress_bar:
205
+ progress_bar.close()
206
+
207
+ estimated_sources = result / counter
208
+ estimated_sources = estimated_sources.cpu().numpy()
209
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
210
+
211
+ if S > 1:
212
+ return {k: v for k, v in zip(config.training.instruments, estimated_sources)}
213
+ else:
214
+ return estimated_sources
215
+
216
+ def sdr(references, estimates):
217
+ # compute SDR for one song
218
+ delta = 1e-7 # avoid numerical errors
219
+ num = np.sum(np.square(references), axis=(1, 2))
220
+ den = np.sum(np.square(references - estimates), axis=(1, 2))
221
+ num += delta
222
+ den += delta
223
+ return 10 * np.log10(num / den)