ArianatorQualquer commited on
Commit
a8a84ee
·
verified ·
1 Parent(s): 7100ed1

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +161 -0
inference.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
3
+
4
+ import argparse
5
+ import time
6
+ import librosa
7
+ from tqdm import tqdm
8
+ import sys
9
+ import os
10
+ import glob
11
+ import torch
12
+ import numpy as np
13
+ import soundfile as sf
14
+ import torch.nn as nn
15
+
16
+ # Using the embedded version of Python can also correctly import the utils module.
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ sys.path.append(current_dir)
19
+ from utils import demix_track, demix_track_demucs, get_model_from_config
20
+
21
+ import warnings
22
+ warnings.filterwarnings("ignore")
23
+
24
+
25
+ def run_folder(model, args, config, device, verbose=False):
26
+ start_time = time.time()
27
+ model.eval()
28
+ all_mixtures_path = glob.glob(args.input_folder + '/*.*')
29
+ all_mixtures_path.sort()
30
+ print('Total files found: {}'.format(len(all_mixtures_path)))
31
+
32
+ instruments = config.training.instruments
33
+ if config.training.target_instrument is not None:
34
+ instruments = [config.training.target_instrument]
35
+
36
+ if not os.path.isdir(args.store_dir):
37
+ os.mkdir(args.store_dir)
38
+
39
+ if not verbose:
40
+ all_mixtures_path = tqdm(all_mixtures_path, desc="Total progress")
41
+
42
+ if args.disable_detailed_pbar:
43
+ detailed_pbar = False
44
+ else:
45
+ detailed_pbar = True
46
+
47
+ for path in all_mixtures_path:
48
+ print("Starting processing track: ", path)
49
+ if not verbose:
50
+ all_mixtures_path.set_postfix({'track': os.path.basename(path)})
51
+ try:
52
+ # mix, sr = sf.read(path)
53
+ mix, sr = librosa.load(path, sr=44100, mono=False)
54
+ except Exception as e:
55
+ print('Can read track: {}'.format(path))
56
+ print('Error message: {}'.format(str(e)))
57
+ continue
58
+
59
+ # Convert mono to stereo if needed
60
+ if len(mix.shape) == 1:
61
+ mix = np.stack([mix, mix], axis=0)
62
+
63
+ mix_orig = mix.copy()
64
+ if 'normalize' in config.inference:
65
+ if config.inference['normalize'] is True:
66
+ mono = mix.mean(0)
67
+ mean = mono.mean()
68
+ std = mono.std()
69
+ mix = (mix - mean) / std
70
+
71
+ mixture = torch.tensor(mix, dtype=torch.float32)
72
+ if args.model_type == 'htdemucs':
73
+ res = demix_track_demucs(config, model, mixture, device, pbar=detailed_pbar)
74
+ else:
75
+ res = demix_track(config, model, mixture, device, pbar=detailed_pbar)
76
+
77
+ for instr in instruments:
78
+ estimates = res[instr].T
79
+ if 'normalize' in config.inference:
80
+ if config.inference['normalize'] is True:
81
+ estimates = estimates * std + mean
82
+ file_name, _ = os.path.splitext(os.path.basename(path))
83
+ output_file = os.path.join(args.store_dir, f"{file_name}_{instr}.wav")
84
+ sf.write(output_file, estimates, sr, subtype = 'FLOAT')
85
+
86
+ # Output "instrumental", which is an inverse of 'vocals' (or first stem in list if 'vocals' absent)
87
+ if args.extract_instrumental:
88
+ file_name, _ = os.path.splitext(os.path.basename(path))
89
+ instrum_file_name = os.path.join(args.store_dir, f"{file_name}_instrumental.wav")
90
+ if 'vocals' in instruments:
91
+ estimates = res['vocals'].T
92
+ else:
93
+ estimates = res[instruments[0]].T
94
+ if 'normalize' in config.inference:
95
+ if config.inference['normalize'] is True:
96
+ estimates = estimates * std + mean
97
+ sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype = 'FLOAT')
98
+
99
+ time.sleep(1)
100
+ print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
101
+
102
+
103
+ def proc_folder(args):
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument("--model_type", type=str, default='mdx23c',
106
+ help="One of bandit, bandit_v2, bs_roformer, htdemucs, mdx23c, mel_band_roformer, scnet, scnet_unofficial, segm_models, swin_upernet, torchseg")
107
+ parser.add_argument("--config_path", type=str, help="path to config file")
108
+ parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights")
109
+ parser.add_argument("--input_folder", type=str, help="folder with mixtures to process")
110
+ parser.add_argument("--store_dir", default="", type=str, help="path to store results as wav file")
111
+ parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids')
112
+ parser.add_argument("--extract_instrumental", action='store_true', help="invert vocals to get instrumental if provided")
113
+ parser.add_argument("--disable_detailed_pbar", action='store_true', help="disable detailed progress bar")
114
+ parser.add_argument("--force_cpu", action = 'store_true', help = "Force the use of CPU even if CUDA is available")
115
+ if args is None:
116
+ args = parser.parse_args()
117
+ else:
118
+ args = parser.parse_args(args)
119
+
120
+
121
+ device = "cpu"
122
+ if args.force_cpu:
123
+ device = "cpu"
124
+ elif torch.cuda.is_available():
125
+ print('CUDA is available, use --force_cpu to disable it.')
126
+ device = "cuda"
127
+ device = f'cuda:{args.device_ids}' if type(args.device_ids) == int else f'cuda:{args.device_ids[0]}'
128
+ elif torch.backends.mps.is_available():
129
+ device = "mps"
130
+
131
+ print("Using device: ", device)
132
+
133
+ model_load_start_time = time.time()
134
+ torch.backends.cudnn.benchmark = True
135
+
136
+ model, config = get_model_from_config(args.model_type, args.config_path)
137
+ if args.start_check_point != '':
138
+ print('Start from checkpoint: {}'.format(args.start_check_point))
139
+ if args.model_type == 'htdemucs':
140
+ state_dict = torch.load(args.start_check_point, map_location = device, weights_only=False)
141
+ # Fix for htdemucs pretrained models
142
+ if 'state' in state_dict:
143
+ state_dict = state_dict['state']
144
+ else:
145
+ state_dict = torch.load(args.start_check_point, map_location = device, weights_only=True)
146
+ model.load_state_dict(state_dict)
147
+ print("Instruments: {}".format(config.training.instruments))
148
+
149
+ # in case multiple CUDA GPUs are used and --device_ids arg is passed
150
+ if type(args.device_ids) != int:
151
+ model = nn.DataParallel(model, device_ids = args.device_ids)
152
+
153
+ model = model.to(device)
154
+
155
+ print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time))
156
+
157
+ run_folder(model, args, config, device, verbose=True)
158
+
159
+
160
+ if __name__ == "__main__":
161
+ proc_folder(None)