|
|
|
|
|
|
|
|
|
|
|
|
|
import gzip |
|
import sys |
|
from concurrent import futures |
|
|
|
import musdb |
|
import museval |
|
import torch as th |
|
import tqdm |
|
from scipy.io import wavfile |
|
from torch import distributed |
|
|
|
from .audio import convert_audio |
|
from .utils import apply_model |
|
|
|
|
|
def evaluate(model, |
|
musdb_path, |
|
eval_folder, |
|
workers=2, |
|
device="cpu", |
|
rank=0, |
|
save=False, |
|
shifts=0, |
|
split=False, |
|
overlap=0.25, |
|
is_wav=False, |
|
world_size=1): |
|
""" |
|
Evaluate model using museval. Run the model |
|
on a single GPU, the bottleneck being the call to museval. |
|
""" |
|
|
|
output_dir = eval_folder / "results" |
|
output_dir.mkdir(exist_ok=True, parents=True) |
|
json_folder = eval_folder / "results/test" |
|
json_folder.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav) |
|
src_rate = 44100 |
|
|
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
p.grad = None |
|
|
|
pendings = [] |
|
with futures.ProcessPoolExecutor(workers or 1) as pool: |
|
for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout): |
|
track = test_set.tracks[index] |
|
|
|
out = json_folder / f"{track.name}.json.gz" |
|
if out.exists(): |
|
continue |
|
|
|
mix = th.from_numpy(track.audio).t().float() |
|
ref = mix.mean(dim=0) |
|
mix = (mix - ref.mean()) / ref.std() |
|
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) |
|
estimates = apply_model(model, mix.to(device), |
|
shifts=shifts, split=split, overlap=overlap) |
|
estimates = estimates * ref.std() + ref.mean() |
|
|
|
estimates = estimates.transpose(1, 2) |
|
references = th.stack( |
|
[th.from_numpy(track.targets[name].audio).t() for name in model.sources]) |
|
references = convert_audio(references, src_rate, |
|
model.samplerate, model.audio_channels) |
|
references = references.transpose(1, 2).numpy() |
|
estimates = estimates.cpu().numpy() |
|
win = int(1. * model.samplerate) |
|
hop = int(1. * model.samplerate) |
|
if save: |
|
folder = eval_folder / "wav/test" / track.name |
|
folder.mkdir(exist_ok=True, parents=True) |
|
for name, estimate in zip(model.sources, estimates): |
|
wavfile.write(str(folder / (name + ".wav")), 44100, estimate) |
|
|
|
if workers: |
|
pendings.append((track.name, pool.submit( |
|
museval.evaluate, references, estimates, win=win, hop=hop))) |
|
else: |
|
pendings.append((track.name, museval.evaluate( |
|
references, estimates, win=win, hop=hop))) |
|
del references, mix, estimates, track |
|
|
|
for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout): |
|
if workers: |
|
pending = pending.result() |
|
sdr, isr, sir, sar = pending |
|
track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name) |
|
for idx, target in enumerate(model.sources): |
|
values = { |
|
"SDR": sdr[idx].tolist(), |
|
"SIR": sir[idx].tolist(), |
|
"ISR": isr[idx].tolist(), |
|
"SAR": sar[idx].tolist() |
|
} |
|
|
|
track_store.add_target(target_name=target, values=values) |
|
json_path = json_folder / f"{track_name}.json.gz" |
|
gzip.open(json_path, "w").write(track_store.json.encode('utf-8')) |
|
if world_size > 1: |
|
distributed.barrier() |
|
|