Spaces:
Build error
Build error
import argparse | |
import logging | |
import os | |
import sys | |
import librosa | |
import torch | |
import tqdm | |
from .data.data import EvalDataLoader, EvalDataset | |
from . import distrib | |
from .utils import remove_pad | |
from .utils import bold, deserialize_model, LogProgress | |
logger = logging.getLogger(__name__) | |
def load_model(): | |
global device | |
global model | |
global pkg | |
print("Loading svoice model if available...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pkg = torch.load('checkpoint.th', map_location=device) | |
if 'model' in pkg: | |
model = pkg['model'] | |
else: | |
model = pkg | |
model = deserialize_model(model) | |
logger.debug(model) | |
model.eval() | |
model.to(device) | |
print("svoice model loaded.") | |
print("Device: {}".format(device)) | |
parser = argparse.ArgumentParser("Speech separation using MulCat blocks") | |
parser.add_argument("model_path", type=str, help="Model name") | |
parser.add_argument("out_dir", type=str, default="exp/result", | |
help="Directory putting enhanced wav files") | |
parser.add_argument("--mix_dir", type=str, default=None, | |
help="Directory including mix wav files") | |
parser.add_argument("--mix_json", type=str, default=None, | |
help="Json file including mix wav files") | |
parser.add_argument('--device', default="cuda") | |
parser.add_argument("--sample_rate", default=8000, | |
type=int, help="Sample rate") | |
parser.add_argument("--batch_size", default=1, type=int, help="Batch size") | |
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, | |
default=logging.INFO, help="More loggging") | |
def save_wavs(estimate_source, mix_sig, lengths, filenames, out_dir, sr=16000): | |
# Remove padding and flat | |
flat_estimate = remove_pad(estimate_source, lengths) | |
mix_sig = remove_pad(mix_sig, lengths) | |
# Write result | |
for i, filename in enumerate(filenames): | |
filename = os.path.join( | |
out_dir, os.path.basename(filename).strip(".wav")) | |
write(mix_sig[i], filename + ".wav", sr=sr) | |
C = flat_estimate[i].shape[0] | |
# future support for wave playing | |
for c in range(C): | |
write(flat_estimate[i][c], filename + f"_s{c + 1}.wav", sr=sr) | |
def write(inputs, filename, sr=8000): | |
librosa.output.write_wav(filename, inputs, sr, norm=True) | |
def separate_demo(mix_dir='mix/', batch_size=1, sample_rate=16000): | |
mix_dir, mix_json = mix_dir, None | |
out_dir = 'separated' | |
# Load data | |
eval_dataset = EvalDataset( | |
mix_dir, | |
mix_json, | |
batch_size=batch_size, | |
sample_rate=sample_rate, | |
) | |
eval_loader = distrib.loader( | |
eval_dataset, batch_size=1, klass=EvalDataLoader) | |
if distrib.rank == 0: | |
os.makedirs(out_dir, exist_ok=True) | |
distrib.barrier() | |
with torch.no_grad(): | |
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)): | |
# Get batch data | |
mixture, lengths, filenames = data | |
mixture = mixture.to(device) | |
lengths = lengths.to(device) | |
# Forward | |
estimate_sources = model(mixture)[-1] | |
# save wav files | |
save_wavs(estimate_sources, mixture, lengths, | |
filenames, out_dir, sr=sample_rate) | |
separated_files = [os.path.join(out_dir, f) for f in os.listdir(out_dir)] | |
separated_files = [os.path.abspath(f) for f in separated_files] | |
separated_files = [f for f in separated_files if not f.endswith('original.wav')] | |
return separated_files | |
def get_mix_paths(args): | |
mix_dir = None | |
mix_json = None | |
# fix mix dir | |
try: | |
if args.dset.mix_dir: | |
mix_dir = args.dset.mix_dir | |
except: | |
mix_dir = args.mix_dir | |
# fix mix json | |
try: | |
if args.dset.mix_json: | |
mix_json = args.dset.mix_json | |
except: | |
mix_json = args.mix_json | |
return mix_dir, mix_json | |
def separate(args, model=None, local_out_dir=None): | |
mix_dir, mix_json = get_mix_paths(args) | |
if not mix_json and not mix_dir: | |
logger.error("Must provide mix_dir or mix_json! " | |
"When providing mix_dir, mix_json is ignored.") | |
# Load model | |
if not model: | |
# model | |
pkg = torch.load(args.model_path) | |
if 'model' in pkg: | |
model = pkg['model'] | |
else: | |
model = pkg | |
model = deserialize_model(model) | |
logger.debug(model) | |
model.eval() | |
model.to(args.device) | |
if local_out_dir: | |
out_dir = local_out_dir | |
else: | |
out_dir = args.out_dir | |
# Load data | |
eval_dataset = EvalDataset( | |
mix_dir, | |
mix_json, | |
batch_size=args.batch_size, | |
sample_rate=args.sample_rate, | |
) | |
eval_loader = distrib.loader( | |
eval_dataset, batch_size=1, klass=EvalDataLoader) | |
if distrib.rank == 0: | |
os.makedirs(out_dir, exist_ok=True) | |
distrib.barrier() | |
with torch.no_grad(): | |
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)): | |
# Get batch data | |
mixture, lengths, filenames = data | |
mixture = mixture.to(args.device) | |
lengths = lengths.to(args.device) | |
# Forward | |
estimate_sources = model(mixture)[-1] | |
# save wav files | |
save_wavs(estimate_sources, mixture, lengths, | |
filenames, out_dir, sr=args.sample_rate) | |
if __name__ == "__main__": | |
args = parser.parse_args() | |
logging.basicConfig(stream=sys.stderr, level=args.verbose) | |
logger.debug(args) | |
separate(args, local_out_dir=args.out_dir) | |