import os
import numpy as np
import unicodedata
import diff_match_patch as dmp_module
from enum import Enum
import gradio as gr
from datasets import load_dataset
import pandas as pd
from jiwer import process_words, wer_default
from nltk import ngrams
class Action(Enum):
INSERTION = 1
DELETION = -1
EQUAL = 0
def compare_string(text1: str, text2: str) -> list:
text1_normalized = unicodedata.normalize("NFKC", text1)
text2_normalized = unicodedata.normalize("NFKC", text2)
dmp = dmp_module.diff_match_patch()
diff = dmp.diff_main(text1_normalized, text2_normalized)
dmp.diff_cleanupSemantic(diff)
return diff
def style_text(diff):
fullText = ""
for action, text in diff:
if action == Action.INSERTION.value:
fullText += f"{text}"
elif action == Action.DELETION.value:
fullText += f"{text}"
elif action == Action.EQUAL.value:
fullText += f"{text}"
else:
raise Exception("Not Implemented")
fullText = fullText.replace("](", "]\(").replace("~", "\~")
return fullText
dataset = load_dataset(
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
)
csv_v2 = pd.read_csv("assets/large-v2.csv")
norm_target = csv_v2["Norm Target"]
norm_pred_v2 = csv_v2["Norm Pred"]
norm_target = [norm_target[i] for i in range(len(norm_target))]
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))]
csv_v2 = pd.read_csv("assets/large-32-2.csv")
norm_pred_32_2 = csv_v2["Norm Pred"]
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))]
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
def get_statistics(model="large-v2", round_dp=2, ngram_degree=5):
text1 = norm_target
if model == "large-v2":
text2 = norm_pred_v2
elif model == "large-32-2":
text2 = norm_pred_32_2
else:
raise ValueError(
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`."
)
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references]), round_dp
)
all_ngrams = list(ngrams(" ".join(text2).split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
return wer_percentage, ier_percentage, repeated_ngrams
def get_overall_table():
large_v2 = get_statistics(model="large-v2")
large_32_2 = get_statistics(model="large-32-2")
# format the rows
table = [large_v2, large_32_2]
# format the model names
table[0] = ["Whisper", *table[0]]
table[1] = ["Distil-Whisper", *table[1]]
return table
def get_visualisation(idx, model="large-v2", round_dp=2, ngram_degree=5):
idx -= 1
audio = dataset[idx]["audio"]
array = (audio["array"] * max_range).astype(np.int16)
sampling_rate = audio["sampling_rate"]
text1 = norm_target[idx]
if model == "large-v2":
text2 = norm_pred_v2[idx]
elif model == "large-32-2":
text2 = norm_pred_32_2[idx]
else:
raise ValueError(
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`."
)
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / len(wer_output.references[0]), round_dp
)
all_ngrams = list(ngrams(text2.split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
diff = compare_string(text1, text2)
full_text = style_text(diff)
return (
(sampling_rate, array),
wer_percentage,
ier_percentage,
repeated_ngrams,
full_text,
)
def get_side_by_side_visualisation(idx):
large_v2 = get_visualisation(idx, model="large-v2")
large_32_2 = get_visualisation(idx, model="large-32-2")
# format the rows
table = [large_v2[1:-1], large_32_2[1:-1]]
# format the model names
table[0] = ["Whisper", *table[0]]
table[1] = ["Distil-Whisper", *table[1]]
return large_v2[0], table, large_v2[-1], large_32_2[-1]
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.HTML(
"""