|
""" |
|
TODO: |
|
+ [x] Load Configuration |
|
+ [ ] Checking |
|
+ [ ] Better saving directory |
|
""" |
|
import numpy as np |
|
from pathlib import Path |
|
import jiwer |
|
import pdb |
|
import torch.nn as nn |
|
import torch |
|
import torchaudio |
|
import gradio as gr |
|
from logging import PlaceHolder |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC |
|
import yaml |
|
from transformers import pipeline |
|
import librosa |
|
import librosa.display |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
import sys |
|
|
|
sys.path.append("src") |
|
import lightning_module |
|
|
|
|
|
|
|
config_yaml = "Arthur.yaml" |
|
|
|
with open(config_yaml, "r") as f: |
|
|
|
try: |
|
config = yaml.safe_load(f) |
|
except FileExistsError: |
|
print("Config file Loading Error") |
|
exit() |
|
|
|
|
|
|
|
with open(config["ref_txt"], "r") as f: |
|
refs = f.readlines() |
|
refs_ids = [x.split()[0] for x in refs] |
|
refs_txt = [" ".join(x.split()[1:]) for x in refs] |
|
ref_feature = np.loadtxt(config["ref_feature"], delimiter=",", dtype="str") |
|
ref_wavs = [str(x) for x in sorted(Path(config["ref_wavs"]).glob("**/*.wav"))] |
|
|
|
dummy_wavs = [None for x in np.arange(len(ref_wavs))] |
|
|
|
refs_ppm = np.array(ref_feature[:, -1][1:], dtype="str") |
|
|
|
reference_id = gr.Textbox(value="ID", placeholder="Utter ID", label="Reference_ID") |
|
|
|
reference_textbox = gr.Textbox( |
|
value="Input reference here", |
|
placeholder="Input reference here", |
|
label="Reference", |
|
) |
|
reference_PPM = gr.Textbox(placeholder="Pneumatic Voice's PPM", label="Ref PPM") |
|
|
|
|
|
|
|
print("Preparing Examples") |
|
examples = [ |
|
[w, w_, i, x, y] for w, w_, i, x, y in zip(ref_wavs, ref_wavs, refs_ids, refs_txt, refs_ppm) |
|
] |
|
|
|
p = pipeline( |
|
"automatic-speech-recognition", |
|
model="KevinGeng/whipser_medium_en_PAL300_step25", |
|
device=0, |
|
) |
|
|
|
|
|
transformation = jiwer.Compose( |
|
[ |
|
jiwer.RemovePunctuation(), |
|
jiwer.ToLowerCase(), |
|
jiwer.RemoveWhiteSpace(replace_by_space=True), |
|
jiwer.RemoveMultipleSpaces(), |
|
jiwer.ReduceToListOfListOfWords(word_delimiter=" "), |
|
] |
|
) |
|
|
|
|
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") |
|
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") |
|
|
|
|
|
class ChangeSampleRate(nn.Module): |
|
def __init__(self, input_rate: int, output_rate: int): |
|
super().__init__() |
|
self.output_rate = output_rate |
|
self.input_rate = input_rate |
|
|
|
def forward(self, wav: torch.tensor) -> torch.tensor: |
|
|
|
wav = wav.view(wav.size(0), -1) |
|
new_length = wav.size(-1) * self.output_rate // self.input_rate |
|
indices = torch.arange(new_length) * (self.input_rate / self.output_rate) |
|
round_down = wav[:, indices.long()] |
|
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] |
|
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(0) + ( |
|
round_up * indices.fmod(1.0).unsqueeze(0) |
|
) |
|
return output |
|
|
|
|
|
|
|
model = lightning_module.BaselineLightningModule.load_from_checkpoint( |
|
"src/epoch=3-step=7459.ckpt" |
|
).eval() |
|
|
|
|
|
|
|
def get_speech_interval(signal, db): |
|
audio_interv = librosa.effects.split(signal, top_db=db) |
|
pause_end = [x[0] for x in audio_interv[1:]] |
|
pause_start = [x[1] for x in audio_interv[0:-1]] |
|
pause_interv = [[x, y] for x, y in zip(pause_start, pause_end)] |
|
return audio_interv, pause_interv |
|
|
|
|
|
|
|
|
|
def plot_UV(signal, audio_interv, sr): |
|
fig, ax = plt.subplots(nrows=2, sharex=True) |
|
librosa.display.waveshow(signal, sr=sr, ax=ax[0]) |
|
uv_flag = np.zeros(len(signal)) |
|
for i in audio_interv: |
|
uv_flag[i[0] : i[1]] = 1 |
|
|
|
ax[1].plot(np.arange(len(signal)) / sr, uv_flag, "r") |
|
ax[1].set_ylim([-0.1, 1.1]) |
|
return fig |
|
|
|
def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None): |
|
if audio_path == None: |
|
audio_path = _ |
|
print("using ref audio as eval audio since it's empty") |
|
|
|
wav, sr = torchaudio.load(audio_path) |
|
if wav.shape[0] != 1: |
|
wav = wav[0, :] |
|
print(wav.shape) |
|
|
|
osr = 16000 |
|
batch = wav.unsqueeze(0).repeat(10, 1, 1) |
|
csr = ChangeSampleRate(sr, osr) |
|
out_wavs = csr(wav) |
|
|
|
|
|
trans = jiwer.ToLowerCase()(p(audio_path)["text"]) |
|
|
|
|
|
wer = jiwer.wer( |
|
ref, |
|
trans, |
|
truth_transform=transformation, |
|
hypothesis_transform=transformation, |
|
) |
|
|
|
batch = { |
|
"wav": out_wavs, |
|
"domains": torch.tensor([0]), |
|
"judge_id": torch.tensor([288]), |
|
} |
|
with torch.no_grad(): |
|
output = model(batch) |
|
predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3 |
|
|
|
|
|
with torch.no_grad(): |
|
logits = phoneme_model(out_wavs).logits |
|
phone_predicted_ids = torch.argmax(logits, dim=-1) |
|
phone_transcription = processor.batch_decode(phone_predicted_ids) |
|
lst_phonemes = phone_transcription[0].split(" ") |
|
|
|
|
|
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr) |
|
|
|
a_h, p_h = get_speech_interval(wav_vad.numpy(), db=40) |
|
|
|
|
|
fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr) |
|
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60 |
|
|
|
error_msg = "!!! ERROR MESSAGE !!!\n" |
|
if audio_path == _ or audio_path == None: |
|
error_msg += "ERROR: Fail recording, Please start from the beginning again." |
|
return ( |
|
fig_h, |
|
predic_mos, |
|
trans, |
|
wer, |
|
phone_transcription, |
|
ppm, |
|
error_msg, |
|
) |
|
if ppm >= float(pre_ppm) + float(config["thre"]["maxppm"]): |
|
error_msg += "ERROR: Please speak slower.\n" |
|
elif ppm <= float(pre_ppm) - float(config["thre"]["minppm"]): |
|
error_msg += "ERROR: Please speak faster.\n" |
|
elif predic_mos <= float(config["thre"]["AUTOMOS"]): |
|
error_msg += "ERROR: Naturalness is too low, Please try again.\n" |
|
elif wer >= float(config["thre"]["WER"]): |
|
error_msg += "ERROR: Intelligibility is too low, Please try again\n" |
|
else: |
|
error_msg = ( |
|
"GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample." |
|
) |
|
|
|
return ( |
|
fig_h, |
|
predic_mos, |
|
trans, |
|
wer, |
|
phone_transcription, |
|
ppm, |
|
error_msg, |
|
) |
|
|
|
with open("src/description.html", "r", encoding="utf-8") as f: |
|
description = f.read() |
|
|
|
|
|
refs_ppm = np.array(ref_feature[:, -1][1:], dtype="str") |
|
|
|
reference_id = gr.Textbox(value="ID", placeholder="Utter ID", label="Reference_ID", visible=False) |
|
reference_textbox = gr.Textbox( |
|
value="Input reference here", |
|
placeholder="Input reference here", |
|
label="Reference", |
|
) |
|
reference_PPM = gr.Textbox(placeholder="Pneumatic Voice's PPM", label="Ref PPM", visible=False) |
|
|
|
|
|
|
|
|
|
|
|
def record_part_info(name, gender, first_lng): |
|
message = "Participant information is successfully collected." |
|
id_str = "%s_%s_%s" % (name, gender[0], first_lng[0]) |
|
|
|
if name == None: |
|
message = "ERROR: Name Information incomplete!" |
|
id_str = "ERROR" |
|
|
|
if gender == None: |
|
message = "ERROR: Please select gender" |
|
id_str = "ERROR" |
|
|
|
if len(gender) > 1: |
|
message = "ERROR: Please select one gender only" |
|
id_str = "ERROR" |
|
|
|
if first_lng == None: |
|
message = "ERROR: Please select your english proficiency" |
|
id_str = "ERROR" |
|
|
|
if len(first_lng) > 1: |
|
message = "ERROR: Please select one english proficiency only" |
|
id_str = "ERROR" |
|
|
|
return message, id_str |
|
|
|
|
|
|
|
name = gr.Textbox(placeholder="Name", label="Name") |
|
gender = gr.CheckboxGroup(["Male", "Female"], label="gender") |
|
first_lng = gr.CheckboxGroup( |
|
[ |
|
"B1 Intermediate", |
|
"B2: Upper Intermediate", |
|
"C1: Advanced", |
|
"C2: Proficient", |
|
], |
|
label="English Proficiency (CEFR)", |
|
) |
|
|
|
msg = gr.Textbox(placeholder="Evaluation for valid participant", label="message") |
|
id_str = gr.Textbox(placeholder="participant id", label="participant_id") |
|
|
|
info = gr.Interface( |
|
fn=record_part_info, |
|
inputs=[name, gender, first_lng], |
|
outputs=[msg, id_str], |
|
title="Participant Information Page", |
|
allow_flagging="never", |
|
css="body {background-color: blue}", |
|
) |
|
|
|
if config["exp_id"] == None: |
|
config["exp_id"] = Path(config_yaml).stem |
|
|
|
|
|
css = """ |
|
.ref_text textarea {font-size: 40px !important} |
|
.message textarea {font-size: 40px !important} |
|
""" |
|
|
|
my_theme = gr.themes.Default().set( |
|
button_primary_background_fill="#75DA99", |
|
button_primary_background_fill_dark="#DEF2D7", |
|
button_primary_text_color="black", |
|
button_secondary_text_color="black", |
|
) |
|
|
|
|
|
callback = gr.CSVLogger() |
|
|
|
with gr.Blocks(css=css, theme=my_theme) as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
ref_audio = gr.Audio( |
|
source="microphone", |
|
type="filepath", |
|
label="Reference_Audio", |
|
container=True, |
|
interactive=False, |
|
visible=False, |
|
) |
|
with gr.Row(): |
|
eval_audio = gr.Audio( |
|
source="microphone", |
|
type="filepath", |
|
container=True, |
|
label="Audio_to_Evaluate", |
|
) |
|
b_redo = gr.ClearButton( |
|
value="Redo", variant="stop", components=[eval_audio], size="sm" |
|
) |
|
reference_textbox = gr.Textbox( |
|
value="Input reference here", |
|
placeholder="Input reference here", |
|
label="Reference", |
|
interactive=True, |
|
elem_classes="ref_text", |
|
) |
|
with gr.Accordion("Input for Development", open=False): |
|
reference_id = gr.Textbox( |
|
value="ID", |
|
placeholder="Utter ID", |
|
label="Reference_ID", |
|
visible=True, |
|
) |
|
reference_PPM = gr.Textbox( |
|
placeholder="Pneumatic Voice's PPM", |
|
label="Ref PPM", |
|
visible=True, |
|
) |
|
with gr.Row(): |
|
b = gr.Button(value="1.Submit", variant="primary", elem_classes="submit") |
|
|
|
|
|
|
|
with gr.Row(): |
|
inputs = [ |
|
ref_audio, |
|
eval_audio, |
|
reference_id, |
|
reference_textbox, |
|
reference_PPM, |
|
] |
|
e = gr.Examples(examples, inputs, examples_per_page=5) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
|
|
msg = gr.Textbox( |
|
placeholder="Recording Feedback", |
|
label="Message", |
|
interactive=False, |
|
elem_classes="message", |
|
) |
|
with gr.Accordion("Output for Development", open=False): |
|
wav_plot = gr.Plot(PlaceHolder="Wav/Pause Plot", label="wav_pause_plot", visible=True) |
|
|
|
predict_mos = gr.Textbox( |
|
placeholder="Predicted MOS", |
|
label="Predicted MOS", |
|
visible=True, |
|
) |
|
|
|
hyp = gr.Textbox(placeholder="Hypothesis", label="Hypothesis", visible=True) |
|
|
|
wer = gr.Textbox(placeholder="Word Error Rate", label="WER", visible=True) |
|
|
|
predict_pho = gr.Textbox( |
|
placeholder="Predicted Phonemes", |
|
label="Predicted Phonemes", |
|
visible=True, |
|
) |
|
|
|
ppm = gr.Textbox( |
|
placeholder="Phonemes per minutes", |
|
label="PPM", |
|
visible=True, |
|
) |
|
outputs = [ |
|
wav_plot, |
|
predict_mos, |
|
hyp, |
|
wer, |
|
predict_pho, |
|
ppm, |
|
msg, |
|
] |
|
|
|
|
|
b.click(fn=calc_mos, inputs=inputs, outputs=outputs, api_name="Submit") |
|
|
|
|
|
callback.setup( |
|
components=[ |
|
eval_audio, |
|
reference_id, |
|
reference_textbox, |
|
reference_PPM, |
|
predict_mos, |
|
hyp, |
|
wer, |
|
ppm, |
|
msg], |
|
flagging_dir="./exp/%s" % config["exp_id"], |
|
) |
|
|
|
with gr.Row(): |
|
b2 = gr.Button("2. Save the Recording", variant="primary", elem_id="save") |
|
js_confirmed_saving = "(x) => confirm('Recording Saved!')" |
|
|
|
b2.click( |
|
lambda *args: callback.flag(args), |
|
inputs=[ |
|
eval_audio, |
|
reference_id, |
|
reference_textbox, |
|
reference_PPM, |
|
predict_mos, |
|
hyp, |
|
wer, |
|
ppm, |
|
msg, |
|
], |
|
outputs=None, |
|
preprocess=False, |
|
api_name="flagging", |
|
) |
|
with gr.Row(): |
|
b3 = gr.ClearButton( |
|
[ |
|
ref_audio, |
|
eval_audio, |
|
reference_id, |
|
reference_textbox, |
|
reference_PPM, |
|
predict_mos, |
|
hyp, |
|
wer, |
|
ppm, |
|
msg, |
|
], |
|
value="3.Clear All", |
|
elem_id="clear", |
|
) |
|
|
|
demo.launch(share=True) |