File size: 6,665 Bytes
022d425
 
 
 
 
 
 
 
 
e779c90
 
022d425
 
 
 
 
 
 
 
460f7e6
022d425
 
e779c90
 
 
 
022d425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e779c90
 
 
 
 
022d425
 
 
e779c90
022d425
 
 
 
e779c90
022d425
 
 
 
 
 
 
 
e779c90
 
 
022d425
 
 
e779c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022d425
 
 
 
e779c90
022d425
 
e779c90
 
022d425
 
 
 
 
 
e779c90
022d425
 
 
e779c90
022d425
 
 
 
e779c90
022d425
 
e779c90
022d425
 
e779c90
 
 
 
 
 
 
 
 
022d425
e779c90
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
from datasets import Audio
from datasets import load_dataset
from jiwer import wer, cer
from transformers import pipeline

from arabic_normalizer import ArabicTextNormalizer

# Load dataset
common_voice = load_dataset("mozilla-foundation/common_voice_11_0", trust_remote_code = True, name = "ar",
                            split = "train")
# select column that will be used
common_voice = common_voice.select_columns(["audio", "sentence"])

generate_kwargs = {
    "language": "arabic",
    "task": "transcribe"
}
# Initialize ASR pipeline
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3",
                             generate_kwargs = generate_kwargs)
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
                                   generate_kwargs = generate_kwargs)
asr_whisper_large_turbo_mboushaba = pipeline("automatic-speech-recognition", model =
"mboushaba/whisper-large-v3-turbo-arabic",
                                             generate_kwargs = generate_kwargs)
normalizer = ArabicTextNormalizer()


def generate_audio(index = None):
    """Select an audio sample, resample if needed, and transcribe using ASR."""
    # inspect dataset
    # print(common_voice)
    # print(common_voice.features)

    # resample audio using dataset function
    global common_voice
    common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000))
    # print(common_voice.features)

    # Randomly shuffle the dataset and pick the first sample
    example = common_voice.shuffle()[0]
    audio = example["audio"]

    # Ground truth transcription (for WER/CER calculations)
    reference_text = normalizer(example["sentence"])

    # Prepare audio data for ASR
    audio_data = {
        "array": audio["array"],
        "sampling_rate": audio["sampling_rate"]
    }

    audio_data_turbo = {
        "raw": audio["array"],
        "sampling_rate": audio["sampling_rate"]
    }

    audio_data_turbo_mboushaba = {
        "raw": audio["array"],
        "sampling_rate": audio["sampling_rate"]
    }

    # Perform automatic speech recognition (ASR) directly on the resampled audio array
    asr_output = asr_whisper_large(audio_data)
    asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
    asr_output_turbo_mboushaba = asr_whisper_large_turbo_mboushaba(audio_data_turbo_mboushaba)

    # Extract the transcription from the ASR model output
    predicted_text = normalizer(asr_output["text"])
    predicted_text_turbo = normalizer(asr_output_turbo["text"])
    predicted_text_turbo_mboushaba = normalizer(asr_output_turbo_mboushaba["text"])

    # Compute WER, Word Accuracy, and CER
    wer_score = wer(reference_text, predicted_text)
    cer_score = cer(reference_text, predicted_text)

    wer_score_turbo = wer(reference_text, predicted_text_turbo)
    cer_score_turbo = cer(reference_text, predicted_text_turbo)

    wer_score_turbo_mboushaba = wer(reference_text, predicted_text_turbo_mboushaba)
    cer_score_turbo_mboushaba = cer(reference_text, predicted_text_turbo_mboushaba)

    # Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
    sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])

    return {
        "audio": (
            audio["sampling_rate"],
            audio["array"]
        ),
        "sentence_info": sentence_info,
        "predicted_text": predicted_text,
        "wer_score": wer_score,
        "cer_score": cer_score,
        "predicted_text_turbo": predicted_text_turbo,
        "wer_score_turbo": wer_score_turbo,
        "cer_score_turbo": cer_score_turbo,
        "predicted_text_turbo_mboushaba": predicted_text_turbo_mboushaba,
        "wer_score_turbo_mboushaba": wer_score_turbo_mboushaba,
        "cer_score_turbo_mboushaba": cer_score_turbo_mboushaba
    }


def update_ui():
    res = []
    for i in range(4):
        res.append(gr.Textbox(label = f"Label {i}"))
    return res


with gr.Blocks() as demo:
    gr.HTML("""
        <h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
    gr.Markdown("""
        This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using 
        arabic dataset from mozilla-foundation/common_voice_11_0
    """)
    num_samples_input = gr.Slider(minimum = 1, maximum = 10, step = 1, value = 4, label = "Number of audio samples")
    generate_button = gr.Button("Generate Samples")


    @gr.render(inputs = num_samples_input, triggers = [generate_button.click])
    def render(num_samples):
        with gr.Column():
            for i in range(num_samples):
                # Generate audio and associated data
                data = generate_audio()

                # Create Gradio components to display the audio, transcription, and metrics
                gr.Audio(data["audio"], label = data["sentence_info"])
                with gr.Row():
                    with gr.Column():
                        gr.Textbox(value = data["predicted_text"], label = "Whisper large output"),
                        gr.Textbox(value = f"WER: {data['wer_score']:.2f}", label = "Word Error Rate"),
                        gr.Textbox(value = f"CER: {data['cer_score']:.2f}", label = "Character Error Rate"),
                    with gr.Column():
                        gr.Textbox(value = data["predicted_text_turbo"], label = "Whisper large turbo output"),
                        gr.Textbox(value = f"WER: {data['wer_score_turbo']:.2f}", label = "Word Error Rate - "
                                                                                  "TURBO  "),
                        gr.Textbox(value = f"CER: {data['cer_score_turbo']:.2f}", label = "Character Error "
                                                                                  "Rate - TURBO")
                    with gr.Column():
                        gr.Textbox(value = data["predicted_text_turbo_mboushaba"], label = "Whisper large turbo "
                                                                                           "mboushaba output"),
                        gr.Textbox(value = f"WER: {data['wer_score_turbo_mboushaba']:.2f}", label = "Word Error Rate - "
                                                                                  " mboushaba TURBO  "),
                        gr.Textbox(value = f"CER: {data['cer_score_turbo_mboushaba']:.2f}", label = "Character Error "
                                                                                  "Rate - mboushaba TURBO")

demo.launch(show_error = True)