Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import whisper | |
from torchmetrics import WordErrorRate | |
def extract_wer( | |
content_gt=None, | |
audio_ref=None, | |
audio_deg=None, | |
fs=None, | |
language="chinese", | |
remove_space=True, | |
remove_punctuation=True, | |
mode="gt_audio", | |
): | |
"""Compute Word Error Rate (WER) between the predicted and the ground truth audio. | |
content_gt: the ground truth content. | |
audio_ref: path to the ground truth audio. | |
audio_deg: path to the predicted audio. | |
mode: "gt_content" computes the WER between the predicted content obtained from the whisper model and the ground truth content. | |
both content_gt and audio_deg are needed. | |
"gt_audio" computes the WER between the extracted ground truth and predicted contents obtained from the whisper model. | |
both audio_ref and audio_deg are needed. | |
""" | |
# Get ground truth content | |
if mode == "gt_content": | |
assert content_gt != None | |
if language == "chinese": | |
prompt = "以下是普通话的句子" | |
model = whisper.load_model("large").cuda() | |
result_deg = model.transcribe( | |
audio_deg, language="zh", verbose=True, initial_prompt=prompt | |
) | |
elif language == "english": | |
model = whisper.load_model("large").cuda() | |
result_deg = model.transcribe(audio_deg, language="en", verbose=True) | |
elif mode == "gt_audio": | |
assert audio_ref != None | |
if language == "chinese": | |
prompt = "以下是普通话的句子" | |
model = whisper.load_model("large").cuda() | |
result_ref = model.transcribe( | |
audio_ref, language="zh", verbose=True, initial_prompt=prompt | |
) | |
result_deg = model.transcribe( | |
audio_deg, language="zh", verbose=True, initial_prompt=prompt | |
) | |
elif language == "english": | |
model = whisper.load_model("large").cuda() | |
result_ref = model.transcribe(audio_deg, language="en", verbose=True) | |
result_deg = model.transcribe(audio_deg, language="en", verbose=True) | |
content_gt = result_ref["text"] | |
if remove_space: | |
content_gt = content_gt.replace(" ", "") | |
if remove_punctuation: | |
content_gt = content_gt.replace(".", "") | |
content_gt = content_gt.replace("'", "") | |
content_gt = content_gt.replace("-", "") | |
content_gt = content_gt.replace(",", "") | |
content_gt = content_gt.replace("!", "") | |
content_gt = content_gt.lower() | |
# Get predicted content | |
content_pred = result_deg["text"] | |
if remove_space: | |
content_pred = content_pred.replace(" ", "") | |
if remove_punctuation: | |
content_pred = content_pred.replace(".", "") | |
content_pred = content_pred.replace("'", "") | |
content_pred = content_pred.replace("-", "") | |
content_pred = content_pred.replace(",", "") | |
content_pred = content_pred.replace("!", "") | |
content_pred = content_pred.lower() | |
wer = WordErrorRate() | |
return wer(content_pred, content_gt).numpy().tolist() | |