whisperkit-benchmarks / quality_generate.py
ardaatahan's picture
initial commit
1543414
raw
history blame
7.31 kB
import json
import os
import shutil
import sys
from collections import defaultdict
from statistics import mean
import pandas as pd
import requests
from text_normalizer import text_normalizer
from utils import compute_average_wer, download_dataset
def fetch_evaluation_data(url):
"""
Fetches evaluation data from the given URL.
:param url: The URL to fetch the evaluation data from.
:returns: The evaluation data as a dictionary.
:rauses: sys.exit if the request fails
"""
response = requests.get(url)
if response.status_code == 200:
return json.loads(response.text)
else:
sys.exit(f"Failed to fetch WhisperKit evals: {response.text}")
def get_device_name(device):
"""
Gets the device name from the device map if it exists.
:param device: String representing the device name.
:returns: The device name from the device map if it exists, otherwise the input device name.
"""
with open("dashboard_data/device_map.json", "r") as f:
device_map = json.load(f)
return device_map.get(device, device).replace(" ", "_")
def process_quality_file(file_path, dataset_dfs, quality_results):
"""
Processes a single quality file and updates the quality_results dictionary.
:param file_path: Path to the quality JSON file.
:param dataset_dfs: Dictionary of DataFrames containing dataset information.
:param quality_results: Dictionary to store the processed quality results.
This function reads a quality JSON file, extracts relevant information,
and updates the quality_results dictionary with various metrics including WER
and Quality of Inference (QoI) for different datasets.
"""
with open(file_path, "r") as file:
test_results = json.load(file)
if len(test_results) == 0:
return
metadata = test_results["metadata"]
test_results = test_results["results"]
model = file_path.split("/")[-3].replace("_", "/")
device = metadata["inference_context"]["device_spec"]["product_name"]
device = get_device_name(device)
timestamp = file_path.split("/")[-1].split(".")[0]
key = model
dataset_name = metadata["dataset_name"]
for test_result in test_results:
audio_file_name = test_result["file"]
dataset_key = "Earnings-22" if "earnings22" in dataset_name else "LibriSpeech"
dataset_df = dataset_dfs[dataset_key]
wer_entry = {
"prediction": text_normalizer(test_result["prediction"]),
"reference": text_normalizer(test_result["reference"]),
}
quality_results[key]["timestamp"] = timestamp
quality_results[key]["dataset_wer"][dataset_name].append(wer_entry)
audio = audio_file_name.split(".")[0]
dataset_row = dataset_df.loc[dataset_df["file"].str.contains(audio)].iloc[0]
reference_wer = dataset_row["wer"]
prediction_wer = test_result["wer"]
quality_results[key]["qoi"].append(1 if prediction_wer <= reference_wer else 0)
def calculate_and_save_quality_results(quality_results, quality_output_path):
"""
Calculates final quality metrics and saves them to a JSON file.
:param quality_results: Dictionary containing raw quality data.
:param quality_output_path: Path to save the processed quality results.
This function processes the raw quality data, calculates average metrics,
and writes the final results to a JSON file, with each entry representing
a unique model's quality metrics across different datasets, including
Word Error Rate (WER) and Quality of Inference (QoI).
"""
with open(quality_output_path, "w") as quality_file:
for key, data in quality_results.items():
model = key
dataset_wers = {
dataset: compute_average_wer(wer)
for dataset, wer in data["dataset_wer"].items()
}
average_wer = (
sum(dataset_wers.values()) / len(dataset_wers)
if len(dataset_wers) != 0
else 0
)
quality_entry = {
"model": model.replace("_", "/"),
"timestamp": data["timestamp"],
"average_wer": round(average_wer, 2),
"dataset_wer": dataset_wers,
"qoi": round(mean(data["qoi"]), 2),
}
json.dump(quality_entry, quality_file)
quality_file.write("\n")
def main():
"""
Main function to orchestrate the quality data generation process.
This function performs the following steps:
1. Downloads quality data if requested.
2. Fetches evaluation data for various datasets.
3. Processes quality files for specific datasets.
4. Calculates and saves quality results, including WER and QoI metrics.
"""
if len(sys.argv) > 1 and sys.argv[1] == "download":
try:
shutil.rmtree("english")
except:
print("Nothing to remove.")
download_dataset("argmaxinc/whisperkit-evals", "english", "WhisperKit")
datasets = {
"Earnings-22": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
"LibriSpeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
"earnings22-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
"librispeech-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
"earnings22-12hours": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
"librispeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
}
dataset_dfs = {}
for dataset_name, url in datasets.items():
evals = fetch_evaluation_data(url)
dataset_dfs[dataset_name] = pd.json_normalize(evals["results"])
source_quality_directory = "argmaxinc/english/WhisperKit/"
quality_results = defaultdict(
lambda: {
"average_wer": [],
"dataset_wer": defaultdict(list),
"qoi": [],
"timestamp": None,
}
)
for subdir, _, files in os.walk(source_quality_directory):
dataset = subdir.split("/")[-1]
if dataset not in ["earnings22-12hours", "librispeech"]:
continue
for filename in files:
if not filename.endswith(".json"):
continue
file_path = os.path.join(subdir, filename)
process_quality_file(file_path, dataset_dfs, quality_results)
calculate_and_save_quality_results(
quality_results, "dashboard_data/quality_data.json"
)
if __name__ == "__main__":
main()