MagnataMusic / mdx.py
ttettheu's picture
Create mdx.py
75754ee verified
import torch
import onnxruntime as ort
from tqdm import tqdm
import warnings
import numpy as np
import hashlib
import queue
import threading
warnings.filterwarnings("ignore")
class MDX_Model:
def __init__(self, device, dim_f, dim_t, n_fft, hop=1024, stem_name=None, compensation=1.000):
self.dim_f = dim_f
self.dim_t = dim_t
self.dim_c = 4
self.n_fft = n_fft
self.hop = hop
self.stem_name = stem_name
self.compensation = compensation
self.n_bins = self.n_fft//2+1
self.chunk_size = hop * (self.dim_t-1)
self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(device)
out_c = self.dim_c
self.freq_pad = torch.zeros([1, out_c, self.n_bins-self.dim_f, self.dim_t]).to(device)
def stft(self, x):
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
x = torch.view_as_real(x)
x = x.permute([0,3,1,2])
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,4,self.n_bins,self.dim_t])
return x[:,:,:self.dim_f]
def istft(self, x, freq_pad=None):
freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
# c = 4*2 if self.target_name=='*' else 2
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
x = x.permute([0,2,3,1])
x = x.contiguous()
x = torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1,2,self.chunk_size])
class MDX:
DEFAULT_SR = 44100
# Unit: seconds
DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
DEFAULT_PROCESSOR = 0
def __init__(self, model_path:str, params:MDX_Model, processor=DEFAULT_PROCESSOR):
# Set the device and the provider (CPU or CUDA)
self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
self.model = params
# Load the ONNX model using ONNX Runtime
self.ort = ort.InferenceSession(model_path, providers=self.provider)
# Preload the model for faster performance
self.ort.run(None, {'input':torch.rand(1, 4, params.dim_f, params.dim_t).numpy()})
self.process = lambda spec:self.ort.run(None, {'input': spec.cpu().numpy()})[0]
self.prog = None
@staticmethod
def get_hash(model_path):
try:
with open(model_path, 'rb') as f:
f.seek(- 10000 * 1024, 2)
model_hash = hashlib.md5(f.read()).hexdigest()
except:
model_hash = hashlib.md5(open(model_path,'rb').read()).hexdigest()
return model_hash
@staticmethod
def segment(wave, combine=True, chunk_size=DEFAULT_CHUNK_SIZE, margin_size=DEFAULT_MARGIN_SIZE):
"""
Segment or join segmented wave array
Args:
wave: (np.array) Wave array to be segmented or joined
combine: (bool) If True, combines segmented wave array. If False, segments wave array.
chunk_size: (int) Size of each segment (in samples)
margin_size: (int) Size of margin between segments (in samples)
Returns:
numpy array: Segmented or joined wave array
"""
if combine:
processed_wave = None # Initializing as None instead of [] for later numpy array concatenation
for segment_count, segment in enumerate(wave):
start = 0 if segment_count == 0 else margin_size
end = None if segment_count == len(wave)-1 else -margin_size
if margin_size == 0:
end = None
if processed_wave is None: # Create array for first segment
processed_wave = segment[:, start:end]
else: # Concatenate to existing array for subsequent segments
processed_wave = np.concatenate((processed_wave, segment[:, start:end]), axis=-1)
else:
processed_wave = []
sample_count = wave.shape[-1]
if chunk_size <= 0 or chunk_size > sample_count:
chunk_size = sample_count
if margin_size > chunk_size:
margin_size = chunk_size
for segment_count, skip in enumerate(range(0, sample_count, chunk_size)):
margin = 0 if segment_count == 0 else margin_size
end = min(skip+chunk_size+margin_size, sample_count)
start = skip-margin
cut = wave[:,start:end].copy()
processed_wave.append(cut)
if end == sample_count:
break
return processed_wave
def pad_wave(self, wave):
"""
Pad the wave array to match the required chunk size
Args:
wave: (np.array) Wave array to be padded
Returns:
tuple: (padded_wave, pad, trim)
- padded_wave: Padded wave array
- pad: Number of samples that were padded
- trim: Number of samples that were trimmed
"""
n_sample = wave.shape[1]
trim = self.model.n_fft//2
gen_size = self.model.chunk_size-2*trim
pad = gen_size - n_sample%gen_size
# Padded wave
wave_p = np.concatenate((np.zeros((2,trim)), wave, np.zeros((2,pad)), np.zeros((2,trim))), 1)
mix_waves = []
for i in range(0, n_sample+pad, gen_size):
waves = np.array(wave_p[:, i:i+self.model.chunk_size])
mix_waves.append(waves)
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
return mix_waves, pad, trim
def _process_wave(self, mix_waves, trim, pad, q:queue.Queue, _id:int):
"""
Process each wave segment in a multi-threaded environment
Args:
mix_waves: (torch.Tensor) Wave segments to be processed
trim: (int) Number of samples trimmed during padding
pad: (int) Number of samples padded during padding
q: (queue.Queue) Queue to hold the processed wave segments
_id: (int) Identifier of the processed wave segment
Returns:
numpy array: Processed wave segment
"""
mix_waves = mix_waves.split(1)
with torch.no_grad():
pw = []
for mix_wave in mix_waves:
self.prog.update()
spec = self.model.stft(mix_wave)
processed_spec = torch.tensor(self.process(spec))
processed_wav = self.model.istft(processed_spec.to(self.device))
processed_wav = processed_wav[:,:,trim:-trim].transpose(0,1).reshape(2, -1).cpu().numpy()
pw.append(processed_wav)
processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
q.put({_id:processed_signal})
return processed_signal
def process_wave(self, wave:np.array, mt_threads=1):
"""
Process the wave array in a multi-threaded environment
Args:
wave: (np.array) Wave array to be processed
mt_threads: (int) Number of threads to be used for processing
Returns:
numpy array: Processed wave array
"""
self.prog = tqdm(total=0)
chunk = wave.shape[-1]//mt_threads
waves = self.segment(wave, False, chunk)
# Create a queue to hold the processed wave segments
q = queue.Queue()
threads = []
for c, batch in enumerate(waves):
mix_waves, pad, trim = self.pad_wave(batch)
self.prog.total = len(mix_waves)*mt_threads
thread = threading.Thread(target=self._process_wave, args=(mix_waves, trim, pad, q, c))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
self.prog.close()
processed_batches = []
while not q.empty():
processed_batches.append(q.get())
processed_batches = [list(wave.values())[0] for wave in sorted(processed_batches, key=lambda d: list(d.keys())[0])]
assert len(processed_batches) == len(waves), 'Incomplete processed batches, please reduce batch size!'
return self.segment(processed_batches, True, chunk)