Spaces:
Runtime error
Runtime error
File size: 4,180 Bytes
260b46d 93b48cb 260b46d 9fbfaa6 260b46d 93b48cb 260b46d 93b48cb 260b46d 93b48cb 260b46d 84d4ed6 260b46d 93b48cb 260b46d 84d4ed6 260b46d 9fbfaa6 260b46d 84d4ed6 260b46d ac059f4 c1b9ba0 260b46d c1b9ba0 260b46d 3815be3 93b48cb ac059f4 93b48cb c1b9ba0 260b46d ac059f4 260b46d ac059f4 260b46d 275afd0 260b46d 275afd0 260b46d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from pathlib import Path
import os
from functools import partial
from frechet_audio_distance import FrechetAudioDistance
import pandas
import argbind
import torch
from tqdm import tqdm
import audiotools
from audiotools import AudioSignal
@argbind.bind(without_prefix=True)
def eval(
exp_dir: str = None,
baseline_key: str = "baseline",
audio_ext: str = ".wav",
):
assert exp_dir is not None
exp_dir = Path(exp_dir)
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
# set up our metrics
# sisdr_loss = audiotools.metrics.distance.SISDRLoss()
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
frechet = FrechetAudioDistance(
use_pca=False,
use_activation=False,
verbose=True,
audio_load_worker=4,
)
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
# figure out what conditions we have
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
conditions.remove(baseline_key)
print(f"Found {len(conditions)} conditions in {exp_dir}")
print(f"conditions: {conditions}")
baseline_dir = exp_dir / baseline_key
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
metrics = []
for condition in tqdm(conditions):
cond_dir = exp_dir / condition
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
print(f"computing fad for {baseline_dir} and {cond_dir}")
frechet_score = frechet.score(baseline_dir, cond_dir)
# make sure we have the same number of files
num_files = min(len(baseline_files), len(cond_files))
baseline_files = baseline_files[:num_files]
cond_files = cond_files[:num_files]
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
def process(baseline_file, cond_file):
# make sure the files match (same name)
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
# load the files
baseline_sig = AudioSignal(str(baseline_file))
cond_sig = AudioSignal(str(cond_file))
cond_sig.resample(baseline_sig.sample_rate)
cond_sig.truncate_samples(baseline_sig.length)
# if our condition is inpainting, we need to trim the conditioning off
if "inpaint" in condition:
ctx_amt = float(condition.split("_")[-1])
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
cond_sig.trim(ctx_samples, ctx_samples)
baseline_sig.trim(ctx_samples, ctx_samples)
return {
# "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
# "stft": stft_loss(baseline_sig, cond_sig).item(),
"mel": mel_loss(baseline_sig, cond_sig).item(),
"frechet": frechet_score,
# "visqol": vsq,
"condition": condition,
"file": baseline_file.stem,
}
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
for mk in metric_keys:
stat = pandas.DataFrame(metrics)
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
stat.to_csv(exp_dir / f"stats-{mk}.csv")
df = pandas.DataFrame(metrics)
df.to_csv(exp_dir / "metrics-all.csv", index=False)
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
eval() |