Spaces:
Running
Running
initial commit
Browse files- .gitattributes +5 -0
- data/models/UVR-MDX-NET-Inst_HQ_3.onnx +3 -0
- data/samples/result.mp4 +3 -0
- data/samples/temp.mp3 +3 -0
- data/samples/temp.mp4 +3 -0
- data/samples/temp_no_vocals.wav +3 -0
- data/samples/temp_vocals.wav +3 -0
- demo.py +16 -0
- model.py +123 -0
- packages.txt +2 -0
- requirements.txt +11 -0
- source_separation.py +291 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/samples/result.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/samples/temp_no_vocals.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/samples/temp_vocals.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/samples/temp.mp3 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/samples/temp.mp4 filter=lfs diff=lfs merge=lfs -text
|
data/models/UVR-MDX-NET-Inst_HQ_3.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:317554b07fe1ea5279a77f2b1520a41ea4b93432560c4ffd08792c30fddf9adc
|
3 |
+
size 66759214
|
data/samples/result.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f3a5c839f552d27b110e7db77ac74cb41a5c51c6c8376a75814aa4fc5a0c5921
|
3 |
+
size 16601916
|
data/samples/temp.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1322f661bf6c9b22a6e30283933f223358ad68fab06d73017cb80363e6e3ff50
|
3 |
+
size 4749941
|
data/samples/temp.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:302dd0780f1420599fa5bc179eb766981aac39883b4b79f8f0273f94d11d2542
|
3 |
+
size 14761845
|
data/samples/temp_no_vocals.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:96ea44ba19641369a63e5ab8ec403e204b88e7aab35b7670f6af2b6811d912de
|
3 |
+
size 26179568
|
data/samples/temp_vocals.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:087b4afcc655ab2b0c0e25e196ee559bb661c996438ff897e5ef671cd51f4564
|
3 |
+
size 26179568
|
demo.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from youtube_karaoke.model import get_karaoke
|
4 |
+
|
5 |
+
with gr.Blocks() as demo:
|
6 |
+
with gr.Row():
|
7 |
+
with gr.Column(), gr.Row():
|
8 |
+
url = gr.Textbox(placeholder="Youtube video URL", label="URL")
|
9 |
+
|
10 |
+
with gr.Column():
|
11 |
+
outputs = gr.PlayableVideo()
|
12 |
+
|
13 |
+
transcribe_btn = gr.Button("YouTube Karaoke")
|
14 |
+
transcribe_btn.click(get_karaoke, inputs=url, outputs=outputs)
|
15 |
+
|
16 |
+
demo.launch(debug=True)
|
model.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import soundfile as sf
|
3 |
+
|
4 |
+
# import torch
|
5 |
+
from moviepy import AudioFileClip, VideoFileClip
|
6 |
+
from pydub import AudioSegment
|
7 |
+
from pytubefix import YouTube
|
8 |
+
from pytubefix.cli import on_progress
|
9 |
+
|
10 |
+
# from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
11 |
+
from youtube_karaoke.source_separation import Predictor
|
12 |
+
|
13 |
+
|
14 |
+
def download_from_youtube(url, folder_path):
|
15 |
+
yt = YouTube(url, on_progress_callback=on_progress)
|
16 |
+
print(yt.title)
|
17 |
+
|
18 |
+
ys = yt.streams.get_highest_resolution()
|
19 |
+
ys.download(output_path=folder_path, filename="temp.mp4")
|
20 |
+
|
21 |
+
|
22 |
+
def separate_video_and_audio(video_path, audio_path):
|
23 |
+
# Load the video clip
|
24 |
+
video_clip = VideoFileClip(video_path)
|
25 |
+
|
26 |
+
# Extract the audio from the video clip
|
27 |
+
audio_clip = video_clip.audio
|
28 |
+
|
29 |
+
# Write the audio to a separate file
|
30 |
+
audio_clip.write_audiofile(audio_path)
|
31 |
+
|
32 |
+
|
33 |
+
def load_audio(audio_path, sample_rate=44_100):
|
34 |
+
audio = AudioSegment.from_file(audio_path)
|
35 |
+
|
36 |
+
print("Entering the preprocessing of audio")
|
37 |
+
|
38 |
+
# Convert the audio file to WAV format
|
39 |
+
audio = audio.set_frame_rate(sample_rate)
|
40 |
+
audio = audio.set_sample_width(2) # Set bit depth to 16bit
|
41 |
+
audio = audio.set_channels(1) # Set to mono
|
42 |
+
|
43 |
+
print("Audio file converted to WAV format")
|
44 |
+
|
45 |
+
# Calculate the gain to be applied
|
46 |
+
target_dBFS = -20
|
47 |
+
gain = target_dBFS - audio.dBFS
|
48 |
+
print(f"Calculating the gain needed for the audio: {gain} dB")
|
49 |
+
|
50 |
+
# Normalize volume and limit gain range to between -3 and 3
|
51 |
+
normalized_audio = audio.apply_gain(min(max(gain, -3), 3))
|
52 |
+
|
53 |
+
waveform = np.array(normalized_audio.get_array_of_samples(), dtype=np.float32)
|
54 |
+
max_amplitude = np.max(np.abs(waveform))
|
55 |
+
waveform /= max_amplitude # Normalize
|
56 |
+
|
57 |
+
print(f"waveform shape: {waveform.shape}")
|
58 |
+
print("waveform in np ndarray, dtype=" + str(waveform.dtype))
|
59 |
+
|
60 |
+
return waveform, sample_rate
|
61 |
+
|
62 |
+
|
63 |
+
args = {
|
64 |
+
"model_path": "data/models/UVR-MDX-NET-Inst_HQ_3.onnx",
|
65 |
+
"denoise": True,
|
66 |
+
"margin": 44100,
|
67 |
+
"chunks": 15,
|
68 |
+
"n_fft": 6144,
|
69 |
+
"dim_t": 8,
|
70 |
+
"dim_f": 3072,
|
71 |
+
}
|
72 |
+
|
73 |
+
separate_predictor = Predictor(args=args, device="cpu")
|
74 |
+
|
75 |
+
|
76 |
+
def source_separation(waveform):
|
77 |
+
"""
|
78 |
+
Separate the audio into vocals and non-vocals using the given predictor.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
predictor: The separation model predictor.
|
82 |
+
audio (str or dict): The audio file path or a dictionary containing audio waveform and sample rate.
|
83 |
+
|
84 |
+
Returns
|
85 |
+
-------
|
86 |
+
dict: A dictionary containing the separated vocals and updated audio waveform.
|
87 |
+
"""
|
88 |
+
vocals, no_vocals = separate_predictor.predict(waveform)
|
89 |
+
|
90 |
+
vocals = vocals[:, 0] # vocals is stereo, only use one channel
|
91 |
+
no_vocals = no_vocals[:, 0] # no_vocals is stereo, only use one channel
|
92 |
+
|
93 |
+
return vocals, no_vocals
|
94 |
+
|
95 |
+
|
96 |
+
def export_to_wav(vocals, no_vocals, sample_rate, folder_path):
|
97 |
+
"""Export segmented audio to WAV files."""
|
98 |
+
sf.write(folder_path + "temp_vocals.wav", vocals, sample_rate)
|
99 |
+
sf.write(folder_path + "temp_no_vocals.wav", no_vocals, sample_rate)
|
100 |
+
|
101 |
+
|
102 |
+
def combine_video_and_audio(video_path, no_vocals_path, output_path):
|
103 |
+
my_clip = VideoFileClip(video_path, audio=False)
|
104 |
+
audio_background = AudioFileClip(no_vocals_path)
|
105 |
+
my_clip.audio = audio_background
|
106 |
+
my_clip.write_videofile(output_path)
|
107 |
+
|
108 |
+
|
109 |
+
# https://www.youtube.com/watch?v=1jZEyU_eO1s
|
110 |
+
def get_karaoke(url):
|
111 |
+
folder_path = "data/samples/"
|
112 |
+
video_path = folder_path + "temp.mp4"
|
113 |
+
audio_path = folder_path + "temp.mp3"
|
114 |
+
no_vocals_path = folder_path + "temp_no_vocals.wav"
|
115 |
+
output_path = folder_path + "result.mp4"
|
116 |
+
|
117 |
+
download_from_youtube(url, folder_path)
|
118 |
+
separate_video_and_audio(video_path, audio_path)
|
119 |
+
waveform, sample_rate = load_audio(audio_path)
|
120 |
+
vocals, no_vocals = source_separation(waveform)
|
121 |
+
export_to_wav(vocals, no_vocals, sample_rate, folder_path)
|
122 |
+
combine_video_and_audio(video_path, no_vocals_path, output_path)
|
123 |
+
return output_path
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
libsndfile1
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
onnxruntime=="1.20.1"
|
2 |
+
torch=="2.5.1"
|
3 |
+
tqdm=="4.67.1"
|
4 |
+
llvmlite=="0.43.0"
|
5 |
+
librosa=="0.10.2.post1"
|
6 |
+
pydub=="0.25.1"
|
7 |
+
transformers=="4.47.0"
|
8 |
+
pytubefix=="8.8.1"
|
9 |
+
accelerate=="1.2.0"
|
10 |
+
moviepy=="2.1.1"
|
11 |
+
gradio=="5.8.0"
|
source_separation.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 seanghay
|
2 |
+
#
|
3 |
+
# This code is from an unliscensed repository.
|
4 |
+
#
|
5 |
+
# Note: This code has been modified to fit the context of this repository.
|
6 |
+
# This code is included in an MIT-licensed repository.
|
7 |
+
# The repository's MIT license does not apply to this code.
|
8 |
+
|
9 |
+
# This code is modified from https://github.com/seanghay/uvr-mdx-infer/blob/main/separate.py
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import onnxruntime as ort
|
13 |
+
import torch
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
class ConvTDFNet:
|
18 |
+
"""
|
19 |
+
ConvTDFNet - Convolutional Temporal Frequency Domain Network.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, target_name, L, dim_f, dim_t, n_fft, hop=1024):
|
23 |
+
"""
|
24 |
+
Initialize ConvTDFNet.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
target_name (str): The target name for separation.
|
28 |
+
L (int): Number of layers.
|
29 |
+
dim_f (int): Dimension in the frequency domain.
|
30 |
+
dim_t (int): Dimension in the time domain (log2).
|
31 |
+
n_fft (int): FFT size.
|
32 |
+
hop (int, optional): Hop size. Defaults to 1024.
|
33 |
+
|
34 |
+
Returns
|
35 |
+
-------
|
36 |
+
None
|
37 |
+
"""
|
38 |
+
super(ConvTDFNet, self).__init__()
|
39 |
+
self.dim_c = 4
|
40 |
+
self.dim_f = dim_f
|
41 |
+
self.dim_t = 2**dim_t
|
42 |
+
self.n_fft = n_fft
|
43 |
+
self.hop = hop
|
44 |
+
self.n_bins = self.n_fft // 2 + 1
|
45 |
+
self.chunk_size = hop * (self.dim_t - 1)
|
46 |
+
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
47 |
+
self.target_name = target_name
|
48 |
+
|
49 |
+
out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
|
50 |
+
|
51 |
+
self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t])
|
52 |
+
self.n = L // 2
|
53 |
+
|
54 |
+
def stft(self, x):
|
55 |
+
"""
|
56 |
+
Perform Short-Time Fourier Transform (STFT).
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x (torch.Tensor): Input waveform.
|
60 |
+
|
61 |
+
Returns
|
62 |
+
-------
|
63 |
+
torch.Tensor: STFT of the input waveform.
|
64 |
+
"""
|
65 |
+
x = x.reshape([-1, self.chunk_size])
|
66 |
+
x = torch.stft(
|
67 |
+
x,
|
68 |
+
n_fft=self.n_fft,
|
69 |
+
hop_length=self.hop,
|
70 |
+
window=self.window,
|
71 |
+
center=True,
|
72 |
+
return_complex=True,
|
73 |
+
)
|
74 |
+
x = torch.view_as_real(x)
|
75 |
+
x = x.permute([0, 3, 1, 2])
|
76 |
+
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
|
77 |
+
[-1, self.dim_c, self.n_bins, self.dim_t]
|
78 |
+
)
|
79 |
+
return x[:, :, : self.dim_f]
|
80 |
+
|
81 |
+
def istft(self, x, freq_pad=None):
|
82 |
+
"""
|
83 |
+
Perform Inverse Short-Time Fourier Transform (ISTFT).
|
84 |
+
|
85 |
+
Args:
|
86 |
+
x (torch.Tensor): Input STFT.
|
87 |
+
freq_pad (torch.Tensor, optional): Frequency padding. Defaults to None.
|
88 |
+
|
89 |
+
Returns
|
90 |
+
-------
|
91 |
+
torch.Tensor: Inverse STFT of the input.
|
92 |
+
"""
|
93 |
+
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
|
94 |
+
x = torch.cat([x, freq_pad], -2)
|
95 |
+
c = 4 * 2 if self.target_name == "*" else 2
|
96 |
+
x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
|
97 |
+
x = x.permute([0, 2, 3, 1])
|
98 |
+
x = x.contiguous()
|
99 |
+
x = torch.view_as_complex(x)
|
100 |
+
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
|
101 |
+
return x.reshape([-1, c, self.chunk_size])
|
102 |
+
|
103 |
+
|
104 |
+
class Predictor:
|
105 |
+
"""
|
106 |
+
Predictor class for source separation using ConvTDFNet and ONNX Runtime.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, args, device):
|
110 |
+
"""
|
111 |
+
Initialize the Predictor.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
args (dict): Configuration arguments.
|
115 |
+
device (str): Device to run the model ('cuda' or 'cpu').
|
116 |
+
|
117 |
+
Returns
|
118 |
+
-------
|
119 |
+
None
|
120 |
+
|
121 |
+
Raises
|
122 |
+
------
|
123 |
+
ValueError: If the provided device is not 'cuda' or 'cpu'.
|
124 |
+
"""
|
125 |
+
self.args = args
|
126 |
+
self.model_ = ConvTDFNet(
|
127 |
+
target_name="vocals",
|
128 |
+
L=11,
|
129 |
+
dim_f=args["dim_f"],
|
130 |
+
dim_t=args["dim_t"],
|
131 |
+
n_fft=args["n_fft"],
|
132 |
+
)
|
133 |
+
|
134 |
+
if device == "cuda":
|
135 |
+
self.model = ort.InferenceSession(
|
136 |
+
args["model_path"], providers=["CUDAExecutionProvider"]
|
137 |
+
)
|
138 |
+
elif device == "cpu":
|
139 |
+
self.model = ort.InferenceSession(
|
140 |
+
args["model_path"], providers=["CPUExecutionProvider"]
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
raise ValueError("Device must be either 'cuda' or 'cpu'")
|
144 |
+
|
145 |
+
def demix(self, mix):
|
146 |
+
"""
|
147 |
+
Separate the sources from the input mix.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
mix (np.ndarray): Input mixture signal.
|
151 |
+
|
152 |
+
Returns
|
153 |
+
-------
|
154 |
+
np.ndarray: Separated sources.
|
155 |
+
|
156 |
+
Raises
|
157 |
+
------
|
158 |
+
AssertionError: If margin is zero.
|
159 |
+
"""
|
160 |
+
samples = mix.shape[-1]
|
161 |
+
margin = self.args["margin"]
|
162 |
+
chunk_size = self.args["chunks"] * 44100
|
163 |
+
|
164 |
+
assert margin != 0, "Margin cannot be zero!"
|
165 |
+
|
166 |
+
margin = min(margin, chunk_size)
|
167 |
+
|
168 |
+
segmented_mix = {}
|
169 |
+
|
170 |
+
if self.args["chunks"] == 0 or samples < chunk_size:
|
171 |
+
chunk_size = samples
|
172 |
+
|
173 |
+
counter = -1
|
174 |
+
for skip in range(0, samples, chunk_size):
|
175 |
+
counter += 1
|
176 |
+
s_margin = 0 if counter == 0 else margin
|
177 |
+
end = min(skip + chunk_size + margin, samples)
|
178 |
+
start = skip - s_margin
|
179 |
+
segmented_mix[skip] = mix[:, start:end].copy()
|
180 |
+
if end == samples:
|
181 |
+
break
|
182 |
+
|
183 |
+
sources = self.demix_base(segmented_mix, margin_size=margin)
|
184 |
+
return sources
|
185 |
+
|
186 |
+
def demix_base(self, mixes, margin_size):
|
187 |
+
"""
|
188 |
+
Base function for source separation.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
mixes (dict): Dictionary of segmented mixtures.
|
192 |
+
margin_size (int): Size of the margin.
|
193 |
+
|
194 |
+
Returns
|
195 |
+
-------
|
196 |
+
np.ndarray: Separated sources.
|
197 |
+
"""
|
198 |
+
chunked_sources = []
|
199 |
+
progress_bar = tqdm(total=len(mixes))
|
200 |
+
progress_bar.set_description("Source separation")
|
201 |
+
|
202 |
+
for mix in mixes:
|
203 |
+
cmix = mixes[mix]
|
204 |
+
sources = []
|
205 |
+
n_sample = cmix.shape[1]
|
206 |
+
model = self.model_
|
207 |
+
trim = model.n_fft // 2
|
208 |
+
gen_size = model.chunk_size - 2 * trim
|
209 |
+
pad = gen_size - n_sample % gen_size
|
210 |
+
mix_p = np.concatenate(
|
211 |
+
(np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
|
212 |
+
)
|
213 |
+
mix_waves = []
|
214 |
+
i = 0
|
215 |
+
while i < n_sample + pad:
|
216 |
+
waves = np.array(mix_p[:, i : i + model.chunk_size])
|
217 |
+
mix_waves.append(waves)
|
218 |
+
i += gen_size
|
219 |
+
|
220 |
+
mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32)
|
221 |
+
|
222 |
+
with torch.no_grad():
|
223 |
+
_ort = self.model
|
224 |
+
spek = model.stft(mix_waves)
|
225 |
+
if self.args["denoise"]:
|
226 |
+
spec_pred = (
|
227 |
+
-_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
|
228 |
+
+ _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
|
229 |
+
)
|
230 |
+
tar_waves = model.istft(torch.tensor(spec_pred))
|
231 |
+
else:
|
232 |
+
tar_waves = model.istft(
|
233 |
+
torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
|
234 |
+
)
|
235 |
+
tar_signal = (
|
236 |
+
tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy()[:, :-pad]
|
237 |
+
)
|
238 |
+
|
239 |
+
start = 0 if mix == 0 else margin_size
|
240 |
+
end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
|
241 |
+
|
242 |
+
if margin_size == 0:
|
243 |
+
end = None
|
244 |
+
|
245 |
+
sources.append(tar_signal[:, start:end])
|
246 |
+
|
247 |
+
progress_bar.update(1)
|
248 |
+
|
249 |
+
chunked_sources.append(sources)
|
250 |
+
_sources = np.concatenate(chunked_sources, axis=-1)
|
251 |
+
|
252 |
+
progress_bar.close()
|
253 |
+
return _sources
|
254 |
+
|
255 |
+
def predict(self, mix):
|
256 |
+
"""
|
257 |
+
Predict the separated sources from the input mix.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
mix (np.ndarray): Input mixture signal.
|
261 |
+
|
262 |
+
Returns
|
263 |
+
-------
|
264 |
+
tuple: Tuple containing the mixture minus the separated sources and the separated sources.
|
265 |
+
"""
|
266 |
+
if mix.ndim == 1:
|
267 |
+
mix = np.asfortranarray([mix, mix])
|
268 |
+
|
269 |
+
tail = mix.shape[1] % (self.args["chunks"] * 44100)
|
270 |
+
if mix.shape[1] % (self.args["chunks"] * 44100) != 0:
|
271 |
+
mix = np.pad(
|
272 |
+
mix,
|
273 |
+
(
|
274 |
+
(0, 0),
|
275 |
+
(
|
276 |
+
0,
|
277 |
+
self.args["chunks"] * 44100 - mix.shape[1] % (self.args["chunks"] * 44100),
|
278 |
+
),
|
279 |
+
),
|
280 |
+
)
|
281 |
+
|
282 |
+
mix = mix.T
|
283 |
+
sources = self.demix(mix.T)
|
284 |
+
opt = sources[0].T
|
285 |
+
|
286 |
+
if tail != 0:
|
287 |
+
return (
|
288 |
+
(mix - opt)[: -(self.args["chunks"] * 44100 - tail), :],
|
289 |
+
opt[: -(self.args["chunks"] * 44100 - tail), :],
|
290 |
+
)
|
291 |
+
return ((mix - opt), opt)
|