Upload 2 files
Browse files- app.py +121 -0
- arabic_normalizer.py +87 -0
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from datasets import Audio
|
5 |
+
from datasets import load_dataset
|
6 |
+
from jiwer import wer, cer
|
7 |
+
from transformers import pipeline
|
8 |
+
|
9 |
+
from arabic_normalizer import ArabicTextNormalizer
|
10 |
+
|
11 |
+
# Load dataset
|
12 |
+
common_voice = load_dataset("mozilla-foundation/common_voice_11_0", name = "ar", split = "train")
|
13 |
+
# select column that will be used
|
14 |
+
common_voice = common_voice.select_columns(["audio", "sentence"])
|
15 |
+
|
16 |
+
generate_kwargs = {
|
17 |
+
"language": "arabic",
|
18 |
+
"task": "transcribe"
|
19 |
+
}
|
20 |
+
# Initialize ASR pipeline
|
21 |
+
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3", device = 0,
|
22 |
+
generate_kwargs = generate_kwargs)
|
23 |
+
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
|
24 |
+
device = 0, generate_kwargs = generate_kwargs)
|
25 |
+
normalizer = ArabicTextNormalizer()
|
26 |
+
|
27 |
+
|
28 |
+
def generate_audio(index = None):
|
29 |
+
"""Select an audio sample, resample if needed, and transcribe using ASR."""
|
30 |
+
# inspect dataset
|
31 |
+
# print(common_voice)
|
32 |
+
# print(common_voice.features)
|
33 |
+
|
34 |
+
# resample audio using dataset function
|
35 |
+
global common_voice
|
36 |
+
common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000))
|
37 |
+
# print(common_voice.features)
|
38 |
+
|
39 |
+
# Randomly shuffle the dataset and pick the first sample
|
40 |
+
example = common_voice.shuffle()[0]
|
41 |
+
audio = example["audio"]
|
42 |
+
|
43 |
+
# Ground truth transcription (for WER/CER calculations)
|
44 |
+
reference_text = normalizer(example["sentence"])
|
45 |
+
|
46 |
+
# Prepare audio data for ASR
|
47 |
+
audio_data = {
|
48 |
+
"array": audio["array"],
|
49 |
+
"sampling_rate": audio["sampling_rate"]
|
50 |
+
}
|
51 |
+
|
52 |
+
audio_data_turbo = {
|
53 |
+
"raw": audio["array"],
|
54 |
+
"sampling_rate": audio["sampling_rate"]
|
55 |
+
}
|
56 |
+
|
57 |
+
# Perform automatic speech recognition (ASR) directly on the resampled audio array
|
58 |
+
asr_output = asr_whisper_large(audio_data)
|
59 |
+
|
60 |
+
asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
|
61 |
+
|
62 |
+
# Extract the transcription from the ASR model output
|
63 |
+
predicted_text = normalizer(asr_output["text"])
|
64 |
+
predicted_text_turbo = normalizer(asr_output_turbo["text"])
|
65 |
+
|
66 |
+
# Compute WER, Word Accuracy, and CER
|
67 |
+
wer_score = wer(reference_text, predicted_text)
|
68 |
+
cer_score = cer(reference_text, predicted_text)
|
69 |
+
|
70 |
+
wer_score_turbo = wer(reference_text, predicted_text_turbo)
|
71 |
+
cer_score_turbo = cer(reference_text, predicted_text_turbo)
|
72 |
+
|
73 |
+
# Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
|
74 |
+
sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])
|
75 |
+
|
76 |
+
return ((
|
77 |
+
audio["sampling_rate"],
|
78 |
+
audio["array"]
|
79 |
+
), sentence_info, predicted_text, wer_score, cer_score, predicted_text_turbo,
|
80 |
+
wer_score_turbo, cer_score_turbo)
|
81 |
+
|
82 |
+
def update_ui():
|
83 |
+
res = []
|
84 |
+
for i in range(4):
|
85 |
+
res.append(gr.Textbox(label=f"Label {i}"))
|
86 |
+
return res
|
87 |
+
|
88 |
+
with (gr.Blocks() as demo):
|
89 |
+
gr.HTML("""
|
90 |
+
<h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
|
91 |
+
gr.Markdown("""
|
92 |
+
This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using
|
93 |
+
arabic dataset from mozilla-foundation/common_voice_11_0
|
94 |
+
""")
|
95 |
+
num_samples_input = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of audio samples")
|
96 |
+
generate_button = gr.Button("Generate Samples")
|
97 |
+
|
98 |
+
|
99 |
+
@gr.render(inputs=num_samples_input, triggers=[generate_button.click])
|
100 |
+
def render(num_samples):
|
101 |
+
with gr.Column():
|
102 |
+
for i in range(num_samples):
|
103 |
+
# Generate audio and associated data
|
104 |
+
_audio, label, asr_text, wer_score, cer_score, asr_text_turbo, wer_score_turbo, cer_score_turbo =generate_audio()
|
105 |
+
|
106 |
+
# Create Gradio components to display the audio, transcription, and metrics
|
107 |
+
gr.Audio(_audio, label = label)
|
108 |
+
with gr.Row():
|
109 |
+
with gr.Column():
|
110 |
+
gr.Textbox(value = asr_text, label = "Whisper large output"),
|
111 |
+
gr.Textbox(value = f"WER: {wer_score:.2f}", label = "Word Error Rate"),
|
112 |
+
gr.Textbox(value = f"CER: {cer_score:.2f}", label = "Character Error Rate"),
|
113 |
+
with gr.Column():
|
114 |
+
gr.Textbox(value = asr_text_turbo, label = "Whisper large turbo output"),
|
115 |
+
gr.Textbox(value = f"WER: {wer_score_turbo:.2f}", label = "Word Error Rate - "
|
116 |
+
"TURBO "),
|
117 |
+
gr.Textbox(value = f"CER: {cer_score_turbo:.2f}", label = "Character Error "
|
118 |
+
"Rate - TURBO")
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
demo.launch(show_error = True)
|
arabic_normalizer.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# author : Mohammed BOUSHABA
|
2 |
+
# date : 02/10/2024
|
3 |
+
|
4 |
+
import re
|
5 |
+
import unicodedata
|
6 |
+
|
7 |
+
class ArabicTextNormalizer:
|
8 |
+
def __init__(self):
|
9 |
+
self.arabic_numerals = {
|
10 |
+
'ู ': '0', 'ูก': '1', 'ูข': '2', 'ูฃ': '3', 'ูค': '4',
|
11 |
+
'ูฅ': '5', 'ูฆ': '6', 'ูง': '7', 'ูจ': '8', 'ูฉ': '9'
|
12 |
+
}
|
13 |
+
|
14 |
+
self.arabic_punctuation = {
|
15 |
+
'ุ': ',', 'ุ': ';', 'ุ': '?', 'ยซ': '"', 'ยป': '"'
|
16 |
+
}
|
17 |
+
|
18 |
+
self.removable_diacritics = re.compile(r'[\u064B-\u065F\u0670]')
|
19 |
+
|
20 |
+
self.replacers = {
|
21 |
+
# Common Arabic contractions and their expansions
|
22 |
+
r'\bุฅู ุดุงุก ุงููู\b': 'ุงู ุดุงุก ุงููู',
|
23 |
+
r'\bุจุฅุฐู ุงููู\b': 'ุจุงุฐู ุงููู',
|
24 |
+
r'\bุงูุณูุงู
ุนูููู
\b': 'ุงูุณูุงู
ุนูููู
',
|
25 |
+
# Add more Arabic-specific contractions here
|
26 |
+
}
|
27 |
+
|
28 |
+
def remove_diacritics(self, text):
|
29 |
+
return self.removable_diacritics.sub('', text)
|
30 |
+
|
31 |
+
def normalize_numerals(self, text):
|
32 |
+
for arabic, western in self.arabic_numerals.items():
|
33 |
+
text = text.replace(arabic, western)
|
34 |
+
return text
|
35 |
+
|
36 |
+
def normalize_punctuation(self, text):
|
37 |
+
for arabic, western in self.arabic_punctuation.items():
|
38 |
+
text = text.replace(arabic, western)
|
39 |
+
return text
|
40 |
+
|
41 |
+
def remove_tatweel(self, text):
|
42 |
+
return text.replace('\u0640', '') # Remove tatweel (kashida)
|
43 |
+
|
44 |
+
def remove_dots(self, text):
|
45 |
+
return text.replace('.', '')
|
46 |
+
|
47 |
+
def remove_non_arabic(self, text):
|
48 |
+
return ''.join(c for c in text if '\u0600' <= c <= '\u06FF' or c.isascii())
|
49 |
+
|
50 |
+
def __call__(self, text):
|
51 |
+
# Convert to NFC form for consistent Unicode representation
|
52 |
+
text = unicodedata.normalize('NFC', text)
|
53 |
+
|
54 |
+
# Apply replacements for common contractions
|
55 |
+
for pattern, replacement in self.replacers.items():
|
56 |
+
text = re.sub(pattern, replacement, text)
|
57 |
+
|
58 |
+
# Normalize Arabic-specific elements
|
59 |
+
text = self.remove_diacritics(text)
|
60 |
+
text = self.normalize_numerals(text)
|
61 |
+
#text = self.normalize_punctuation(text)
|
62 |
+
text = self.remove_tatweel(text)
|
63 |
+
text = self.remove_dots(text)
|
64 |
+
|
65 |
+
# Remove non-Arabic characters (except ASCII)
|
66 |
+
text = self.remove_non_arabic(text)
|
67 |
+
|
68 |
+
# Remove extra whitespace
|
69 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
70 |
+
|
71 |
+
return text
|
72 |
+
|
73 |
+
# Example usage
|
74 |
+
if __name__ == "__main__":
|
75 |
+
normalizer = ArabicTextNormalizer()
|
76 |
+
|
77 |
+
test_texts = [
|
78 |
+
"ุงูุณููููุงู
ู ุนูููููููู
ู ููุฑูุญูู
ูุฉู ุงูููู ููุจูุฑูููุงุชููู",
|
79 |
+
"ุฅู ุดููููุงุก ุงููู ุณููุชูู ูู ุงูุณุงุนุฉ ูฃ:ูฃู ู
ุณุงุกู",
|
80 |
+
"ููุงูู ููููุงูู ูกูขูฃูคูฅ ุดูุฎูุตูุง ููู ุงููู
ูููุนูุจู",
|
81 |
+
]
|
82 |
+
|
83 |
+
for text in test_texts:
|
84 |
+
normalized = normalizer(text)
|
85 |
+
print(f"Original: {text}")
|
86 |
+
print(f"Normalized: {normalized}")
|
87 |
+
print()
|