|
import argparse |
|
import pyarabic.araby as araby |
|
from transformers import pipeline |
|
from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
|
from datasets import load_dataset, Audio |
|
import evaluate |
|
|
|
wer_metric = evaluate.load("wer") |
|
|
|
|
|
def is_target_text_in_range(ref): |
|
if ref.strip() == "ignore time segment in scoring": |
|
return False |
|
else: |
|
return ref.strip() != "" |
|
|
|
|
|
def get_text(sample): |
|
if "text" in sample: |
|
return sample["text"] |
|
elif "sentence" in sample: |
|
return sample["sentence"] |
|
elif "normalized_text" in sample: |
|
return sample["normalized_text"] |
|
elif "transcript" in sample: |
|
return sample["transcript"] |
|
elif "transcription" in sample: |
|
return sample["transcription"] |
|
else: |
|
raise ValueError( |
|
f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of " |
|
".join{sample.keys()}. Ensure a text column name is present in the dataset." |
|
) |
|
|
|
|
|
whisper_norm = BasicTextNormalizer() |
|
|
|
|
|
def normalise(batch): |
|
batch["norm_text"] = whisper_norm(get_text(batch)) |
|
return batch |
|
|
|
|
|
def remove_diacritics(batch): |
|
batch["norm_text"] = araby.strip_diacritics(get_text(batch)) |
|
return batch |
|
|
|
|
|
def data(dataset): |
|
for i, item in enumerate(dataset): |
|
yield {**item["audio"], "reference": item["norm_text"]} |
|
|
|
|
|
def main(args): |
|
batch_size = args.batch_size |
|
whisper_asr = pipeline( |
|
"automatic-speech-recognition", model=args.model_id, device=args.device |
|
) |
|
|
|
whisper_asr.model.config.forced_decoder_ids = ( |
|
whisper_asr.tokenizer.get_decoder_prompt_ids( |
|
language=args.language, task="transcribe" |
|
) |
|
) |
|
|
|
dataset = load_dataset( |
|
args.dataset, |
|
args.config, |
|
split=args.split, |
|
streaming=args.streaming, |
|
use_auth_token=True, |
|
) |
|
|
|
|
|
dataset = dataset.take(args.max_eval_samples) |
|
|
|
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) |
|
dataset = dataset.map(normalise) |
|
if args.remove_diacritics: |
|
print("stripping diacritics") |
|
dataset = dataset.map(remove_diacritics) |
|
dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) |
|
|
|
predictions = [] |
|
references = [] |
|
|
|
|
|
for out in whisper_asr(data(dataset), batch_size=batch_size): |
|
predictions.append(whisper_norm(out["text"])) |
|
references.append(out["reference"][0]) |
|
|
|
wer = wer_metric.compute(references=references, predictions=predictions) |
|
wer = round(100 * wer, 2) |
|
|
|
print("WER:", wer) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--model_id", |
|
type=str, |
|
required=True, |
|
help="Model identifier. Should be loadable with 🤗 Transformers", |
|
) |
|
parser.add_argument( |
|
"--dataset", |
|
type=str, |
|
default="mozilla-foundation/common_voice_11_0", |
|
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", |
|
) |
|
parser.add_argument( |
|
"--config", |
|
type=str, |
|
required=True, |
|
help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice", |
|
) |
|
parser.add_argument( |
|
"--split", |
|
type=str, |
|
default="test", |
|
help="Split of the dataset. *E.g.* `'test'`", |
|
) |
|
|
|
parser.add_argument( |
|
"--device", |
|
type=int, |
|
default=-1, |
|
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=16, |
|
help="Number of samples to go through each streamed batch.", |
|
) |
|
parser.add_argument( |
|
"--max_eval_samples", |
|
type=int, |
|
default=None, |
|
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", |
|
) |
|
parser.add_argument( |
|
"--streaming", |
|
type=bool, |
|
default=True, |
|
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", |
|
) |
|
parser.add_argument( |
|
"--language", |
|
type=str, |
|
required=True, |
|
help="Two letter language code for the transcription language, e.g. use 'en' for English.", |
|
) |
|
|
|
parser.add_argument( |
|
"--remove_diacritics", |
|
type=bool, |
|
default=False, |
|
help="Choose whether you'd like remove_diacritics", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|