Spaces:
Build error
Build error
File size: 5,649 Bytes
8235b4f 3d7e2e4 8235b4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import argparse
import logging
import os
import sys
import librosa
import torch
import tqdm
from .data.data import EvalDataLoader, EvalDataset
from . import distrib
from .utils import remove_pad
from .utils import bold, deserialize_model, LogProgress
logger = logging.getLogger(__name__)
def load_model():
global device
global model
global pkg
print("Loading svoice model if available...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pkg = torch.load('checkpoint.th', map_location=device)
if 'model' in pkg:
model = pkg['model']
else:
model = pkg
model = deserialize_model(model)
logger.debug(model)
model.eval()
model.to(device)
print("svoice model loaded.")
print("Device: {}".format(device))
parser = argparse.ArgumentParser("Speech separation using MulCat blocks")
parser.add_argument("model_path", type=str, help="Model name")
parser.add_argument("out_dir", type=str, default="exp/result",
help="Directory putting enhanced wav files")
parser.add_argument("--mix_dir", type=str, default=None,
help="Directory including mix wav files")
parser.add_argument("--mix_json", type=str, default=None,
help="Json file including mix wav files")
parser.add_argument('--device', default="cuda")
parser.add_argument("--sample_rate", default=8000,
type=int, help="Sample rate")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size")
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
default=logging.INFO, help="More loggging")
def save_wavs(estimate_source, mix_sig, lengths, filenames, out_dir, sr=16000):
# Remove padding and flat
flat_estimate = remove_pad(estimate_source, lengths)
mix_sig = remove_pad(mix_sig, lengths)
# Write result
for i, filename in enumerate(filenames):
filename = os.path.join(
out_dir, os.path.basename(filename).strip(".wav"))
write(mix_sig[i], filename + ".wav", sr=sr)
C = flat_estimate[i].shape[0]
# future support for wave playing
for c in range(C):
write(flat_estimate[i][c], filename + f"_s{c + 1}.wav", sr=sr)
def write(inputs, filename, sr=8000):
librosa.output.write_wav(filename, inputs, sr, norm=True)
def separate_demo(mix_dir='mix/', batch_size=1, sample_rate=16000):
mix_dir, mix_json = mix_dir, None
out_dir = 'separated'
# Load data
eval_dataset = EvalDataset(
mix_dir,
mix_json,
batch_size=batch_size,
sample_rate=sample_rate,
)
eval_loader = distrib.loader(
eval_dataset, batch_size=1, klass=EvalDataLoader)
if distrib.rank == 0:
os.makedirs(out_dir, exist_ok=True)
distrib.barrier()
with torch.no_grad():
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
# Get batch data
mixture, lengths, filenames = data
mixture = mixture.to(device)
lengths = lengths.to(device)
# Forward
estimate_sources = model(mixture)[-1]
# save wav files
save_wavs(estimate_sources, mixture, lengths,
filenames, out_dir, sr=sample_rate)
separated_files = [os.path.join(out_dir, f) for f in os.listdir(out_dir)]
separated_files = [os.path.abspath(f) for f in separated_files]
separated_files = [f for f in separated_files if not f.endswith('original.wav')]
return separated_files
def get_mix_paths(args):
mix_dir = None
mix_json = None
# fix mix dir
try:
if args.dset.mix_dir:
mix_dir = args.dset.mix_dir
except:
mix_dir = args.mix_dir
# fix mix json
try:
if args.dset.mix_json:
mix_json = args.dset.mix_json
except:
mix_json = args.mix_json
return mix_dir, mix_json
def separate(args, model=None, local_out_dir=None):
mix_dir, mix_json = get_mix_paths(args)
if not mix_json and not mix_dir:
logger.error("Must provide mix_dir or mix_json! "
"When providing mix_dir, mix_json is ignored.")
# Load model
if not model:
# model
pkg = torch.load(args.model_path)
if 'model' in pkg:
model = pkg['model']
else:
model = pkg
model = deserialize_model(model)
logger.debug(model)
model.eval()
model.to(args.device)
if local_out_dir:
out_dir = local_out_dir
else:
out_dir = args.out_dir
# Load data
eval_dataset = EvalDataset(
mix_dir,
mix_json,
batch_size=args.batch_size,
sample_rate=args.sample_rate,
)
eval_loader = distrib.loader(
eval_dataset, batch_size=1, klass=EvalDataLoader)
if distrib.rank == 0:
os.makedirs(out_dir, exist_ok=True)
distrib.barrier()
with torch.no_grad():
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
# Get batch data
mixture, lengths, filenames = data
mixture = mixture.to(args.device)
lengths = lengths.to(args.device)
# Forward
estimate_sources = model(mixture)[-1]
# save wav files
save_wavs(estimate_sources, mixture, lengths,
filenames, out_dir, sr=args.sample_rate)
if __name__ == "__main__":
args = parser.parse_args()
logging.basicConfig(stream=sys.stderr, level=args.verbose)
logger.debug(args)
separate(args, local_out_dir=args.out_dir)
|