whisper-small-cy / run_eval_whisper_streaming.py
pplantinga's picture
Model save
00d2136
import argparse
from transformers import pipeline
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio
import evaluate
wer_metric = evaluate.load("wer")
def is_target_text_in_range(ref):
if ref.strip() == "ignore time segment in scoring":
return False
else:
return ref.strip() != ""
def get_text(sample):
if "text" in sample:
return sample["text"]
elif "sentence" in sample:
return sample["sentence"]
elif "normalized_text" in sample:
return sample["normalized_text"]
elif "transcript" in sample:
return sample["transcript"]
elif "transcription" in sample:
return sample["transcription"]
else:
raise ValueError(
f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of "
".join{sample.keys()}. Ensure a text column name is present in the dataset."
)
whisper_norm = BasicTextNormalizer()
def normalise(batch):
batch["norm_text"] = whisper_norm(get_text(batch))
return batch
def data(dataset):
for i, item in enumerate(dataset):
yield {**item["audio"], "reference": item["norm_text"]}
def main(args):
batch_size = args.batch_size
whisper_asr = pipeline(
"automatic-speech-recognition", model=args.model_id, device=args.device
)
whisper_asr.model.config.forced_decoder_ids = (
whisper_asr.tokenizer.get_decoder_prompt_ids(
language=args.language, task="transcribe"
)
)
dataset = load_dataset(
args.dataset,
args.config,
split=args.split,
streaming=args.streaming,
use_auth_token=True,
)
# Only uncomment for debugging
dataset = dataset.take(args.max_eval_samples)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.map(normalise)
dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])
predictions = []
references = []
# run streamed inference
for out in whisper_asr(data(dataset), batch_size=batch_size):
predictions.append(whisper_norm(out["text"]))
references.append(out["reference"][0])
wer = wer_metric.compute(references=references, predictions=predictions)
wer = round(100 * wer, 2)
print("WER:", wer)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
type=str,
required=True,
help="Model identifier. Should be loadable with 🤗 Transformers",
)
parser.add_argument(
"--dataset",
type=str,
default="mozilla-foundation/common_voice_11_0",
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
)
parser.add_argument(
"--config",
type=str,
required=True,
help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice",
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Split of the dataset. *E.g.* `'test'`",
)
parser.add_argument(
"--device",
type=int,
default=-1,
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Number of samples to go through each streamed batch.",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
)
parser.add_argument(
"--streaming",
type=bool,
default=True,
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
)
parser.add_argument(
"--language",
type=str,
required=True,
help="Two letter language code for the transcription language, e.g. use 'en' for English.",
)
args = parser.parse_args()
main(args)