Upload inference.py
Browse files- inference.py +161 -0
inference.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import time
|
6 |
+
import librosa
|
7 |
+
from tqdm import tqdm
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
import soundfile as sf
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
# Using the embedded version of Python can also correctly import the utils module.
|
17 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
sys.path.append(current_dir)
|
19 |
+
from utils import demix_track, demix_track_demucs, get_model_from_config
|
20 |
+
|
21 |
+
import warnings
|
22 |
+
warnings.filterwarnings("ignore")
|
23 |
+
|
24 |
+
|
25 |
+
def run_folder(model, args, config, device, verbose=False):
|
26 |
+
start_time = time.time()
|
27 |
+
model.eval()
|
28 |
+
all_mixtures_path = glob.glob(args.input_folder + '/*.*')
|
29 |
+
all_mixtures_path.sort()
|
30 |
+
print('Total files found: {}'.format(len(all_mixtures_path)))
|
31 |
+
|
32 |
+
instruments = config.training.instruments
|
33 |
+
if config.training.target_instrument is not None:
|
34 |
+
instruments = [config.training.target_instrument]
|
35 |
+
|
36 |
+
if not os.path.isdir(args.store_dir):
|
37 |
+
os.mkdir(args.store_dir)
|
38 |
+
|
39 |
+
if not verbose:
|
40 |
+
all_mixtures_path = tqdm(all_mixtures_path, desc="Total progress")
|
41 |
+
|
42 |
+
if args.disable_detailed_pbar:
|
43 |
+
detailed_pbar = False
|
44 |
+
else:
|
45 |
+
detailed_pbar = True
|
46 |
+
|
47 |
+
for path in all_mixtures_path:
|
48 |
+
print("Starting processing track: ", path)
|
49 |
+
if not verbose:
|
50 |
+
all_mixtures_path.set_postfix({'track': os.path.basename(path)})
|
51 |
+
try:
|
52 |
+
# mix, sr = sf.read(path)
|
53 |
+
mix, sr = librosa.load(path, sr=44100, mono=False)
|
54 |
+
except Exception as e:
|
55 |
+
print('Can read track: {}'.format(path))
|
56 |
+
print('Error message: {}'.format(str(e)))
|
57 |
+
continue
|
58 |
+
|
59 |
+
# Convert mono to stereo if needed
|
60 |
+
if len(mix.shape) == 1:
|
61 |
+
mix = np.stack([mix, mix], axis=0)
|
62 |
+
|
63 |
+
mix_orig = mix.copy()
|
64 |
+
if 'normalize' in config.inference:
|
65 |
+
if config.inference['normalize'] is True:
|
66 |
+
mono = mix.mean(0)
|
67 |
+
mean = mono.mean()
|
68 |
+
std = mono.std()
|
69 |
+
mix = (mix - mean) / std
|
70 |
+
|
71 |
+
mixture = torch.tensor(mix, dtype=torch.float32)
|
72 |
+
if args.model_type == 'htdemucs':
|
73 |
+
res = demix_track_demucs(config, model, mixture, device, pbar=detailed_pbar)
|
74 |
+
else:
|
75 |
+
res = demix_track(config, model, mixture, device, pbar=detailed_pbar)
|
76 |
+
|
77 |
+
for instr in instruments:
|
78 |
+
estimates = res[instr].T
|
79 |
+
if 'normalize' in config.inference:
|
80 |
+
if config.inference['normalize'] is True:
|
81 |
+
estimates = estimates * std + mean
|
82 |
+
file_name, _ = os.path.splitext(os.path.basename(path))
|
83 |
+
output_file = os.path.join(args.store_dir, f"{file_name}_{instr}.wav")
|
84 |
+
sf.write(output_file, estimates, sr, subtype = 'FLOAT')
|
85 |
+
|
86 |
+
# Output "instrumental", which is an inverse of 'vocals' (or first stem in list if 'vocals' absent)
|
87 |
+
if args.extract_instrumental:
|
88 |
+
file_name, _ = os.path.splitext(os.path.basename(path))
|
89 |
+
instrum_file_name = os.path.join(args.store_dir, f"{file_name}_instrumental.wav")
|
90 |
+
if 'vocals' in instruments:
|
91 |
+
estimates = res['vocals'].T
|
92 |
+
else:
|
93 |
+
estimates = res[instruments[0]].T
|
94 |
+
if 'normalize' in config.inference:
|
95 |
+
if config.inference['normalize'] is True:
|
96 |
+
estimates = estimates * std + mean
|
97 |
+
sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype = 'FLOAT')
|
98 |
+
|
99 |
+
time.sleep(1)
|
100 |
+
print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
|
101 |
+
|
102 |
+
|
103 |
+
def proc_folder(args):
|
104 |
+
parser = argparse.ArgumentParser()
|
105 |
+
parser.add_argument("--model_type", type=str, default='mdx23c',
|
106 |
+
help="One of bandit, bandit_v2, bs_roformer, htdemucs, mdx23c, mel_band_roformer, scnet, scnet_unofficial, segm_models, swin_upernet, torchseg")
|
107 |
+
parser.add_argument("--config_path", type=str, help="path to config file")
|
108 |
+
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights")
|
109 |
+
parser.add_argument("--input_folder", type=str, help="folder with mixtures to process")
|
110 |
+
parser.add_argument("--store_dir", default="", type=str, help="path to store results as wav file")
|
111 |
+
parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids')
|
112 |
+
parser.add_argument("--extract_instrumental", action='store_true', help="invert vocals to get instrumental if provided")
|
113 |
+
parser.add_argument("--disable_detailed_pbar", action='store_true', help="disable detailed progress bar")
|
114 |
+
parser.add_argument("--force_cpu", action = 'store_true', help = "Force the use of CPU even if CUDA is available")
|
115 |
+
if args is None:
|
116 |
+
args = parser.parse_args()
|
117 |
+
else:
|
118 |
+
args = parser.parse_args(args)
|
119 |
+
|
120 |
+
|
121 |
+
device = "cpu"
|
122 |
+
if args.force_cpu:
|
123 |
+
device = "cpu"
|
124 |
+
elif torch.cuda.is_available():
|
125 |
+
print('CUDA is available, use --force_cpu to disable it.')
|
126 |
+
device = "cuda"
|
127 |
+
device = f'cuda:{args.device_ids}' if type(args.device_ids) == int else f'cuda:{args.device_ids[0]}'
|
128 |
+
elif torch.backends.mps.is_available():
|
129 |
+
device = "mps"
|
130 |
+
|
131 |
+
print("Using device: ", device)
|
132 |
+
|
133 |
+
model_load_start_time = time.time()
|
134 |
+
torch.backends.cudnn.benchmark = True
|
135 |
+
|
136 |
+
model, config = get_model_from_config(args.model_type, args.config_path)
|
137 |
+
if args.start_check_point != '':
|
138 |
+
print('Start from checkpoint: {}'.format(args.start_check_point))
|
139 |
+
if args.model_type == 'htdemucs':
|
140 |
+
state_dict = torch.load(args.start_check_point, map_location = device, weights_only=False)
|
141 |
+
# Fix for htdemucs pretrained models
|
142 |
+
if 'state' in state_dict:
|
143 |
+
state_dict = state_dict['state']
|
144 |
+
else:
|
145 |
+
state_dict = torch.load(args.start_check_point, map_location = device, weights_only=True)
|
146 |
+
model.load_state_dict(state_dict)
|
147 |
+
print("Instruments: {}".format(config.training.instruments))
|
148 |
+
|
149 |
+
# in case multiple CUDA GPUs are used and --device_ids arg is passed
|
150 |
+
if type(args.device_ids) != int:
|
151 |
+
model = nn.DataParallel(model, device_ids = args.device_ids)
|
152 |
+
|
153 |
+
model = model.to(device)
|
154 |
+
|
155 |
+
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time))
|
156 |
+
|
157 |
+
run_folder(model, args, config, device, verbose=True)
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
proc_folder(None)
|