import json import os import shutil import sys from collections import defaultdict from statistics import mean import pandas as pd import requests from text_normalizer import text_normalizer from utils import compute_average_wer, download_dataset def fetch_evaluation_data(url): """ Fetches evaluation data from the given URL. :param url: The URL to fetch the evaluation data from. :returns: The evaluation data as a dictionary. :rauses: sys.exit if the request fails """ response = requests.get(url) if response.status_code == 200: return json.loads(response.text) else: sys.exit(f"Failed to fetch WhisperKit evals: {response.text}") def get_device_name(device): """ Gets the device name from the device map if it exists. :param device: String representing the device name. :returns: The device name from the device map if it exists, otherwise the input device name. """ with open("dashboard_data/device_map.json", "r") as f: device_map = json.load(f) return device_map.get(device, device).replace(" ", "_") def process_quality_file(file_path, dataset_dfs, quality_results): """ Processes a single quality file and updates the quality_results dictionary. :param file_path: Path to the quality JSON file. :param dataset_dfs: Dictionary of DataFrames containing dataset information. :param quality_results: Dictionary to store the processed quality results. This function reads a quality JSON file, extracts relevant information, and updates the quality_results dictionary with various metrics including WER and Quality of Inference (QoI) for different datasets. """ with open(file_path, "r") as file: test_results = json.load(file) if len(test_results) == 0: return metadata = test_results["metadata"] test_results = test_results["results"] model = file_path.split("/")[-3].replace("_", "/") device = metadata["inference_context"]["device_spec"]["product_name"] device = get_device_name(device) timestamp = file_path.split("/")[-1].split(".")[0] key = model dataset_name = metadata["dataset_name"] for test_result in test_results: audio_file_name = test_result["file"] dataset_key = "Earnings-22" if "earnings22" in dataset_name else "LibriSpeech" dataset_df = dataset_dfs[dataset_key] wer_entry = { "prediction": text_normalizer(test_result["prediction"]), "reference": text_normalizer(test_result["reference"]), } quality_results[key]["timestamp"] = timestamp quality_results[key]["dataset_wer"][dataset_name].append(wer_entry) audio = audio_file_name.split(".")[0] dataset_row = dataset_df.loc[dataset_df["file"].str.contains(audio)].iloc[0] reference_wer = dataset_row["wer"] prediction_wer = test_result["wer"] quality_results[key]["qoi"].append(1 if prediction_wer <= reference_wer else 0) def calculate_and_save_quality_results(quality_results, quality_output_path): """ Calculates final quality metrics and saves them to a JSON file. :param quality_results: Dictionary containing raw quality data. :param quality_output_path: Path to save the processed quality results. This function processes the raw quality data, calculates average metrics, and writes the final results to a JSON file, with each entry representing a unique model's quality metrics across different datasets, including Word Error Rate (WER) and Quality of Inference (QoI). """ with open(quality_output_path, "w") as quality_file: for key, data in quality_results.items(): model = key dataset_wers = { dataset: compute_average_wer(wer) for dataset, wer in data["dataset_wer"].items() } average_wer = ( sum(dataset_wers.values()) / len(dataset_wers) if len(dataset_wers) != 0 else 0 ) quality_entry = { "model": model.replace("_", "/"), "timestamp": data["timestamp"], "average_wer": round(average_wer, 2), "dataset_wer": dataset_wers, "qoi": round(mean(data["qoi"]), 2), } json.dump(quality_entry, quality_file) quality_file.write("\n") def main(): """ Main function to orchestrate the quality data generation process. This function performs the following steps: 1. Downloads quality data if requested. 2. Fetches evaluation data for various datasets. 3. Processes quality files for specific datasets. 4. Calculates and saves quality results, including WER and QoI metrics. """ if len(sys.argv) > 1 and sys.argv[1] == "download": try: shutil.rmtree("english") except: print("Nothing to remove.") download_dataset("argmaxinc/whisperkit-evals", "english", "WhisperKit") datasets = { "Earnings-22": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", "LibriSpeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", "earnings22-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", "librispeech-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", "earnings22-12hours": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json", "librispeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true", } dataset_dfs = {} for dataset_name, url in datasets.items(): evals = fetch_evaluation_data(url) dataset_dfs[dataset_name] = pd.json_normalize(evals["results"]) source_quality_directory = "argmaxinc/english/WhisperKit/" quality_results = defaultdict( lambda: { "average_wer": [], "dataset_wer": defaultdict(list), "qoi": [], "timestamp": None, } ) for subdir, _, files in os.walk(source_quality_directory): dataset = subdir.split("/")[-1] if dataset not in ["earnings22-12hours", "librispeech"]: continue for filename in files: if not filename.endswith(".json"): continue file_path = os.path.join(subdir, filename) process_quality_file(file_path, dataset_dfs, quality_results) calculate_and_save_quality_results( quality_results, "dashboard_data/quality_data.json" ) if __name__ == "__main__": main()