ASR_Model_Comparison / processing.py
j-tobias
added new model
db6e0bb
# Import Libraries to load Models
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import AutoProcessor, SeamlessM4TModel
# Import Libraries to access Datasets
from datasets import load_dataset
from datasets import Audio
# Helper Libraries
import plotly.graph_objs as go
import evaluate
import librosa
import torch
import numpy as np
import pandas as pd
import time
# This constant determines on how many samples the Models are evaluated on
N_SAMPLES = 50
# Load the WER Metric
wer_metric = evaluate.load("wer")
def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str):
"""
Main Function running an entire evaluation cycle
Params:
- data_subset (str) :The name of a valid Dataset to choose from ["Common Voice", "Librispeech ASR clean", "Librispeech ASR other", "OWN Recording/Sample"]
- model_1 (str) :The name of a valid model to choose form ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr", "facebook/wav2vec2-base-960h","openai/whisper-large-v2"]
- model_2 (str) :The name of a valid model to choose form ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr", "facebook/wav2vec2-base-960h","openai/whisper-large-v2"]
- own_audio (gr.Audio) :The return value of an gr.Audio component (sr, audio (as numpy array))
- own_transcription (str) :The paired transcription to the own_audio
"""
# A little bit of Error Handling
if data_subset is None and own_audio is None and own_transcription is None:
raise ValueError("No Dataset selected")
if model_1 is None:
raise ValueError("No Model 1 selected")
if model_2 is None:
raise ValueError("No Model 2 selected")
# Load the selected Dataset but only N_SAMPLES of it
if data_subset == "Common Voice":
dataset, text_column = load_Common_Voice()
elif data_subset == "Librispeech ASR clean":
dataset, text_column = load_Librispeech_ASR_clean()
elif data_subset == "Librispeech ASR other":
dataset, text_column = load_Librispeech_ASR_other()
elif data_subset == "OWN Recording/Sample":
sr, audio = own_audio
audio = audio.astype(np.float32)
print("AUDIO: ", type(audio), audio)
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
else:
# if data_subset is None then still load load_Common_Voice
dataset, text_column = load_Common_Voice()
# I have left the print statements because users have access to the logs in Spaces and this might help to understand what's going on
print("Dataset Loaded")
# Load the selected Models
model1, processor1 = load_model(model_1)
model2, processor2 = load_model(model_2)
print("Models Loaded")
# In case a own Recording is selected only a single sample has to be evaluated
if data_subset == "OWN Recording/Sample":
sample = {"audio":{"array":audio,"sampling_rate":16000}}
inference_times1 = []
inference_times2 = []
time_start = time.time()
transcription1 = model_compute(model1, processor1, sample, model_1)
time_stop = time.time()
duration = time_stop - time_start
inference_times1.append(duration)
time_start = time.time()
transcription2 = model_compute(model2, processor2, sample, model_2)
time_stop = time.time()
duration = time_stop - time_start
inference_times2.append(duration)
transcriptions1 = [transcription1]
transcriptions2 = [transcription2]
references = [own_transcription.lower()]
wer1 = round(N_SAMPLES * compute_wer(references, transcriptions1), 2)
wer2 = round(N_SAMPLES * compute_wer(references, transcriptions2), 2)
results_md = f"""
#### {model_1}
- WER Score: {wer1}
- Avg. Inference Duration: {round(sum(inference_times1)/len(inference_times1), 4)}s
#### {model_2}
- WER Score: {wer2}
- Avg. Inference Duration: {round(sum(inference_times2)/len(inference_times2), 4)}s"""
# Create the bar plot
fig = go.Figure(
data=[
go.Bar(x=[f"{model_1}"], y=[wer1], showlegend=False),
go.Bar(x=[f"{model_2}"], y=[wer2], showlegend=False),
]
)
# Update the layout for better visualization
fig.update_layout(
title="Comparison of Two Models",
xaxis_title="Models",
yaxis_title="Value",
barmode="group",
)
df = pd.DataFrame({"references":references, "transcriptions 1":transcriptions1,"WER 1":[wer1],"transcriptions 2":transcriptions2,"WER 2":[wer2]})
yield results_md, fig, df
# In case a Dataset has been selected
else:
references = []
transcriptions1 = []
transcriptions2 = []
WER1s = []
WER2s = []
inference_times1 = []
inference_times2 = []
counter = 0
for i, sample in enumerate(dataset, start=1):
print(counter)
counter += 1
references.append(sample[text_column])
if model_1 == model_2:
time_start = time.time()
transcription = model_compute(model1, processor1, sample, model_1)
time_stop = time.time()
duration = time_stop - time_start
inference_times1.append(duration)
inference_times2.append(duration)
transcriptions1.append(transcription)
transcriptions2.append(transcription)
else:
time_start = time.time()
transcription1 = model_compute(model1, processor1, sample, model_1)
time_stop = time.time()
duration = time_stop - time_start
inference_times1.append(duration)
transcriptions1.append(transcription1)
time_start = time.time()
transcription2 = model_compute(model2, processor2, sample, model_2)
time_stop = time.time()
duration = time_stop - time_start
inference_times2.append(duration)
transcriptions2.append(transcription2)
WER1s.append(round(compute_wer([sample[text_column]], [transcription1]),4))
WER2s.append(round(compute_wer([sample[text_column]], [transcription2]),4))
wer1 = round(sum(WER1s)/len(WER1s), 4)
wer2 = round(sum(WER2s)/len(WER2s), 4)
results_md = f"""
{i}/{len(dataset)}-{'#'*i}{'_'*(N_SAMPLES-i)}
#### {model_1}
- WER Score: {wer1}
- Avg. Inference Duration: {round(sum(inference_times1)/len(inference_times1), 4)}s
#### {model_2}
- WER Score: {wer2}
- Avg. Inference Duration: {round(sum(inference_times2)/len(inference_times2), 4)}s"""
# Create the bar plot
fig = go.Figure(
data=[
go.Bar(x=[f"{model_1}"], y=[wer1], showlegend=False),
go.Bar(x=[f"{model_2}"], y=[wer2], showlegend=False),
]
)
# Update the layout for better visualization
fig.update_layout(
title="Comparison of Two Models",
xaxis_title="Models",
yaxis_title="Value",
barmode="group",
)
df = pd.DataFrame({"references":references, f"{model_1}":transcriptions1,"WER 1":WER1s,f"{model_2}":transcriptions2,"WER 2":WER2s})
yield results_md, fig, df
# DATASET LOADERS
def load_Common_Voice():
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True)
text_column = "sentence"
dataset = dataset.take(N_SAMPLES)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = list(dataset)
for sample in dataset:
sample["text"] = sample["text"].lower()
return dataset, text_column
def load_Librispeech_ASR_clean():
dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True, token=True, trust_remote_code=True)
print(next(iter(dataset)))
text_column = "text"
dataset = dataset.take(N_SAMPLES)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = list(dataset)
for sample in dataset:
sample["text"] = sample["text"].lower()
return dataset, text_column
def load_Librispeech_ASR_other():
dataset = load_dataset("librispeech_asr", "other", split="test", streaming=True, token=True, trust_remote_code=True)
print(next(iter(dataset)))
text_column = "text"
dataset = dataset.take(N_SAMPLES)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = list(dataset)
for sample in dataset:
sample["text"] = sample["text"].lower()
return dataset, text_column
# MODEL LOADERS
def load_model(model_id:str):
if model_id == "openai/whisper-tiny.en":
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
elif model_id == "facebook/s2t-medium-librispeech-asr":
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr")
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr")
elif model_id == "facebook/wav2vec2-base-960h":
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
elif model_id == "openai/whisper-large-v2":
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
model.config.forced_decoder_ids = None
elif model_id == "facebook/hf-seamless-m4t-medium":
processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
model = SeamlessM4TModel.from_pretrained("facebook/hf-seamless-m4t-medium")
else: # In case no model has been selected the Whipser-Tiny.En is selected - just for completeness
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
return model, processor
# MODEL INFERENCE
def model_compute(model, processor, sample, model_id):
if model_id == "openai/whisper-tiny.en":
sample = sample["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
transcription = processor.tokenizer.normalize(transcription[0])
return transcription
elif model_id == "facebook/s2t-medium-librispeech-asr":
sample = sample["audio"]
features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt")
input_features = features.input_features
attention_mask = features.attention_mask
gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)
return transcription[0]
elif model_id == "facebook/wav2vec2-base-960h":
sample = sample["audio"]
input_values = processor(sample["array"], sampling_rate=16000, return_tensors="pt", padding="longest").input_values # Batch size 1
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
return transcription[0].lower()
elif model_id == "openai/whisper-large-v2":
sample = sample["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
transcription = processor.tokenizer.normalize(transcription[0])
print("TRANSCRIPTION Whisper Large v2: ", transcription)
return transcription
elif model_id == "facebook/hf-seamless-m4t-medium":
sample = sample["audio"]
input_data = processor(audios=sample["array"], return_tensors="pt")
output_tokens = model.generate(**input_data, tgt_lang="eng", generate_speech=False)
print(output_tokens)
transcription = processor.decode(output_tokens[0].tolist()[0], skip_special_tokens=True)
return transcription
else: # In case no model has been selected the Whipser-Tiny.En is selected - just for completeness
sample = sample["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
# UTILS
def compute_wer(references, predictions):
wer = wer_metric.compute(references=references, predictions=predictions)
return wer