Gael Le Lan
Initial commit
9d0d223
raw
history blame
1.55 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import julius
import pesq
import torch
import torchmetrics
class PesqMetric(torchmetrics.Metric):
"""Metric for Perceptual Evaluation of Speech Quality.
(https://doi.org/10.5281/zenodo.6549559)
"""
sum_pesq: torch.Tensor
total: torch.Tensor
def __init__(self, sample_rate: int):
super().__init__()
self.sr = sample_rate
self.add_state("sum_pesq", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, targets: torch.Tensor):
if self.sr != 16000:
preds = julius.resample_frac(preds, self.sr, 16000)
targets = julius.resample_frac(targets, self.sr, 16000)
for ii in range(preds.size(0)):
try:
self.sum_pesq += pesq.pesq(
16000, targets[ii, 0].detach().cpu().numpy(), preds[ii, 0].detach().cpu().numpy()
)
self.total += 1
except (
pesq.NoUtterancesError
): # this error can append when the sample don't contain speech
pass
def compute(self) -> torch.Tensor:
return (
self.sum_pesq / self.total
if (self.total != 0).item()
else torch.tensor(0.0)
)