mboushaba commited on
Commit
022d425
ยท
verified ยท
1 Parent(s): b99be23

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +121 -0
  2. arabic_normalizer.py +87 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ from datasets import Audio
5
+ from datasets import load_dataset
6
+ from jiwer import wer, cer
7
+ from transformers import pipeline
8
+
9
+ from arabic_normalizer import ArabicTextNormalizer
10
+
11
+ # Load dataset
12
+ common_voice = load_dataset("mozilla-foundation/common_voice_11_0", name = "ar", split = "train")
13
+ # select column that will be used
14
+ common_voice = common_voice.select_columns(["audio", "sentence"])
15
+
16
+ generate_kwargs = {
17
+ "language": "arabic",
18
+ "task": "transcribe"
19
+ }
20
+ # Initialize ASR pipeline
21
+ asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3", device = 0,
22
+ generate_kwargs = generate_kwargs)
23
+ asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
24
+ device = 0, generate_kwargs = generate_kwargs)
25
+ normalizer = ArabicTextNormalizer()
26
+
27
+
28
+ def generate_audio(index = None):
29
+ """Select an audio sample, resample if needed, and transcribe using ASR."""
30
+ # inspect dataset
31
+ # print(common_voice)
32
+ # print(common_voice.features)
33
+
34
+ # resample audio using dataset function
35
+ global common_voice
36
+ common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000))
37
+ # print(common_voice.features)
38
+
39
+ # Randomly shuffle the dataset and pick the first sample
40
+ example = common_voice.shuffle()[0]
41
+ audio = example["audio"]
42
+
43
+ # Ground truth transcription (for WER/CER calculations)
44
+ reference_text = normalizer(example["sentence"])
45
+
46
+ # Prepare audio data for ASR
47
+ audio_data = {
48
+ "array": audio["array"],
49
+ "sampling_rate": audio["sampling_rate"]
50
+ }
51
+
52
+ audio_data_turbo = {
53
+ "raw": audio["array"],
54
+ "sampling_rate": audio["sampling_rate"]
55
+ }
56
+
57
+ # Perform automatic speech recognition (ASR) directly on the resampled audio array
58
+ asr_output = asr_whisper_large(audio_data)
59
+
60
+ asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
61
+
62
+ # Extract the transcription from the ASR model output
63
+ predicted_text = normalizer(asr_output["text"])
64
+ predicted_text_turbo = normalizer(asr_output_turbo["text"])
65
+
66
+ # Compute WER, Word Accuracy, and CER
67
+ wer_score = wer(reference_text, predicted_text)
68
+ cer_score = cer(reference_text, predicted_text)
69
+
70
+ wer_score_turbo = wer(reference_text, predicted_text_turbo)
71
+ cer_score_turbo = cer(reference_text, predicted_text_turbo)
72
+
73
+ # Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
74
+ sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])
75
+
76
+ return ((
77
+ audio["sampling_rate"],
78
+ audio["array"]
79
+ ), sentence_info, predicted_text, wer_score, cer_score, predicted_text_turbo,
80
+ wer_score_turbo, cer_score_turbo)
81
+
82
+ def update_ui():
83
+ res = []
84
+ for i in range(4):
85
+ res.append(gr.Textbox(label=f"Label {i}"))
86
+ return res
87
+
88
+ with (gr.Blocks() as demo):
89
+ gr.HTML("""
90
+ <h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
91
+ gr.Markdown("""
92
+ This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using
93
+ arabic dataset from mozilla-foundation/common_voice_11_0
94
+ """)
95
+ num_samples_input = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of audio samples")
96
+ generate_button = gr.Button("Generate Samples")
97
+
98
+
99
+ @gr.render(inputs=num_samples_input, triggers=[generate_button.click])
100
+ def render(num_samples):
101
+ with gr.Column():
102
+ for i in range(num_samples):
103
+ # Generate audio and associated data
104
+ _audio, label, asr_text, wer_score, cer_score, asr_text_turbo, wer_score_turbo, cer_score_turbo =generate_audio()
105
+
106
+ # Create Gradio components to display the audio, transcription, and metrics
107
+ gr.Audio(_audio, label = label)
108
+ with gr.Row():
109
+ with gr.Column():
110
+ gr.Textbox(value = asr_text, label = "Whisper large output"),
111
+ gr.Textbox(value = f"WER: {wer_score:.2f}", label = "Word Error Rate"),
112
+ gr.Textbox(value = f"CER: {cer_score:.2f}", label = "Character Error Rate"),
113
+ with gr.Column():
114
+ gr.Textbox(value = asr_text_turbo, label = "Whisper large turbo output"),
115
+ gr.Textbox(value = f"WER: {wer_score_turbo:.2f}", label = "Word Error Rate - "
116
+ "TURBO "),
117
+ gr.Textbox(value = f"CER: {cer_score_turbo:.2f}", label = "Character Error "
118
+ "Rate - TURBO")
119
+
120
+ if __name__ == '__main__':
121
+ demo.launch(show_error = True)
arabic_normalizer.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # author : Mohammed BOUSHABA
2
+ # date : 02/10/2024
3
+
4
+ import re
5
+ import unicodedata
6
+
7
+ class ArabicTextNormalizer:
8
+ def __init__(self):
9
+ self.arabic_numerals = {
10
+ 'ู ': '0', 'ูก': '1', 'ูข': '2', 'ูฃ': '3', 'ูค': '4',
11
+ 'ูฅ': '5', 'ูฆ': '6', 'ูง': '7', 'ูจ': '8', 'ูฉ': '9'
12
+ }
13
+
14
+ self.arabic_punctuation = {
15
+ 'ุŒ': ',', 'ุ›': ';', 'ุŸ': '?', 'ยซ': '"', 'ยป': '"'
16
+ }
17
+
18
+ self.removable_diacritics = re.compile(r'[\u064B-\u065F\u0670]')
19
+
20
+ self.replacers = {
21
+ # Common Arabic contractions and their expansions
22
+ r'\bุฅู† ุดุงุก ุงู„ู„ู‡\b': 'ุงู† ุดุงุก ุงู„ู„ู‡',
23
+ r'\bุจุฅุฐู† ุงู„ู„ู‡\b': 'ุจุงุฐู† ุงู„ู„ู‡',
24
+ r'\bุงู„ุณู„ุงู… ุนู„ูŠูƒู…\b': 'ุงู„ุณู„ุงู… ุนู„ูŠูƒู…',
25
+ # Add more Arabic-specific contractions here
26
+ }
27
+
28
+ def remove_diacritics(self, text):
29
+ return self.removable_diacritics.sub('', text)
30
+
31
+ def normalize_numerals(self, text):
32
+ for arabic, western in self.arabic_numerals.items():
33
+ text = text.replace(arabic, western)
34
+ return text
35
+
36
+ def normalize_punctuation(self, text):
37
+ for arabic, western in self.arabic_punctuation.items():
38
+ text = text.replace(arabic, western)
39
+ return text
40
+
41
+ def remove_tatweel(self, text):
42
+ return text.replace('\u0640', '') # Remove tatweel (kashida)
43
+
44
+ def remove_dots(self, text):
45
+ return text.replace('.', '')
46
+
47
+ def remove_non_arabic(self, text):
48
+ return ''.join(c for c in text if '\u0600' <= c <= '\u06FF' or c.isascii())
49
+
50
+ def __call__(self, text):
51
+ # Convert to NFC form for consistent Unicode representation
52
+ text = unicodedata.normalize('NFC', text)
53
+
54
+ # Apply replacements for common contractions
55
+ for pattern, replacement in self.replacers.items():
56
+ text = re.sub(pattern, replacement, text)
57
+
58
+ # Normalize Arabic-specific elements
59
+ text = self.remove_diacritics(text)
60
+ text = self.normalize_numerals(text)
61
+ #text = self.normalize_punctuation(text)
62
+ text = self.remove_tatweel(text)
63
+ text = self.remove_dots(text)
64
+
65
+ # Remove non-Arabic characters (except ASCII)
66
+ text = self.remove_non_arabic(text)
67
+
68
+ # Remove extra whitespace
69
+ text = re.sub(r'\s+', ' ', text).strip()
70
+
71
+ return text
72
+
73
+ # Example usage
74
+ if __name__ == "__main__":
75
+ normalizer = ArabicTextNormalizer()
76
+
77
+ test_texts = [
78
+ "ุงู„ุณูŽู‘ู„ูŽุงู…ู ุนูŽู„ูŽูŠู’ูƒูู…ู’ ูˆูŽุฑูŽุญู’ู…ูŽุฉู ุงู„ู„ู‡ู ูˆูŽุจูŽุฑูŽูƒูŽุงุชูู‡ู",
79
+ "ุฅู† ุดู€ู€ู€ู€ุงุก ุงู„ู„ู‡ ุณู†ู„ุชู‚ูŠ ููŠ ุงู„ุณุงุนุฉ ูฃ:ูฃู  ู…ุณุงุกู‹",
80
+ "ูƒูŽุงู†ูŽ ู‡ูู†ูŽุงูƒูŽ ูกูขูฃูคูฅ ุดูŽุฎู’ุตู‹ุง ูููŠ ุงู„ู’ู…ูŽู„ู’ุนูŽุจู",
81
+ ]
82
+
83
+ for text in test_texts:
84
+ normalized = normalizer(text)
85
+ print(f"Original: {text}")
86
+ print(f"Normalized: {normalized}")
87
+ print()