unpairedelectron07
commited on
Upload 6 files
Browse files- audiocraft/metrics/chroma_cosinesim.py +72 -0
- audiocraft/metrics/clap_consistency.py +84 -0
- audiocraft/metrics/fad.py +329 -0
- audiocraft/metrics/kld.py +220 -0
- audiocraft/metrics/rvm.py +110 -0
- audiocraft/metrics/visqol.py +216 -0
audiocraft/metrics/chroma_cosinesim.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchmetrics
|
9 |
+
|
10 |
+
from ..data.audio_utils import convert_audio
|
11 |
+
from ..modules.chroma import ChromaExtractor
|
12 |
+
|
13 |
+
|
14 |
+
class ChromaCosineSimilarityMetric(torchmetrics.Metric):
|
15 |
+
"""Chroma cosine similarity metric.
|
16 |
+
|
17 |
+
This metric extracts a chromagram for a reference waveform and
|
18 |
+
a generated waveform and compares each frame using the cosine similarity
|
19 |
+
function. The output is the mean cosine similarity.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
sample_rate (int): Sample rate used by the chroma extractor.
|
23 |
+
n_chroma (int): Number of chroma used by the chroma extractor.
|
24 |
+
radix2_exp (int): Exponent for the chroma extractor.
|
25 |
+
argmax (bool): Whether the chroma extractor uses argmax.
|
26 |
+
eps (float): Epsilon for cosine similarity computation.
|
27 |
+
"""
|
28 |
+
def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
|
29 |
+
super().__init__()
|
30 |
+
self.chroma_sample_rate = sample_rate
|
31 |
+
self.n_chroma = n_chroma
|
32 |
+
self.eps = eps
|
33 |
+
self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
|
34 |
+
radix2_exp=radix2_exp, argmax=argmax)
|
35 |
+
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
36 |
+
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
|
37 |
+
|
38 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
39 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
40 |
+
"""Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
|
41 |
+
if preds.size(0) == 0:
|
42 |
+
return
|
43 |
+
|
44 |
+
assert preds.shape == targets.shape, (
|
45 |
+
f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
|
46 |
+
assert preds.size(0) == sizes.size(0), (
|
47 |
+
f"Number of items in preds ({preds.shape}) mismatch ",
|
48 |
+
f"with sizes ({sizes.shape})")
|
49 |
+
assert preds.size(0) == sample_rates.size(0), (
|
50 |
+
f"Number of items in preds ({preds.shape}) mismatch ",
|
51 |
+
f"with sample_rates ({sample_rates.shape})")
|
52 |
+
assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
|
53 |
+
|
54 |
+
device = self.weight.device
|
55 |
+
preds, targets = preds.to(device), targets.to(device) # type: ignore
|
56 |
+
sample_rate = sample_rates[0].item()
|
57 |
+
preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
|
58 |
+
targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
|
59 |
+
gt_chroma = self.chroma_extractor(targets)
|
60 |
+
gen_chroma = self.chroma_extractor(preds)
|
61 |
+
chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
|
62 |
+
for i in range(len(gt_chroma)):
|
63 |
+
t = int(chroma_lens[i].item())
|
64 |
+
cosine_sim = torch.nn.functional.cosine_similarity(
|
65 |
+
gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
|
66 |
+
self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
|
67 |
+
self.weight += torch.tensor(t) # type: ignore
|
68 |
+
|
69 |
+
def compute(self) -> float:
|
70 |
+
"""Computes the average cosine similarty across all generated/target chromagrams pairs."""
|
71 |
+
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
|
72 |
+
return (self.cosine_sum / self.weight).item() # type: ignore
|
audiocraft/metrics/clap_consistency.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
import typing as tp
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchmetrics
|
12 |
+
from transformers import RobertaTokenizer # type: ignore
|
13 |
+
|
14 |
+
from ..data.audio_utils import convert_audio
|
15 |
+
from ..environment import AudioCraftEnvironment
|
16 |
+
from ..utils.utils import load_clap_state_dict
|
17 |
+
|
18 |
+
try:
|
19 |
+
import laion_clap # type: ignore
|
20 |
+
except ImportError:
|
21 |
+
laion_clap = None
|
22 |
+
|
23 |
+
|
24 |
+
class TextConsistencyMetric(torchmetrics.Metric):
|
25 |
+
"""Text consistency metric measuring consistency between audio and text pairs."""
|
26 |
+
|
27 |
+
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
28 |
+
raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
|
29 |
+
|
30 |
+
def compute(self):
|
31 |
+
raise NotImplementedError("implement how to compute the final metric score.")
|
32 |
+
|
33 |
+
|
34 |
+
class CLAPTextConsistencyMetric(TextConsistencyMetric):
|
35 |
+
"""Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
|
36 |
+
|
37 |
+
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
|
38 |
+
or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
|
39 |
+
|
40 |
+
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
|
41 |
+
similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
|
42 |
+
well as the generated audio based on them, and define the MCC metric as the average cosine similarity
|
43 |
+
between these embeddings.
|
44 |
+
|
45 |
+
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
|
46 |
+
"""
|
47 |
+
def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
|
48 |
+
super().__init__()
|
49 |
+
if laion_clap is None:
|
50 |
+
raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
|
51 |
+
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
52 |
+
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
|
53 |
+
self._initialize_model(model_path, model_arch, enable_fusion)
|
54 |
+
|
55 |
+
def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
|
56 |
+
model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
|
57 |
+
self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
58 |
+
self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
59 |
+
self.model_sample_rate = 48_000
|
60 |
+
load_clap_state_dict(self.model, model_path)
|
61 |
+
self.model.eval()
|
62 |
+
|
63 |
+
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
64 |
+
# we use the default params from CLAP module here as well
|
65 |
+
return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
66 |
+
|
67 |
+
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
68 |
+
"""Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
|
69 |
+
assert audio.size(0) == len(text), "Number of audio and text samples should match"
|
70 |
+
assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
|
71 |
+
sample_rate = int(sample_rates[0].item())
|
72 |
+
# convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
|
73 |
+
audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
|
74 |
+
audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
|
75 |
+
text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
76 |
+
# cosine similarity between the text and the audio embedding
|
77 |
+
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
|
78 |
+
self.cosine_sum += cosine_sim.sum(dim=0)
|
79 |
+
self.weight += torch.tensor(cosine_sim.size(0))
|
80 |
+
|
81 |
+
def compute(self):
|
82 |
+
"""Computes the average cosine similarty across all audio/text pairs."""
|
83 |
+
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
|
84 |
+
return (self.cosine_sum / self.weight).item() # type: ignore
|
audiocraft/metrics/fad.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import tempfile
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
from audiocraft.data.audio import audio_write
|
15 |
+
from audiocraft.data.audio_utils import convert_audio
|
16 |
+
import flashy
|
17 |
+
import torch
|
18 |
+
import torchmetrics
|
19 |
+
|
20 |
+
from ..environment import AudioCraftEnvironment
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
VGGISH_SAMPLE_RATE = 16_000
|
26 |
+
VGGISH_CHANNELS = 1
|
27 |
+
|
28 |
+
|
29 |
+
class FrechetAudioDistanceMetric(torchmetrics.Metric):
|
30 |
+
"""Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
|
31 |
+
|
32 |
+
From: D.C. Dowson & B.V. Landau The Fréchet distance between
|
33 |
+
multivariate normal distributions
|
34 |
+
https://doi.org/10.1016/0047-259X(82)90077-X
|
35 |
+
The Fréchet distance between two multivariate gaussians,
|
36 |
+
`X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
|
37 |
+
d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
|
38 |
+
= (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
|
39 |
+
- 2 * Tr(sqrt(sigma_x*sigma_y)))
|
40 |
+
|
41 |
+
To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
|
42 |
+
from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
|
43 |
+
We provide the below instructions as reference but we do not guarantee for further support
|
44 |
+
in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
|
45 |
+
|
46 |
+
We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
|
47 |
+
|
48 |
+
1. Get the code and models following the repository instructions. We used the steps below:
|
49 |
+
git clone git@github.com:google-research/google-research.git
|
50 |
+
git clone git@github.com:tensorflow/models.git
|
51 |
+
mkdir google-research/tensorflow_models
|
52 |
+
touch google-research/tensorflow_models/__init__.py
|
53 |
+
cp -r models/research/audioset google-research/tensorflow_models/
|
54 |
+
touch google-research/tensorflow_models/audioset/__init__.py
|
55 |
+
echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
|
56 |
+
google-research/tensorflow_models/audioset/__init__.py
|
57 |
+
# we can now remove the tensorflow models repository
|
58 |
+
# rm -r models
|
59 |
+
cd google-research
|
60 |
+
Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
|
61 |
+
assumes it is placed in the AudioCraft reference dir.
|
62 |
+
|
63 |
+
Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
|
64 |
+
- Update xrange for range in:
|
65 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
|
66 |
+
- Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
|
67 |
+
`tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
|
68 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
|
69 |
+
- Update `import vggish_params as params` to `from . import vggish_params as params` in:
|
70 |
+
https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
|
71 |
+
- Add flag to provide a given batch size for running the AudioSet model in:
|
72 |
+
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
|
73 |
+
```
|
74 |
+
flags.DEFINE_integer('batch_size', 64,
|
75 |
+
'Number of samples in the batch for AudioSet model.')
|
76 |
+
```
|
77 |
+
Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
|
78 |
+
`batch_size=FLAGS.batch_size` to the provided parameters.
|
79 |
+
|
80 |
+
2. Follow instructions for the library installation and a valid TensorFlow installation
|
81 |
+
```
|
82 |
+
# e.g. instructions from: https://www.tensorflow.org/install/pip
|
83 |
+
conda install -c conda-forge cudatoolkit=11.8.0
|
84 |
+
python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
|
85 |
+
mkdir -p $CONDA_PREFIX/etc/conda/activate.d
|
86 |
+
echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
|
87 |
+
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
88 |
+
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
|
89 |
+
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
90 |
+
source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
91 |
+
# Verify install: on a machine with GPU device
|
92 |
+
python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
|
93 |
+
```
|
94 |
+
|
95 |
+
Now install frechet_audio_distance required dependencies:
|
96 |
+
```
|
97 |
+
# We assume we already have TensorFlow installed from the above steps
|
98 |
+
pip install apache-beam numpy scipy tf_slim
|
99 |
+
```
|
100 |
+
|
101 |
+
Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
|
102 |
+
(you may want to specify --model_ckpt flag pointing to the model's path).
|
103 |
+
|
104 |
+
3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
|
105 |
+
and Tensorflow library path from the above installation steps:
|
106 |
+
export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
|
107 |
+
export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
|
108 |
+
|
109 |
+
e.g. assuming we have installed everything in a dedicated conda env
|
110 |
+
with python 3.10 that is currently active:
|
111 |
+
export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
|
112 |
+
export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
|
113 |
+
|
114 |
+
Finally you may want to export the following variable:
|
115 |
+
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
116 |
+
See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
|
117 |
+
|
118 |
+
You can save those environment variables in your training conda env, when currently active:
|
119 |
+
`$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
|
120 |
+
e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
|
121 |
+
and the training conda env is named audiocraft:
|
122 |
+
```
|
123 |
+
# activate training env
|
124 |
+
conda activate audiocraft
|
125 |
+
# get path to all envs
|
126 |
+
CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
|
127 |
+
# export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
|
128 |
+
touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
129 |
+
echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
|
130 |
+
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
131 |
+
echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
|
132 |
+
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
133 |
+
# optionally:
|
134 |
+
echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
|
135 |
+
# you may need to reactivate the audiocraft env for this to take effect
|
136 |
+
```
|
137 |
+
|
138 |
+
Args:
|
139 |
+
bin (Path or str): Path to installed frechet audio distance code.
|
140 |
+
model_path (Path or str): Path to Tensorflow checkpoint for the model
|
141 |
+
used to compute statistics over the embedding beams.
|
142 |
+
format (str): Audio format used to save files.
|
143 |
+
log_folder (Path or str, optional): Path where to write process logs.
|
144 |
+
"""
|
145 |
+
def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
|
146 |
+
format: str = "wav", batch_size: tp.Optional[int] = None,
|
147 |
+
log_folder: tp.Optional[tp.Union[Path, str]] = None):
|
148 |
+
super().__init__()
|
149 |
+
self.model_sample_rate = VGGISH_SAMPLE_RATE
|
150 |
+
self.model_channels = VGGISH_CHANNELS
|
151 |
+
self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
|
152 |
+
assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
|
153 |
+
self.format = format
|
154 |
+
self.batch_size = batch_size
|
155 |
+
self.bin = bin
|
156 |
+
self.tf_env = {"PYTHONPATH": str(self.bin)}
|
157 |
+
self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
|
158 |
+
logger.info("Python exe for TF is %s", self.python_path)
|
159 |
+
if 'TF_LIBRARY_PATH' in os.environ:
|
160 |
+
self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
|
161 |
+
if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
|
162 |
+
self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
|
163 |
+
logger.info("Env for TF is %r", self.tf_env)
|
164 |
+
self.reset(log_folder)
|
165 |
+
self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
|
166 |
+
|
167 |
+
def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
|
168 |
+
"""Reset torchmetrics.Metrics state."""
|
169 |
+
log_folder = Path(log_folder or tempfile.mkdtemp())
|
170 |
+
self.tmp_dir = log_folder / 'fad'
|
171 |
+
self.tmp_dir.mkdir(exist_ok=True)
|
172 |
+
self.samples_tests_dir = self.tmp_dir / 'tests'
|
173 |
+
self.samples_tests_dir.mkdir(exist_ok=True)
|
174 |
+
self.samples_background_dir = self.tmp_dir / 'background'
|
175 |
+
self.samples_background_dir.mkdir(exist_ok=True)
|
176 |
+
self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
|
177 |
+
self.manifest_background = self.tmp_dir / 'files_background.cvs'
|
178 |
+
self.stats_tests_dir = self.tmp_dir / 'stats_tests'
|
179 |
+
self.stats_background_dir = self.tmp_dir / 'stats_background'
|
180 |
+
self.counter = 0
|
181 |
+
|
182 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
183 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor,
|
184 |
+
stems: tp.Optional[tp.List[str]] = None):
|
185 |
+
"""Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
|
186 |
+
assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
|
187 |
+
num_samples = preds.shape[0]
|
188 |
+
assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
|
189 |
+
assert stems is None or num_samples == len(set(stems))
|
190 |
+
for i in range(num_samples):
|
191 |
+
self.total_files += 1 # type: ignore
|
192 |
+
self.counter += 1
|
193 |
+
wav_len = int(sizes[i].item())
|
194 |
+
sample_rate = int(sample_rates[i].item())
|
195 |
+
pred_wav = preds[i]
|
196 |
+
target_wav = targets[i]
|
197 |
+
pred_wav = pred_wav[..., :wav_len]
|
198 |
+
target_wav = target_wav[..., :wav_len]
|
199 |
+
stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
|
200 |
+
# dump audio files
|
201 |
+
try:
|
202 |
+
pred_wav = convert_audio(
|
203 |
+
pred_wav.unsqueeze(0), from_rate=sample_rate,
|
204 |
+
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
|
205 |
+
audio_write(
|
206 |
+
self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
|
207 |
+
format=self.format, strategy="peak")
|
208 |
+
except Exception as e:
|
209 |
+
logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
|
210 |
+
try:
|
211 |
+
# for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
|
212 |
+
# the original audio when writing it
|
213 |
+
target_wav = convert_audio(
|
214 |
+
target_wav.unsqueeze(0), from_rate=sample_rate,
|
215 |
+
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
|
216 |
+
audio_write(
|
217 |
+
self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
|
218 |
+
format=self.format, strategy="peak")
|
219 |
+
except Exception as e:
|
220 |
+
logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
|
221 |
+
|
222 |
+
def _get_samples_name(self, is_background: bool):
|
223 |
+
return 'background' if is_background else 'tests'
|
224 |
+
|
225 |
+
def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
|
226 |
+
if is_background:
|
227 |
+
input_samples_dir = self.samples_background_dir
|
228 |
+
input_filename = self.manifest_background
|
229 |
+
stats_name = self.stats_background_dir
|
230 |
+
else:
|
231 |
+
input_samples_dir = self.samples_tests_dir
|
232 |
+
input_filename = self.manifest_tests
|
233 |
+
stats_name = self.stats_tests_dir
|
234 |
+
beams_name = self._get_samples_name(is_background)
|
235 |
+
log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
|
236 |
+
|
237 |
+
logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
|
238 |
+
with open(input_filename, "w") as fout:
|
239 |
+
for path in Path(input_samples_dir).glob(f"*.{self.format}"):
|
240 |
+
fout.write(f"{str(path)}\n")
|
241 |
+
|
242 |
+
cmd = [
|
243 |
+
self.python_path, "-m",
|
244 |
+
"frechet_audio_distance.create_embeddings_main",
|
245 |
+
"--model_ckpt", f"{self.model_path}",
|
246 |
+
"--input_files", f"{str(input_filename)}",
|
247 |
+
"--stats", f"{str(stats_name)}",
|
248 |
+
]
|
249 |
+
if self.batch_size is not None:
|
250 |
+
cmd += ["--batch_size", str(self.batch_size)]
|
251 |
+
logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
|
252 |
+
env = os.environ
|
253 |
+
if gpu_index is not None:
|
254 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
255 |
+
process = subprocess.Popen(
|
256 |
+
cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
|
257 |
+
return process, log_file
|
258 |
+
|
259 |
+
def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
|
260 |
+
cmd = [
|
261 |
+
self.python_path, "-m", "frechet_audio_distance.compute_fad",
|
262 |
+
"--test_stats", f"{str(self.stats_tests_dir)}",
|
263 |
+
"--background_stats", f"{str(self.stats_background_dir)}",
|
264 |
+
]
|
265 |
+
logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
|
266 |
+
env = os.environ
|
267 |
+
if gpu_index is not None:
|
268 |
+
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
269 |
+
result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
|
270 |
+
if result.returncode:
|
271 |
+
logger.error(
|
272 |
+
"Error with FAD computation from stats: \n %s \n %s",
|
273 |
+
result.stdout.decode(), result.stderr.decode()
|
274 |
+
)
|
275 |
+
raise RuntimeError("Error while executing FAD computation from stats")
|
276 |
+
try:
|
277 |
+
# result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
|
278 |
+
fad_score = float(result.stdout[4:])
|
279 |
+
return fad_score
|
280 |
+
except Exception as e:
|
281 |
+
raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
|
282 |
+
|
283 |
+
def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
|
284 |
+
beams_name = self._get_samples_name(is_background)
|
285 |
+
if returncode:
|
286 |
+
with open(log_file, "r") as f:
|
287 |
+
error_log = f.read()
|
288 |
+
logger.error(error_log)
|
289 |
+
os._exit(1)
|
290 |
+
else:
|
291 |
+
logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
|
292 |
+
|
293 |
+
def _parallel_create_embedding_beams(self, num_of_gpus: int):
|
294 |
+
assert num_of_gpus > 0
|
295 |
+
logger.info("Creating embeddings beams in a parallel manner on different GPUs")
|
296 |
+
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
|
297 |
+
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
|
298 |
+
tests_beams_code = tests_beams_process.wait()
|
299 |
+
bg_beams_code = bg_beams_process.wait()
|
300 |
+
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
|
301 |
+
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
|
302 |
+
|
303 |
+
def _sequential_create_embedding_beams(self):
|
304 |
+
logger.info("Creating embeddings beams in a sequential manner")
|
305 |
+
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
|
306 |
+
tests_beams_code = tests_beams_process.wait()
|
307 |
+
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
|
308 |
+
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
|
309 |
+
bg_beams_code = bg_beams_process.wait()
|
310 |
+
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
|
311 |
+
|
312 |
+
@flashy.distrib.rank_zero_only
|
313 |
+
def _local_compute_frechet_audio_distance(self):
|
314 |
+
"""Compute Frechet Audio Distance score calling TensorFlow API."""
|
315 |
+
num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
316 |
+
if num_of_gpus > 1:
|
317 |
+
self._parallel_create_embedding_beams(num_of_gpus)
|
318 |
+
else:
|
319 |
+
self._sequential_create_embedding_beams()
|
320 |
+
fad_score = self._compute_fad_score(gpu_index=0)
|
321 |
+
return fad_score
|
322 |
+
|
323 |
+
def compute(self) -> float:
|
324 |
+
"""Compute metrics."""
|
325 |
+
assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
|
326 |
+
fad_score = self._local_compute_frechet_audio_distance()
|
327 |
+
logger.warning(f"FAD score = {fad_score}")
|
328 |
+
fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
|
329 |
+
return fad_score
|
audiocraft/metrics/kld.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import contextlib
|
8 |
+
from functools import partial
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torchmetrics
|
15 |
+
|
16 |
+
from ..data.audio_utils import convert_audio
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class _patch_passt_stft:
|
23 |
+
"""Decorator to patch torch.stft in PaSST."""
|
24 |
+
def __init__(self):
|
25 |
+
self.old_stft = torch.stft
|
26 |
+
|
27 |
+
def __enter__(self):
|
28 |
+
# return_complex is a mandatory parameter in latest torch versions
|
29 |
+
# torch is throwing RuntimeErrors when not set
|
30 |
+
torch.stft = partial(torch.stft, return_complex=False)
|
31 |
+
|
32 |
+
def __exit__(self, *exc):
|
33 |
+
torch.stft = self.old_stft
|
34 |
+
|
35 |
+
|
36 |
+
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
|
37 |
+
"""Computes the elementwise KL-Divergence loss between probability distributions
|
38 |
+
from generated samples and target samples.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
pred_probs (torch.Tensor): Probabilities for each label obtained
|
42 |
+
from a classifier on generated audio. Expected shape is [B, num_classes].
|
43 |
+
target_probs (torch.Tensor): Probabilities for each label obtained
|
44 |
+
from a classifier on target audio. Expected shape is [B, num_classes].
|
45 |
+
epsilon (float): Epsilon value.
|
46 |
+
Returns:
|
47 |
+
kld (torch.Tensor): KLD loss between each generated sample and target pair.
|
48 |
+
"""
|
49 |
+
kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
|
50 |
+
return kl_div.sum(-1)
|
51 |
+
|
52 |
+
|
53 |
+
class KLDivergenceMetric(torchmetrics.Metric):
|
54 |
+
"""Base implementation for KL Divergence metric.
|
55 |
+
|
56 |
+
The KL divergence is measured between probability distributions
|
57 |
+
of class predictions returned by a pre-trained audio classification model.
|
58 |
+
When the KL-divergence is low, the generated audio is expected to
|
59 |
+
have similar acoustic characteristics as the reference audio,
|
60 |
+
according to the classifier.
|
61 |
+
"""
|
62 |
+
def __init__(self):
|
63 |
+
super().__init__()
|
64 |
+
self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
65 |
+
self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
66 |
+
self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
67 |
+
self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
|
68 |
+
|
69 |
+
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
|
70 |
+
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
|
71 |
+
"""Get model output given provided input tensor.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
x (torch.Tensor): Input audio tensor of shape [B, C, T].
|
75 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
76 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
77 |
+
Returns:
|
78 |
+
probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
|
79 |
+
"""
|
80 |
+
raise NotImplementedError("implement method to extract label distributions from the model.")
|
81 |
+
|
82 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor,
|
83 |
+
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
|
84 |
+
"""Calculates running KL-Divergence loss between batches of audio
|
85 |
+
preds (generated) and target (ground-truth)
|
86 |
+
Args:
|
87 |
+
preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
|
88 |
+
targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
|
89 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
90 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
91 |
+
"""
|
92 |
+
assert preds.shape == targets.shape
|
93 |
+
assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
|
94 |
+
preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
|
95 |
+
targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
|
96 |
+
if preds_probs is not None and targets_probs is not None:
|
97 |
+
assert preds_probs.shape == targets_probs.shape
|
98 |
+
kld_scores = kl_divergence(preds_probs, targets_probs)
|
99 |
+
assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
|
100 |
+
self.kld_pq_sum += torch.sum(kld_scores)
|
101 |
+
kld_qp_scores = kl_divergence(targets_probs, preds_probs)
|
102 |
+
self.kld_qp_sum += torch.sum(kld_qp_scores)
|
103 |
+
self.weight += torch.tensor(kld_scores.size(0))
|
104 |
+
|
105 |
+
def compute(self) -> dict:
|
106 |
+
"""Computes KL-Divergence across all evaluated pred/target pairs."""
|
107 |
+
weight: float = float(self.weight.item()) # type: ignore
|
108 |
+
assert weight > 0, "Unable to compute with total number of comparisons <= 0"
|
109 |
+
logger.info(f"Computing KL divergence on a total of {weight} samples")
|
110 |
+
kld_pq = self.kld_pq_sum.item() / weight # type: ignore
|
111 |
+
kld_qp = self.kld_qp_sum.item() / weight # type: ignore
|
112 |
+
kld_both = kld_pq + kld_qp
|
113 |
+
return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
|
114 |
+
|
115 |
+
|
116 |
+
class PasstKLDivergenceMetric(KLDivergenceMetric):
|
117 |
+
"""KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
|
118 |
+
|
119 |
+
From: PaSST: Efficient Training of Audio Transformers with Patchout
|
120 |
+
Paper: https://arxiv.org/abs/2110.05069
|
121 |
+
Implementation: https://github.com/kkoutini/PaSST
|
122 |
+
|
123 |
+
Follow instructions from the github repo:
|
124 |
+
```
|
125 |
+
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
|
126 |
+
```
|
127 |
+
|
128 |
+
Args:
|
129 |
+
pretrained_length (float, optional): Audio duration used for the pretrained model.
|
130 |
+
"""
|
131 |
+
def __init__(self, pretrained_length: tp.Optional[float] = None):
|
132 |
+
super().__init__()
|
133 |
+
self._initialize_model(pretrained_length)
|
134 |
+
|
135 |
+
def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
|
136 |
+
"""Initialize underlying PaSST audio classifier."""
|
137 |
+
model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
|
138 |
+
self.min_input_frames = min_frames
|
139 |
+
self.max_input_frames = max_frames
|
140 |
+
self.model_sample_rate = sr
|
141 |
+
self.model = model
|
142 |
+
self.model.eval()
|
143 |
+
self.model.to(self.device)
|
144 |
+
|
145 |
+
def _load_base_model(self, pretrained_length: tp.Optional[float]):
|
146 |
+
"""Load pretrained model from PaSST."""
|
147 |
+
try:
|
148 |
+
if pretrained_length == 30:
|
149 |
+
from hear21passt.base30sec import get_basic_model # type: ignore
|
150 |
+
max_duration = 30
|
151 |
+
elif pretrained_length == 20:
|
152 |
+
from hear21passt.base20sec import get_basic_model # type: ignore
|
153 |
+
max_duration = 20
|
154 |
+
else:
|
155 |
+
from hear21passt.base import get_basic_model # type: ignore
|
156 |
+
# Original PASST was trained on AudioSet with 10s-long audio samples
|
157 |
+
max_duration = 10
|
158 |
+
min_duration = 0.15
|
159 |
+
min_duration = 0.15
|
160 |
+
except ModuleNotFoundError:
|
161 |
+
raise ModuleNotFoundError(
|
162 |
+
"Please install hear21passt to compute KL divergence: ",
|
163 |
+
"pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
|
164 |
+
)
|
165 |
+
model_sample_rate = 32_000
|
166 |
+
max_input_frames = int(max_duration * model_sample_rate)
|
167 |
+
min_input_frames = int(min_duration * model_sample_rate)
|
168 |
+
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
|
169 |
+
model = get_basic_model(mode='logits')
|
170 |
+
return model, model_sample_rate, max_input_frames, min_input_frames
|
171 |
+
|
172 |
+
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
|
173 |
+
"""Process audio to feed to the pretrained model."""
|
174 |
+
wav = wav.unsqueeze(0)
|
175 |
+
wav = wav[..., :wav_len]
|
176 |
+
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
|
177 |
+
wav = wav.squeeze(0)
|
178 |
+
# we don't pad but return a list of audio segments as this otherwise affects the KLD computation
|
179 |
+
segments = torch.split(wav, self.max_input_frames, dim=-1)
|
180 |
+
valid_segments = []
|
181 |
+
for s in segments:
|
182 |
+
# ignoring too small segments that are breaking the model inference
|
183 |
+
if s.size(-1) > self.min_input_frames:
|
184 |
+
valid_segments.append(s)
|
185 |
+
return [s[None] for s in valid_segments]
|
186 |
+
|
187 |
+
def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
|
188 |
+
"""Run the pretrained model and get the predictions."""
|
189 |
+
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
|
190 |
+
wav = wav.mean(dim=1)
|
191 |
+
# PaSST is printing a lot of garbage that we are not interested in
|
192 |
+
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
|
193 |
+
with torch.no_grad(), _patch_passt_stft():
|
194 |
+
logits = self.model(wav.to(self.device))
|
195 |
+
probs = torch.softmax(logits, dim=-1)
|
196 |
+
return probs
|
197 |
+
|
198 |
+
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
|
199 |
+
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
|
200 |
+
"""Get model output given provided input tensor.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
x (torch.Tensor): Input audio tensor of shape [B, C, T].
|
204 |
+
sizes (torch.Tensor): Actual audio sample length, of shape [B].
|
205 |
+
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
|
206 |
+
Returns:
|
207 |
+
probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
|
208 |
+
"""
|
209 |
+
all_probs: tp.List[torch.Tensor] = []
|
210 |
+
for i, wav in enumerate(x):
|
211 |
+
sample_rate = int(sample_rates[i].item())
|
212 |
+
wav_len = int(sizes[i].item())
|
213 |
+
wav_segments = self._process_audio(wav, sample_rate, wav_len)
|
214 |
+
for segment in wav_segments:
|
215 |
+
probs = self._get_model_preds(segment).mean(dim=0)
|
216 |
+
all_probs.append(probs)
|
217 |
+
if len(all_probs) > 0:
|
218 |
+
return torch.stack(all_probs, dim=0)
|
219 |
+
else:
|
220 |
+
return None
|
audiocraft/metrics/rvm.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torchaudio
|
11 |
+
|
12 |
+
|
13 |
+
def db_to_scale(volume: tp.Union[float, torch.Tensor]):
|
14 |
+
return 10 ** (volume / 20)
|
15 |
+
|
16 |
+
|
17 |
+
def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
|
18 |
+
min_scale = db_to_scale(min_volume)
|
19 |
+
return 20 * torch.log10(scale.clamp(min=min_scale))
|
20 |
+
|
21 |
+
|
22 |
+
class RelativeVolumeMel(nn.Module):
|
23 |
+
"""Relative volume melspectrogram measure.
|
24 |
+
|
25 |
+
Computes a measure of distance over two mel spectrogram that is interpretable in terms
|
26 |
+
of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
|
27 |
+
first renormalize both by the ground truth of `x_ref`.
|
28 |
+
|
29 |
+
..Warning:: This class returns the volume of the distortion at the spectrogram level,
|
30 |
+
e.g. low negative values reflects lower distortion levels. For a SNR (like reported
|
31 |
+
in the MultiBandDiffusion paper), just take `-rvm`.
|
32 |
+
|
33 |
+
Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
|
34 |
+
relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
|
35 |
+
clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
|
36 |
+
with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
|
37 |
+
Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
|
38 |
+
average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
|
39 |
+
good (for a neural network output, although sound engineers typically aim for much lower attenuations).
|
40 |
+
Similarly, anything above +30 dB would just be completely missing the target, and there is no point
|
41 |
+
in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
|
42 |
+
in line with what neural nets currently can achieve.
|
43 |
+
|
44 |
+
For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
|
45 |
+
the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
|
46 |
+
|
47 |
+
The metric can be aggregated over a given frequency band in order have different insights for
|
48 |
+
different region of the spectrum. `num_aggregated_bands` controls the number of bands.
|
49 |
+
|
50 |
+
..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
|
51 |
+
is numerically stable when computing its gradient. We thus advise against using it as a training loss.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
sample_rate (int): Sample rate of the input audio.
|
55 |
+
n_mels (int): Number of mel bands to use.
|
56 |
+
n_fft (int): Number of frequency bins for the STFT.
|
57 |
+
hop_length (int): Hop length of the STFT and the mel-spectrogram.
|
58 |
+
min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
|
59 |
+
the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
|
60 |
+
max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
|
61 |
+
max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
|
62 |
+
to that amount, to avoid rescaling near silence. Given in dB.
|
63 |
+
min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
|
64 |
+
bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
|
65 |
+
and anything below that will be considered equally.
|
66 |
+
num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
|
67 |
+
For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
|
68 |
+
"""
|
69 |
+
def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
|
70 |
+
hop_length: int = 128, min_relative_volume: float = -25,
|
71 |
+
max_relative_volume: float = 25, max_initial_gain: float = 25,
|
72 |
+
min_activity_volume: float = -25,
|
73 |
+
num_aggregated_bands: int = 4) -> None:
|
74 |
+
super().__init__()
|
75 |
+
self.melspec = torchaudio.transforms.MelSpectrogram(
|
76 |
+
n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
|
77 |
+
normalized=True, sample_rate=sample_rate, power=2)
|
78 |
+
self.min_relative_volume = min_relative_volume
|
79 |
+
self.max_relative_volume = max_relative_volume
|
80 |
+
self.max_initial_gain = max_initial_gain
|
81 |
+
self.min_activity_volume = min_activity_volume
|
82 |
+
self.num_aggregated_bands = num_aggregated_bands
|
83 |
+
|
84 |
+
def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
|
85 |
+
"""Compute RVM metric between estimate and reference samples.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
estimate (torch.Tensor): Estimate sample.
|
89 |
+
ground_truth (torch.Tensor): Reference sample.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
|
93 |
+
for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
|
94 |
+
"""
|
95 |
+
min_scale = db_to_scale(-self.max_initial_gain)
|
96 |
+
std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
|
97 |
+
z_gt = self.melspec(ground_truth / std).sqrt()
|
98 |
+
z_est = self.melspec(estimate / std).sqrt()
|
99 |
+
|
100 |
+
delta = z_gt - z_est
|
101 |
+
ref_db = scale_to_db(z_gt, self.min_activity_volume)
|
102 |
+
delta_db = scale_to_db(delta.abs(), min_volume=-120)
|
103 |
+
relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
|
104 |
+
dims = list(range(relative_db.dim()))
|
105 |
+
dims.remove(dims[-2])
|
106 |
+
losses_per_band = relative_db.mean(dim=dims)
|
107 |
+
aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
|
108 |
+
metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
|
109 |
+
metrics['rvm'] = losses_per_band.mean()
|
110 |
+
return metrics
|
audiocraft/metrics/visqol.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import csv
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
import tempfile
|
12 |
+
import typing as tp
|
13 |
+
import subprocess
|
14 |
+
import shutil
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torchaudio
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class ViSQOL:
|
23 |
+
"""ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
|
24 |
+
|
25 |
+
To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
|
26 |
+
instructions available in the open source repository: https://github.com/google/visqol
|
27 |
+
|
28 |
+
ViSQOL is capable of running in two modes:
|
29 |
+
|
30 |
+
Audio Mode:
|
31 |
+
When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
|
32 |
+
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
|
33 |
+
Audio mode uses support vector regression, with the maximum range at ~4.75.
|
34 |
+
|
35 |
+
Speech Mode:
|
36 |
+
When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
|
37 |
+
Input should be resampled to 16kHz.
|
38 |
+
As part of the speech mode processing, a root mean square implementation for voice activity detection
|
39 |
+
is performed on the reference signal to determine what parts of the signal have voice activity and
|
40 |
+
should therefore be included in the comparison. The signal is normalized before performing the voice
|
41 |
+
activity detection.
|
42 |
+
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
|
43 |
+
Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
|
44 |
+
|
45 |
+
For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
|
46 |
+
|
47 |
+
Args:
|
48 |
+
visqol_bin (str): Path to the ViSQOL binary.
|
49 |
+
mode (str): ViSQOL computation mode, expecting "audio" or "speech".
|
50 |
+
model (str): Name of the model to use for similarity to quality model.
|
51 |
+
debug (bool): Whether to also get debug metrics from ViSQOL or not.
|
52 |
+
"""
|
53 |
+
SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
|
54 |
+
ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
|
55 |
+
|
56 |
+
def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
|
57 |
+
model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
|
58 |
+
assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
|
59 |
+
self.visqol_bin = str(bin)
|
60 |
+
self.visqol_mode = mode
|
61 |
+
self.target_sr = self._get_target_sr(self.visqol_mode)
|
62 |
+
self.model = model
|
63 |
+
self.debug = debug
|
64 |
+
assert Path(self.visqol_model).exists(), \
|
65 |
+
f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
|
66 |
+
|
67 |
+
def _get_target_sr(self, mode: str) -> int:
|
68 |
+
# returns target sampling rate for the corresponding ViSQOL mode.
|
69 |
+
if mode not in ViSQOL.SAMPLE_RATES_MODES:
|
70 |
+
raise ValueError(
|
71 |
+
f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
|
72 |
+
)
|
73 |
+
return ViSQOL.SAMPLE_RATES_MODES[mode]
|
74 |
+
|
75 |
+
def _prepare_files(
|
76 |
+
self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
|
77 |
+
):
|
78 |
+
# prepare files for ViSQOL evaluation.
|
79 |
+
assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
|
80 |
+
assert len(ref_sig) == len(deg_sig), (
|
81 |
+
"Expects same number of ref and degraded inputs",
|
82 |
+
f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
|
83 |
+
)
|
84 |
+
# resample audio if needed
|
85 |
+
if sr != target_sr:
|
86 |
+
transform = torchaudio.transforms.Resample(sr, target_sr)
|
87 |
+
pad = int(0.5 * target_sr)
|
88 |
+
rs_ref = []
|
89 |
+
rs_deg = []
|
90 |
+
for i in range(len(ref_sig)):
|
91 |
+
rs_ref_i = transform(ref_sig[i])
|
92 |
+
rs_deg_i = transform(deg_sig[i])
|
93 |
+
if pad_with_silence:
|
94 |
+
rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
|
95 |
+
rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
|
96 |
+
rs_ref.append(rs_ref_i)
|
97 |
+
rs_deg.append(rs_deg_i)
|
98 |
+
ref_sig = torch.stack(rs_ref)
|
99 |
+
deg_sig = torch.stack(rs_deg)
|
100 |
+
# save audio chunks to tmp dir and create csv
|
101 |
+
tmp_dir = Path(tempfile.mkdtemp())
|
102 |
+
try:
|
103 |
+
tmp_input_csv_path = tmp_dir / "input.csv"
|
104 |
+
tmp_results_csv_path = tmp_dir / "results.csv"
|
105 |
+
tmp_debug_json_path = tmp_dir / "debug.json"
|
106 |
+
with open(tmp_input_csv_path, "w") as csv_file:
|
107 |
+
csv_writer = csv.writer(csv_file)
|
108 |
+
csv_writer.writerow(["reference", "degraded"])
|
109 |
+
for i in range(len(ref_sig)):
|
110 |
+
tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
|
111 |
+
tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
|
112 |
+
torchaudio.save(
|
113 |
+
tmp_ref_filename,
|
114 |
+
torch.clamp(ref_sig[i], min=-0.99, max=0.99),
|
115 |
+
sample_rate=target_sr,
|
116 |
+
bits_per_sample=16,
|
117 |
+
encoding="PCM_S"
|
118 |
+
)
|
119 |
+
torchaudio.save(
|
120 |
+
tmp_deg_filename,
|
121 |
+
torch.clamp(deg_sig[i], min=-0.99, max=0.99),
|
122 |
+
sample_rate=target_sr,
|
123 |
+
bits_per_sample=16,
|
124 |
+
encoding="PCM_S"
|
125 |
+
)
|
126 |
+
csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
|
127 |
+
return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
|
128 |
+
except Exception as e:
|
129 |
+
logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
|
130 |
+
return tmp_dir, None, None, None
|
131 |
+
|
132 |
+
def _flush_files(self, tmp_dir: tp.Union[Path, str]):
|
133 |
+
# flush tmp files used to compute ViSQOL.
|
134 |
+
shutil.rmtree(str(tmp_dir))
|
135 |
+
|
136 |
+
def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
|
137 |
+
# collect results for each evaluated pair and return averaged moslqo score.
|
138 |
+
with open(results_csv_path, "r") as csv_file:
|
139 |
+
reader = csv.DictReader(csv_file)
|
140 |
+
moslqo_scores = [float(row["moslqo"]) for row in reader]
|
141 |
+
if len(moslqo_scores) > 0:
|
142 |
+
return sum(moslqo_scores) / len(moslqo_scores)
|
143 |
+
else:
|
144 |
+
return 0.0
|
145 |
+
|
146 |
+
def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
|
147 |
+
# collect debug data for the visqol inference.
|
148 |
+
with open(debug_json_path, "r") as f:
|
149 |
+
data = json.load(f)
|
150 |
+
return data
|
151 |
+
|
152 |
+
@property
|
153 |
+
def visqol_model(self):
|
154 |
+
return f'{self.visqol_bin}/model/{self.model}'
|
155 |
+
|
156 |
+
def _run_visqol(
|
157 |
+
self,
|
158 |
+
input_csv_path: tp.Union[Path, str],
|
159 |
+
results_csv_path: tp.Union[Path, str],
|
160 |
+
debug_csv_path: tp.Optional[tp.Union[Path, str]],
|
161 |
+
):
|
162 |
+
input_csv_path = str(input_csv_path)
|
163 |
+
results_csv_path = str(results_csv_path)
|
164 |
+
debug_csv_path = str(debug_csv_path)
|
165 |
+
cmd = [
|
166 |
+
f'{self.visqol_bin}/bazel-bin/visqol',
|
167 |
+
'--batch_input_csv', f'{input_csv_path}',
|
168 |
+
'--results_csv', f'{results_csv_path}'
|
169 |
+
]
|
170 |
+
if debug_csv_path is not None:
|
171 |
+
cmd += ['--output_debug', f'{debug_csv_path}']
|
172 |
+
if self.visqol_mode == "speech":
|
173 |
+
cmd += ['--use_speech_mode']
|
174 |
+
cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
|
175 |
+
result = subprocess.run(cmd, capture_output=True)
|
176 |
+
if result.returncode:
|
177 |
+
logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
|
178 |
+
raise RuntimeError("Error while executing visqol")
|
179 |
+
result.check_returncode()
|
180 |
+
|
181 |
+
def __call__(
|
182 |
+
self,
|
183 |
+
ref_sig: torch.Tensor,
|
184 |
+
deg_sig: torch.Tensor,
|
185 |
+
sr: int,
|
186 |
+
pad_with_silence: bool = False,
|
187 |
+
):
|
188 |
+
"""Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
|
189 |
+
Args:
|
190 |
+
ref_sig (torch.Tensor): Reference signals as [B, C, T].
|
191 |
+
deg_sig (torch.Tensor): Degraded signals as [B, C, T].
|
192 |
+
sr (int): Sample rate of the two audio signals.
|
193 |
+
pad_with_silence (bool): Whether to pad the file with silences as recommended
|
194 |
+
in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
|
195 |
+
Returns:
|
196 |
+
float: The ViSQOL score or mean score for the batch.
|
197 |
+
"""
|
198 |
+
logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
|
199 |
+
tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
|
200 |
+
ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
|
201 |
+
)
|
202 |
+
try:
|
203 |
+
if input_csv and results_csv:
|
204 |
+
self._run_visqol(
|
205 |
+
input_csv,
|
206 |
+
results_csv,
|
207 |
+
debug_json if self.debug else None,
|
208 |
+
)
|
209 |
+
mosqol = self._collect_moslqo_score(results_csv)
|
210 |
+
return mosqol
|
211 |
+
else:
|
212 |
+
raise RuntimeError("Something unexpected happened when running VISQOL!")
|
213 |
+
except Exception as e:
|
214 |
+
logger.error("Exception occurred when running ViSQOL: %s", e)
|
215 |
+
finally:
|
216 |
+
self._flush_files(tmp_dir)
|