Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import gzip | |
import sys | |
from concurrent import futures | |
import musdb | |
import museval | |
import torch as th | |
import tqdm | |
from scipy.io import wavfile | |
from torch import distributed | |
from .audio import convert_audio | |
from .utils import apply_model | |
def evaluate(model, | |
musdb_path, | |
eval_folder, | |
workers=2, | |
device="cpu", | |
rank=0, | |
save=False, | |
shifts=0, | |
split=False, | |
overlap=0.25, | |
is_wav=False, | |
world_size=1): | |
""" | |
Evaluate model using museval. Run the model | |
on a single GPU, the bottleneck being the call to museval. | |
""" | |
output_dir = eval_folder / "results" | |
output_dir.mkdir(exist_ok=True, parents=True) | |
json_folder = eval_folder / "results/test" | |
json_folder.mkdir(exist_ok=True, parents=True) | |
# we load tracks from the original musdb set | |
test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav) | |
src_rate = 44100 # hardcoded for now... | |
for p in model.parameters(): | |
p.requires_grad = False | |
p.grad = None | |
pendings = [] | |
with futures.ProcessPoolExecutor(workers or 1) as pool: | |
for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout): | |
track = test_set.tracks[index] | |
out = json_folder / f"{track.name}.json.gz" | |
if out.exists(): | |
continue | |
mix = th.from_numpy(track.audio).t().float() | |
ref = mix.mean(dim=0) # mono mixture | |
mix = (mix - ref.mean()) / ref.std() | |
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) | |
estimates = apply_model(model, mix.to(device), | |
shifts=shifts, split=split, overlap=overlap) | |
estimates = estimates * ref.std() + ref.mean() | |
estimates = estimates.transpose(1, 2) | |
references = th.stack( | |
[th.from_numpy(track.targets[name].audio).t() for name in model.sources]) | |
references = convert_audio(references, src_rate, | |
model.samplerate, model.audio_channels) | |
references = references.transpose(1, 2).numpy() | |
estimates = estimates.cpu().numpy() | |
win = int(1. * model.samplerate) | |
hop = int(1. * model.samplerate) | |
if save: | |
folder = eval_folder / "wav/test" / track.name | |
folder.mkdir(exist_ok=True, parents=True) | |
for name, estimate in zip(model.sources, estimates): | |
wavfile.write(str(folder / (name + ".wav")), 44100, estimate) | |
if workers: | |
pendings.append((track.name, pool.submit( | |
museval.evaluate, references, estimates, win=win, hop=hop))) | |
else: | |
pendings.append((track.name, museval.evaluate( | |
references, estimates, win=win, hop=hop))) | |
del references, mix, estimates, track | |
for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout): | |
if workers: | |
pending = pending.result() | |
sdr, isr, sir, sar = pending | |
track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name) | |
for idx, target in enumerate(model.sources): | |
values = { | |
"SDR": sdr[idx].tolist(), | |
"SIR": sir[idx].tolist(), | |
"ISR": isr[idx].tolist(), | |
"SAR": sar[idx].tolist() | |
} | |
track_store.add_target(target_name=target, values=values) | |
json_path = json_folder / f"{track_name}.json.gz" | |
gzip.open(json_path, "w").write(track_store.json.encode('utf-8')) | |
if world_size > 1: | |
distributed.barrier() | |