svoice_demo / svoice /separate.py
ahmedghani's picture
added samples
3d7e2e4
raw
history blame
5.65 kB
import argparse
import logging
import os
import sys
import librosa
import torch
import tqdm
from .data.data import EvalDataLoader, EvalDataset
from . import distrib
from .utils import remove_pad
from .utils import bold, deserialize_model, LogProgress
logger = logging.getLogger(__name__)
def load_model():
global device
global model
global pkg
print("Loading svoice model if available...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pkg = torch.load('checkpoint.th', map_location=device)
if 'model' in pkg:
model = pkg['model']
else:
model = pkg
model = deserialize_model(model)
logger.debug(model)
model.eval()
model.to(device)
print("svoice model loaded.")
print("Device: {}".format(device))
parser = argparse.ArgumentParser("Speech separation using MulCat blocks")
parser.add_argument("model_path", type=str, help="Model name")
parser.add_argument("out_dir", type=str, default="exp/result",
help="Directory putting enhanced wav files")
parser.add_argument("--mix_dir", type=str, default=None,
help="Directory including mix wav files")
parser.add_argument("--mix_json", type=str, default=None,
help="Json file including mix wav files")
parser.add_argument('--device', default="cuda")
parser.add_argument("--sample_rate", default=8000,
type=int, help="Sample rate")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size")
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
default=logging.INFO, help="More loggging")
def save_wavs(estimate_source, mix_sig, lengths, filenames, out_dir, sr=16000):
# Remove padding and flat
flat_estimate = remove_pad(estimate_source, lengths)
mix_sig = remove_pad(mix_sig, lengths)
# Write result
for i, filename in enumerate(filenames):
filename = os.path.join(
out_dir, os.path.basename(filename).strip(".wav"))
write(mix_sig[i], filename + ".wav", sr=sr)
C = flat_estimate[i].shape[0]
# future support for wave playing
for c in range(C):
write(flat_estimate[i][c], filename + f"_s{c + 1}.wav", sr=sr)
def write(inputs, filename, sr=8000):
librosa.output.write_wav(filename, inputs, sr, norm=True)
def separate_demo(mix_dir='mix/', batch_size=1, sample_rate=16000):
mix_dir, mix_json = mix_dir, None
out_dir = 'separated'
# Load data
eval_dataset = EvalDataset(
mix_dir,
mix_json,
batch_size=batch_size,
sample_rate=sample_rate,
)
eval_loader = distrib.loader(
eval_dataset, batch_size=1, klass=EvalDataLoader)
if distrib.rank == 0:
os.makedirs(out_dir, exist_ok=True)
distrib.barrier()
with torch.no_grad():
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
# Get batch data
mixture, lengths, filenames = data
mixture = mixture.to(device)
lengths = lengths.to(device)
# Forward
estimate_sources = model(mixture)[-1]
# save wav files
save_wavs(estimate_sources, mixture, lengths,
filenames, out_dir, sr=sample_rate)
separated_files = [os.path.join(out_dir, f) for f in os.listdir(out_dir)]
separated_files = [os.path.abspath(f) for f in separated_files]
separated_files = [f for f in separated_files if not f.endswith('original.wav')]
return separated_files
def get_mix_paths(args):
mix_dir = None
mix_json = None
# fix mix dir
try:
if args.dset.mix_dir:
mix_dir = args.dset.mix_dir
except:
mix_dir = args.mix_dir
# fix mix json
try:
if args.dset.mix_json:
mix_json = args.dset.mix_json
except:
mix_json = args.mix_json
return mix_dir, mix_json
def separate(args, model=None, local_out_dir=None):
mix_dir, mix_json = get_mix_paths(args)
if not mix_json and not mix_dir:
logger.error("Must provide mix_dir or mix_json! "
"When providing mix_dir, mix_json is ignored.")
# Load model
if not model:
# model
pkg = torch.load(args.model_path)
if 'model' in pkg:
model = pkg['model']
else:
model = pkg
model = deserialize_model(model)
logger.debug(model)
model.eval()
model.to(args.device)
if local_out_dir:
out_dir = local_out_dir
else:
out_dir = args.out_dir
# Load data
eval_dataset = EvalDataset(
mix_dir,
mix_json,
batch_size=args.batch_size,
sample_rate=args.sample_rate,
)
eval_loader = distrib.loader(
eval_dataset, batch_size=1, klass=EvalDataLoader)
if distrib.rank == 0:
os.makedirs(out_dir, exist_ok=True)
distrib.barrier()
with torch.no_grad():
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
# Get batch data
mixture, lengths, filenames = data
mixture = mixture.to(args.device)
lengths = lengths.to(args.device)
# Forward
estimate_sources = model(mixture)[-1]
# save wav files
save_wavs(estimate_sources, mixture, lengths,
filenames, out_dir, sr=args.sample_rate)
if __name__ == "__main__":
args = parser.parse_args()
logging.basicConfig(stream=sys.stderr, level=args.verbose)
logger.debug(args)
separate(args, local_out_dir=args.out_dir)