ttettheu commited on
Commit
75754ee
·
verified ·
1 Parent(s): f2ba2a7

Create mdx.py

Browse files
Files changed (1) hide show
  1. mdx.py +220 -0
mdx.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import onnxruntime as ort
3
+ from tqdm import tqdm
4
+ import warnings
5
+ import numpy as np
6
+ import hashlib
7
+ import queue
8
+ import threading
9
+
10
+ warnings.filterwarnings("ignore")
11
+
12
+ class MDX_Model:
13
+ def __init__(self, device, dim_f, dim_t, n_fft, hop=1024, stem_name=None, compensation=1.000):
14
+ self.dim_f = dim_f
15
+ self.dim_t = dim_t
16
+ self.dim_c = 4
17
+ self.n_fft = n_fft
18
+ self.hop = hop
19
+ self.stem_name = stem_name
20
+ self.compensation = compensation
21
+
22
+ self.n_bins = self.n_fft//2+1
23
+ self.chunk_size = hop * (self.dim_t-1)
24
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
25
+
26
+ out_c = self.dim_c
27
+
28
+ self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(device)
29
+
30
+ def stft(self, x):
31
+ x = x.reshape([-1, self.chunk_size])
32
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
33
+ x = torch.view_as_real(x)
34
+ x = x.permute([0,3,1,2])
35
+ x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,4,self.n_bins,self.dim_t])
36
+ return x[:,:,:self.dim_f]
37
+
38
+ def istft(self, x, freq_pad=None):
39
+ freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
40
+ x = torch.cat([x, freq_pad], -2)
41
+ # c = 4*2 if self.target_name=='*' else 2
42
+ x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
43
+ x = x.permute([0,2,3,1])
44
+ x = x.contiguous()
45
+ x = torch.view_as_complex(x)
46
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
47
+ return x.reshape([-1,2,self.chunk_size])
48
+
49
+
50
+ class MDX:
51
+
52
+ DEFAULT_SR = 44100
53
+ # Unit: seconds
54
+ DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
55
+ DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
56
+
57
+ DEFAULT_PROCESSOR = 0
58
+
59
+ def __init__(self, model_path:str, params:MDX_Model, processor=DEFAULT_PROCESSOR):
60
+
61
+ # Set the device and the provider (CPU or CUDA)
62
+ self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
63
+ self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
64
+
65
+ self.model = params
66
+
67
+ # Load the ONNX model using ONNX Runtime
68
+ self.ort = ort.InferenceSession(model_path, providers=self.provider)
69
+ # Preload the model for faster performance
70
+ self.ort.run(None, {'input':torch.rand(1, 4, params.dim_f, params.dim_t).numpy()})
71
+ self.process = lambda spec:self.ort.run(None, {'input': spec.cpu().numpy()})[0]
72
+
73
+ self.prog = None
74
+
75
+ @staticmethod
76
+ def get_hash(model_path):
77
+ try:
78
+ with open(model_path, 'rb') as f:
79
+ f.seek(- 10000 * 1024, 2)
80
+ model_hash = hashlib.md5(f.read()).hexdigest()
81
+ except:
82
+ model_hash = hashlib.md5(open(model_path,'rb').read()).hexdigest()
83
+
84
+ return model_hash
85
+
86
+ @staticmethod
87
+ def segment(wave, combine=True, chunk_size=DEFAULT_CHUNK_SIZE, margin_size=DEFAULT_MARGIN_SIZE):
88
+ """
89
+ Segment or join segmented wave array
90
+ Args:
91
+ wave: (np.array) Wave array to be segmented or joined
92
+ combine: (bool) If True, combines segmented wave array. If False, segments wave array.
93
+ chunk_size: (int) Size of each segment (in samples)
94
+ margin_size: (int) Size of margin between segments (in samples)
95
+ Returns:
96
+ numpy array: Segmented or joined wave array
97
+ """
98
+
99
+ if combine:
100
+ processed_wave = None # Initializing as None instead of [] for later numpy array concatenation
101
+ for segment_count, segment in enumerate(wave):
102
+ start = 0 if segment_count == 0 else margin_size
103
+ end = None if segment_count == len(wave)-1 else -margin_size
104
+ if margin_size == 0:
105
+ end = None
106
+ if processed_wave is None: # Create array for first segment
107
+ processed_wave = segment[:, start:end]
108
+ else: # Concatenate to existing array for subsequent segments
109
+ processed_wave = np.concatenate((processed_wave, segment[:, start:end]), axis=-1)
110
+
111
+ else:
112
+ processed_wave = []
113
+ sample_count = wave.shape[-1]
114
+
115
+ if chunk_size <= 0 or chunk_size > sample_count:
116
+ chunk_size = sample_count
117
+
118
+ if margin_size > chunk_size:
119
+ margin_size = chunk_size
120
+
121
+ for segment_count, skip in enumerate(range(0, sample_count, chunk_size)):
122
+
123
+ margin = 0 if segment_count == 0 else margin_size
124
+ end = min(skip+chunk_size+margin_size, sample_count)
125
+ start = skip-margin
126
+
127
+ cut = wave[:,start:end].copy()
128
+ processed_wave.append(cut)
129
+
130
+ if end == sample_count:
131
+ break
132
+
133
+ return processed_wave
134
+
135
+ def pad_wave(self, wave):
136
+ """
137
+ Pad the wave array to match the required chunk size
138
+ Args:
139
+ wave: (np.array) Wave array to be padded
140
+ Returns:
141
+ tuple: (padded_wave, pad, trim)
142
+ - padded_wave: Padded wave array
143
+ - pad: Number of samples that were padded
144
+ - trim: Number of samples that were trimmed
145
+ """
146
+ n_sample = wave.shape[1]
147
+ trim = self.model.n_fft//2
148
+ gen_size = self.model.chunk_size-2*trim
149
+ pad = gen_size - n_sample%gen_size
150
+
151
+ # Padded wave
152
+ wave_p = np.concatenate((np.zeros((2,trim)), wave, np.zeros((2,pad)), np.zeros((2,trim))), 1)
153
+
154
+ mix_waves = []
155
+ for i in range(0, n_sample+pad, gen_size):
156
+ waves = np.array(wave_p[:, i:i+self.model.chunk_size])
157
+ mix_waves.append(waves)
158
+
159
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
160
+
161
+ return mix_waves, pad, trim
162
+
163
+ def _process_wave(self, mix_waves, trim, pad, q:queue.Queue, _id:int):
164
+ """
165
+ Process each wave segment in a multi-threaded environment
166
+ Args:
167
+ mix_waves: (torch.Tensor) Wave segments to be processed
168
+ trim: (int) Number of samples trimmed during padding
169
+ pad: (int) Number of samples padded during padding
170
+ q: (queue.Queue) Queue to hold the processed wave segments
171
+ _id: (int) Identifier of the processed wave segment
172
+ Returns:
173
+ numpy array: Processed wave segment
174
+ """
175
+ mix_waves = mix_waves.split(1)
176
+ with torch.no_grad():
177
+ pw = []
178
+ for mix_wave in mix_waves:
179
+ self.prog.update()
180
+ spec = self.model.stft(mix_wave)
181
+ processed_spec = torch.tensor(self.process(spec))
182
+ processed_wav = self.model.istft(processed_spec.to(self.device))
183
+ processed_wav = processed_wav[:,:,trim:-trim].transpose(0,1).reshape(2, -1).cpu().numpy()
184
+ pw.append(processed_wav)
185
+ processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
186
+ q.put({_id:processed_signal})
187
+ return processed_signal
188
+
189
+ def process_wave(self, wave:np.array, mt_threads=1):
190
+ """
191
+ Process the wave array in a multi-threaded environment
192
+ Args:
193
+ wave: (np.array) Wave array to be processed
194
+ mt_threads: (int) Number of threads to be used for processing
195
+ Returns:
196
+ numpy array: Processed wave array
197
+ """
198
+ self.prog = tqdm(total=0)
199
+ chunk = wave.shape[-1]//mt_threads
200
+ waves = self.segment(wave, False, chunk)
201
+
202
+ # Create a queue to hold the processed wave segments
203
+ q = queue.Queue()
204
+ threads = []
205
+ for c, batch in enumerate(waves):
206
+ mix_waves, pad, trim = self.pad_wave(batch)
207
+ self.prog.total = len(mix_waves)*mt_threads
208
+ thread = threading.Thread(target=self._process_wave, args=(mix_waves, trim, pad, q, c))
209
+ thread.start()
210
+ threads.append(thread)
211
+ for thread in threads:
212
+ thread.join()
213
+ self.prog.close()
214
+
215
+ processed_batches = []
216
+ while not q.empty():
217
+ processed_batches.append(q.get())
218
+ processed_batches = [list(wave.values())[0] for wave in sorted(processed_batches, key=lambda d: list(d.keys())[0])]
219
+ assert len(processed_batches) == len(waves), 'Incomplete processed batches, please reduce batch size!'
220
+ return self.segment(processed_batches, True, chunk)