File size: 3,889 Bytes
67beffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

from random import sample
import gradio as gr
import torchaudio
import torch
import torch.nn as nn
import lightning_module
import pdb
import jiwer

# ASR part
from transformers import pipeline
p = pipeline("automatic-speech-recognition")

# WER part
transformation = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])

# WPM part
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
# phoneme_model =  pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft") 
class ChangeSampleRate(nn.Module):
    def __init__(self, input_rate: int, output_rate: int):
        super().__init__()
        self.output_rate = output_rate
        self.input_rate = input_rate

    def forward(self, wav: torch.tensor) -> torch.tensor:
        # Only accepts 1-channel waveform input
        wav = wav.view(wav.size(0), -1)
        new_length = wav.size(-1) * self.output_rate // self.input_rate
        indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
        round_down = wav[:, indices.long()]
        round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
        output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
        return output

model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()

def calc_mos(audio_path, ref):
    wav, sr = torchaudio.load(audio_path)
    osr = 16_000
    batch = wav.unsqueeze(0).repeat(10, 1, 1)
    csr = ChangeSampleRate(sr, osr)
    out_wavs = csr(wav)
    # ASR
    trans = p(audio_path)["text"]
    # WER
    wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
    # MOS
    batch = {
        'wav': out_wavs,
        'domains': torch.tensor([0]),
        'judge_id': torch.tensor([288])
    }
    with torch.no_grad():
        output = model(batch)
    predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
    # Phonemes per minute (PPM)
    with torch.no_grad():
        logits = phoneme_model(out_wavs).logits
    phone_predicted_ids = torch.argmax(logits, dim=-1)
    phone_transcription = processor.batch_decode(phone_predicted_ids)
    lst_phonemes = phone_transcription[0].split(" ")
    wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
    ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
    
    return predic_mos, trans, wer, phone_transcription, ppm

description ="""
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
This demo only accepts .wav format. Best at 16 kHz sampling rate.

Paper is available [here](https://arxiv.org/abs/2204.02152)

Add ASR based on wav2vec-960, currently only English available.
Add WER interface.
""" 


iface = gr.Interface(
  fn=calc_mos,
  inputs=[gr.Audio(source="microphone", type='filepath', label="Audio to evaluate"),
          gr.Textbox(value="Once upon a time there was a young rat named Author who couldn’t make up his mind.",
                    placeholder="Input referance here",
                    label="Referance")],
  outputs=[gr.Textbox(placeholder="Predicted MOS", label="Predicted MOS"),
           gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
           gr.Textbox(placeholder="Word Error Rate", label = "WER"),
           gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
           gr.Textbox(placeholder="Phonemes per minutes", label="PPM")],
  title="Laronix's Voice Quality Checking System Demo",
  description=description,
  allow_flagging="auto",
)
iface.launch()