import dataclasses import os import os.path import re from datasets import load_dataset from datasets import Audio import jiwer import torch from transformers import AutoProcessor, Wav2Vec2ForCTC from transformers.models.wav2vec2.processing_wav2vec2 import Wav2Vec2Processor MODEL = "xekri/wav2vec2-common_voice_13_0-eo-10" DATA = "validation[:10]" chars_to_ignore_regex = "[-!\"'(),.:;=?_`¨«¸»ʼ‑–—‘’“”„…‹›♫?]" chars_to_substitute = { "przy": "pŝe", "byn": "bin", "cx": "ĉ", "sx": "ŝ", "fi": "fi", "fl": "fl", "ǔ": "ŭ", "ñ": "nj", "á": "a", "é": "e", "ü": "ŭ", "y": "j", "qu": "ku", } def remove_special_characters(text: str) -> str: text = re.sub(chars_to_ignore_regex, "", text) text = text.lower() return text def substitute_characters(text: str) -> str: for k, v in chars_to_substitute.items(): text.replace(k, v) text = text.lower() return text @dataclasses.dataclass class EvalResult: filename: str cer: float loss: float actual: str predicted: str def print(self) -> None: print(f"FILE {self.filename}") print(f"CERR {self.cer}") print(f"LOSS {self.loss}") print(f"ACTU {self.actual}") print(f"PRED {self.predicted}") def evaluate(processor: Wav2Vec2Processor, model, example) -> EvalResult: """Evaluates a single example.""" audio_file = example["path"] d, n = os.path.split(audio_file) f = os.listdir(d)[0] audio_file = os.path.join(d, f, n) inputs = processor( audio=example["audio"]["array"], sampling_rate=16000, return_tensors="pt" ) with torch.no_grad(): logits = model(**inputs).logits predicted_ids = logits.argmax(dim=-1) predict = processor.batch_decode(predicted_ids)[0] actual = example["sentence"] actual = substitute_characters(remove_special_characters(actual)) inputs["labels"] = processor(text=actual, return_tensors="pt").input_ids loss = model(**inputs).loss cer = jiwer.cer(actual, predict) return EvalResult(os.path.basename(audio_file), cer, loss, actual, predict) def run() -> None: cv13 = load_dataset( "mozilla-foundation/common_voice_13_0", "eo", split=DATA, ) cv13 = cv13.cast_column("audio", Audio(sampling_rate=16000)) processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(MODEL) model = Wav2Vec2ForCTC.from_pretrained(MODEL) print("| Actual
Predicted | CER |") print("|:--------------------|:----|") for i, example in enumerate(cv13): results = evaluate(processor, model, example) print(f"| `{results.actual}`
`{results.predicted}` | {results.cer} |") if __name__ == "__main__": run()