UTMOS-demo / score.py
saefro991's picture
add scorer for quick start
e13d732
import lightning_module
import torch
import torchaudio
import unittest
class Score:
"""Predicting score for each audio clip."""
def __init__(
self,
ckpt_path: str = "epoch=3-step=7459.ckpt",
input_sample_rate: int = 16000,
device: str = "cpu"):
"""
Args:
ckpt_path: path to pretrained checkpoint of UTMOS strong learner.
input_sample_rate: sampling rate of input audio tensor. The input audio tensor
is automatically downsampled to 16kHz.
"""
print(f"Using device: {device}")
self.device = device
self.model = lightning_module.BaselineLightningModule.load_from_checkpoint(
ckpt_path).eval().to(device)
self.in_sr = input_sample_rate
self.resampler = torchaudio.transforms.Resample(
orig_freq=input_sample_rate,
new_freq=16000,
resampling_method="sinc_interpolation",
lowpass_filter_width=6,
dtype=torch.float32,
).to(device)
def score(self, wavs: torch.tensor) -> torch.tensor:
"""
Args:
wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
the model processes the input as a single audio clip. The model
performs batch processing when len(wavs) == 3.
"""
if len(wavs.shape) == 1:
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
elif len(wavs.shape) == 2:
out_wavs = wavs.unsqueeze(0)
elif len(wavs.shape) == 3:
out_wavs = wavs
else:
raise ValueError('Dimension of input tensor needs to be <= 3.')
if self.in_sr != 16000:
out_wavs = self.resampler(out_wavs)
bs = out_wavs.shape[0]
batch = {
'wav': out_wavs,
'domains': torch.zeros(bs, dtype=torch.int).to(self.device),
'judge_id': torch.ones(bs, dtype=torch.int).to(self.device)*288
}
with torch.no_grad():
output = self.model(batch)
return output.mean(dim=1).squeeze(1).cpu().detach().numpy()*2 + 3
class TestFunc(unittest.TestCase):
"""Test class."""
def test_1dim_0(self):
scorer = Score(input_sample_rate=16000)
seq_len = 10000
inp_audio = torch.ones(seq_len)
pred = scorer.score(inp_audio)
self.assertGreaterEqual(pred, 0.)
self.assertLessEqual(pred, 5.)
def test_1dim_1(self):
scorer = Score(input_sample_rate=24000)
seq_len = 10000
inp_audio = torch.ones(seq_len)
pred = scorer.score(inp_audio)
self.assertGreaterEqual(pred, 0.)
self.assertLessEqual(pred, 5.)
def test_2dim_0(self):
scorer = Score(input_sample_rate=16000)
seq_len = 10000
inp_audio = torch.ones(1, seq_len)
pred = scorer.score(inp_audio)
self.assertGreaterEqual(pred, 0.)
self.assertLessEqual(pred, 5.)
def test_2dim_1(self):
scorer = Score(input_sample_rate=24000)
seq_len = 10000
inp_audio = torch.ones(1, seq_len)
pred = scorer.score(inp_audio)
print(pred)
print(pred.shape)
self.assertGreaterEqual(pred, 0.)
self.assertLessEqual(pred, 5.)
def test_3dim_0(self):
scorer = Score(input_sample_rate=16000)
seq_len = 10000
batch = 8
inp_audio = torch.ones(batch, 1, seq_len)
pred = scorer.score(inp_audio)
for p in pred:
self.assertGreaterEqual(p, 0.)
self.assertLessEqual(p, 5.)
def test_3dim_1(self):
scorer = Score(input_sample_rate=24000)
seq_len = 10000
batch = 8
inp_audio = torch.ones(batch, 1, seq_len)
pred = scorer.score(inp_audio)
for p in pred:
self.assertGreaterEqual(p, 0.)
self.assertLessEqual(p, 5.)
if __name__ == '__main__':
unittest.main()