SLPL
/

Sharif-wav2vec2 / README.md
sadrasabouri's picture
Update README.md
479390e
|
raw
history blame
5.19 kB
metadata
language: fa
datasets:
  - common_voice_6_1
tags:
  - audio
  - automatic-speech-recognition
license: mit
widget:
  - example_title: Common Voice Sample 1
    src: >-
      https://datasets-server.huggingface.co/assets/common_voice/--/fa/train/0/audio/audio.mp3
  - example_title: Common Voice Sample 2
    src: >-
      https://datasets-server.huggingface.co/assets/common_voice/--/fa/train/1/audio/audio.mp3
model-index:
  - name: Sharif-wav2vec2
    results:
      - task:
          name: Automatic Speech Recognition
          type: automatic-speech-recognition
        dataset:
          name: Common Voice Corpus 6.1 (clean)
          type: common_voice_6_1
          config: clean
          split: test
          args:
            language: fa
        metrics:
          - name: Test WER
            type: wer
            value: 6

Sharif-wav2vec2

This is the fine-tuned version of Sharif Wav2vec2 for Farsi. The base model was fine-tuned on 108 hours of Commonvoice's Farsi samples with a sampling rate equal to 16kHz. When using the model make sure that your speech input is also sampled at 16Khz. Prior to the usage, you may need to install the below dependencies:

pip -q install pyctcdecode
python -m pip -q install pypi-kenlm

For testing you can use the hoster API at the hugging face (There are provided examples from common voice) it may take a while to transcribe the given voice. Or you can use bellow code for local run:

import tensorflow
import torchaudio
import torch
import numpy as np

from transformers import AutoProcessor, AutoModelForCTC

processor = AutoProcessor.from_pretrained("SLPL/Sharif-wav2vec2")
model = AutoModelForCTC.from_pretrained("SLPL/Sharif-wav2vec2")

speech_array, sampling_rate = torchaudio.load("path/to/your.wav")
speech_array = speech_array.squeeze().numpy()

features = processor(
    speech_array,
    sampling_rate=processor.feature_extractor.sampling_rate,
    return_tensors="pt",
    padding=True)

with torch.no_grad():
    logits = model(
        features.input_values,
        attention_mask=features.attention_mask).logits
    prediction = processor.batch_decode(logits.numpy()).text

print(prediction[0])
# تست

Authors: Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli

Abstract

The original model can be found under https://github.com/pytorch/fairseq/tree/master/examples/wav2vec#wav2vec-20.

Usage

To transcribe Persian audio files the model can be used as a standalone acoustic model as follows:

 from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
 from datasets import load_dataset
 import torch
 
 # load model and tokenizer
 processor = Wav2Vec2Processor.from_pretrained("SLPL/Sharif-wav2vec2")
 model = Wav2Vec2ForCTC.from_pretrained("SLPL/Sharif-wav2vec2")
     
 # load dummy dataset and read soundfiles
# ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
 
 # tokenize
 input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values  # Batch size 1
 
 # retrieve logits
 logits = model(input_values).logits
 
 # take argmax and decode
 predicted_ids = torch.argmax(logits, dim=-1)
 transcription = processor.batch_decode(predicted_ids)

Evaluation

This code snippet shows how to evaluate facebook/wav2vec2-base-960h on LibriSpeech's "clean" and "other" test data.

from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
from jiwer import wer


librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch

result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])

print("WER:", wer(result["text"], result["transcription"]))

Result (WER):

"clean" "other"
3.4 8.6