sinai-voice-ar-stt / README.md
bakrianoo's picture
Set a definition for WER
7bd1af0
|
raw
history blame
12.6 kB
metadata
language: ar
datasets:
  - common_voice
metrics:
  - wer
tags:
  - audio
  - automatic-speech-recognition
  - speech
  - xlsr-fine-tuning-week
license: apache-2.0
model-index:
  - name: Sinai Voice Arabic Speech Recognition Model
    results:
      - task:
          name: Speech Recognition
          type: automatic-speech-recognition
        dataset:
          name: Common Voice ar
          type: common_voice
          args: ar
        metrics:
          - name: Test WER
            type: wer
            value: 23.8

Sinai Voice Arabic Speech Recognition Model

نموذج صوت سيناء للتعرف على الأصوات العربية الفصحى و تحويلها إلى نصوص

Fine-tuned facebook/wav2vec2-large-xlsr-53 on Arabic using the Common Voice

Most of evaluation codes in this documentation are INSPIRED by elgeish/wav2vec2-large-xlsr-53-arabic

Please install:

  • PyTorch
  • $ pip3 install jiwer lang_trans torchaudio datasets transformers pandas tqdm

Benchmark

We evaluated the model against different Arabic-STT Wav2Vec models.

[WER: Word Error Rate] The Lowest score you get, the best model you have

Model using transliteration WER Training Datasets
1 bakrianoo/sinai-voice-ar-stt True 0.238001 Common Voice 6
2 elgeish/wav2vec2-large-xlsr-53-arabic True 0.266527 Common Voice 6 + Arabic Speech Corpus
3 othrif/wav2vec2-large-xlsr-arabic True 0.298122 Common Voice 6
4 bakrianoo/sinai-voice-ar-stt False 0.448987 Common Voice 6
5 othrif/wav2vec2-large-xlsr-arabic False 0.464004 Common Voice 6
6 anas/wav2vec2-large-xlsr-arabic True 0.506191 Common Voice 4
7 anas/wav2vec2-large-xlsr-arabic False 0.622288 Common Voice 4
We used the following CODE to generate the above results
import jiwer
import torch
from tqdm.auto import tqdm
import torchaudio
from datasets import load_dataset
from lang_trans.arabic import buckwalter
from transformers import set_seed, Wav2Vec2ForCTC, Wav2Vec2Processor
import pandas as pd

# load test dataset
set_seed(42)
test_split = load_dataset("common_voice", "ar", split="test")

# init sample rate resamplers
resamplers = {  # all three sampling rates exist in test split
    48000: torchaudio.transforms.Resample(48000, 16000),
    44100: torchaudio.transforms.Resample(44100, 16000),
    32000: torchaudio.transforms.Resample(32000, 16000),
}

# WER composer
transformation = jiwer.Compose([
    # normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
    jiwer.SubstituteRegexes({
        r'[auiFNKo\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\~_،؟»\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\?;:\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\-,\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\.؛«!"]': "", "\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\u06D6": "",
        r"[\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\|\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\{]": "A", "p": "h", "ک": "k", "ی": "y"}),
    # default transformation below
    jiwer.RemoveMultipleSpaces(),
    jiwer.Strip(),
    jiwer.SentencesToListOfWords(),
    jiwer.RemoveEmptyStrings(),
])

def prepare_example(example):
    speech, sampling_rate = torchaudio.load(example["path"])
    if sampling_rate in resamplers:
        example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
    else:
        example["speech"] = resamplers[4800](speech).squeeze().numpy()
    return example

def predict(batch):
    inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        predicted = torch.argmax(model(inputs.input_values.to("cuda")).logits, dim=-1)
    predicted[predicted == -100] = processor.tokenizer.pad_token_id  # see fine-tuning script
    batch["predicted"] = processor.batch_decode(predicted)
    return batch

# prepare the test dataset
test_split = test_split.map(prepare_example)

stt_models = [
   "elgeish/wav2vec2-large-xlsr-53-arabic",
   "othrif/wav2vec2-large-xlsr-arabic",
   "anas/wav2vec2-large-xlsr-arabic",
   "bakrianoo/sinai-voice-ar-stt"
]

stt_results = []

for model_path in tqdm(stt_models):
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    model = Wav2Vec2ForCTC.from_pretrained(model_path).to("cuda").eval()
    
    test_split_preds = test_split.map(predict, batched=True, batch_size=56, remove_columns=["speech"])
    
    orig_metrics = jiwer.compute_measures(
        truth=[s for s in test_split_preds["sentence"]],
        hypothesis=[s for s in test_split_preds["predicted"]],
        truth_transform=transformation,
        hypothesis_transform=transformation,
    )
    
    trans_metrics = jiwer.compute_measures(
        truth=[buckwalter.trans(s) for s in test_split_preds["sentence"]],  # Buckwalter transliteration
        hypothesis=[buckwalter.trans(s) for s in test_split_preds["predicted"]], # Buckwalter transliteration
        truth_transform=transformation,
        hypothesis_transform=transformation,
    )
    
    stt_results.append({
        "model": model_path,
        "using_transliation": True,
        "WER": trans_metrics["wer"]
    })
    
    stt_results.append({
        "model": model_path,
        "using_transliation": False,
        "WER": orig_metrics["wer"]
    })
    
    del model
    del processor
    
stt_results_df = pd.DataFrame(stt_results)
stt_results_df = stt_results_df.sort_values('WER', axis=0, ascending=True)
stt_results_df.head(n=50)

Usage

The model can be used directly (without a language model) as follows:

import torch
import torchaudio
from datasets import load_dataset
from lang_trans.arabic import buckwalter
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
dataset = load_dataset("common_voice", "ar", split="test[:10]")
resamplers = {  # all three sampling rates exist in test split
    48000: torchaudio.transforms.Resample(48000, 16000),
    44100: torchaudio.transforms.Resample(44100, 16000),
    32000: torchaudio.transforms.Resample(32000, 16000),
}

def prepare_example(example):
    speech, sampling_rate = torchaudio.load(example["path"])
    if sampling_rate in resamplers:
        example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
    else:
        example["speech"] = resamplers[4800](speech).squeeze().numpy()
    return example
   
dataset = dataset.map(prepare_example)
processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").eval()
def predict(batch):
    inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        predicted = torch.argmax(model(inputs.input_values).logits, dim=-1)
    predicted[predicted == -100] = processor.tokenizer.pad_token_id  # see fine-tuning script
    batch["predicted"] = processor.tokenizer.batch_decode(predicted)
    return batch
dataset = dataset.map(predict, batched=True, batch_size=1, remove_columns=["speech"])
for reference, predicted in zip(dataset["sentence"], dataset["predicted"]):
    print("reference:", reference)
    print("predicted:", predicted)
    print("--")

Here's the output: ``` reference: ألديك قلم ؟ predicted: ألديك قلم

reference: ليست هناك مسافة على هذه الأرض أبعد من يوم أمس. predicted: ليست نارك مسافة على هذه الأرض أبعد من يوم أمس

reference: إنك تكبر المشكلة. predicted: إنك تكبر المشكلة

reference: يرغب أن يلتقي بك. predicted: يرغب أن يلتقي بك

reference: إنهم لا يعرفون لماذا حتى. predicted: إنهم لا يعرفون لماذا حتى

reference: سيسعدني مساعدتك أي وقت تحب. predicted: سيسعدن مساعثتك أي وقد تحب

reference: أَحَبُّ نظريّة علمية إليّ هي أن حلقات زحل مكونة بالكامل من الأمتعة المفقودة. predicted: أحب نظرية علمية إلي هي أن أحلقتز حلم كوينا بالكامل من الأمت عن المفقودة

reference: سأشتري له قلماً. predicted: سأشتري له قلما

reference: أين المشكلة ؟ predicted: أين المشكل

reference: وَلِلَّهِ يَسْجُدُ مَا فِي السَّمَاوَاتِ وَمَا فِي الْأَرْضِ مِنْ دَابَّةٍ وَالْمَلَائِكَةُ وَهُمْ لَا يَسْتَكْبِرُونَ predicted: ولله يسجد ما في السماوات وما في الأرض من دابة والملائكة وهم لا يستكبرون


## Evaluation

The model can be evaluated as follows on the Arabic test data of Common Voice:
```python
import jiwer
import torch
import torchaudio
from datasets import load_dataset
from lang_trans.arabic import buckwalter
from transformers import set_seed, Wav2Vec2ForCTC, Wav2Vec2Processor
set_seed(42)
test_split = load_dataset("common_voice", "ar", split="test")
resamplers = {  # all three sampling rates exist in test split
    48000: torchaudio.transforms.Resample(48000, 16000),
    44100: torchaudio.transforms.Resample(44100, 16000),
    32000: torchaudio.transforms.Resample(32000, 16000),
}

def prepare_example(example):
    speech, sampling_rate = torchaudio.load(example["path"])
    if sampling_rate in resamplers:
        example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
    else:
        example["speech"] = resamplers[4800](speech).squeeze().numpy()
    return example
 
test_split = test_split.map(prepare_example)
processor = Wav2Vec2Processor.from_pretrained("bakrianoo/sinai-voice-ar-stt")
model = Wav2Vec2ForCTC.from_pretrained("bakrianoo/sinai-voice-ar-stt").to("cuda").eval()
def predict(batch):
    inputs = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        predicted = torch.argmax(model(inputs.input_values.to("cuda")).logits, dim=-1)
    predicted[predicted == -100] = processor.tokenizer.pad_token_id  # see fine-tuning script
    batch["predicted"] = processor.batch_decode(predicted)
    return batch
test_split = test_split.map(predict, batched=True, batch_size=16, remove_columns=["speech"])

transformation = jiwer.Compose([
    # normalize some diacritics, remove punctuation, and replace Persian letters with Arabic ones
    jiwer.SubstituteRegexes({
        r'[auiFNKo\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\~_،؟»\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\?;:\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\-,\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\.؛«!"]': "", "\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\u06D6": "",
        r"[\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\|\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\{]": "A", "p": "h", "ک": "k", "ی": "y"}),
    # default transformation below
    jiwer.RemoveMultipleSpaces(),
    jiwer.Strip(),
    jiwer.SentencesToListOfWords(),
    jiwer.RemoveEmptyStrings(),
])

metrics = jiwer.compute_measures(
    truth=[buckwalter.trans(s) for s in test_split["sentence"]],  # Buckwalter transliteration
    hypothesis=[buckwalter.trans(s) for s in test_split["predicted"]],
    truth_transform=transformation,
    hypothesis_transform=transformation,
)
print(f"WER: {metrics['wer']:.2%}")

Test Result: 23.80%

[WER: Word Error Rate] The Lowest score you get, the best model you have

Other Arabic Voice recognition Models

الكلمات لا تكفى لشكر أولئك الذين يؤمنون أن هنالك أمل, و يسعون من أجله