Spaces:
Running
Running
File size: 7,312 Bytes
1543414 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import json
import os
import shutil
import sys
from collections import defaultdict
from statistics import mean
import pandas as pd
import requests
from text_normalizer import text_normalizer
from utils import compute_average_wer, download_dataset
def fetch_evaluation_data(url):
"""
Fetches evaluation data from the given URL.
:param url: The URL to fetch the evaluation data from.
:returns: The evaluation data as a dictionary.
:rauses: sys.exit if the request fails
"""
response = requests.get(url)
if response.status_code == 200:
return json.loads(response.text)
else:
sys.exit(f"Failed to fetch WhisperKit evals: {response.text}")
def get_device_name(device):
"""
Gets the device name from the device map if it exists.
:param device: String representing the device name.
:returns: The device name from the device map if it exists, otherwise the input device name.
"""
with open("dashboard_data/device_map.json", "r") as f:
device_map = json.load(f)
return device_map.get(device, device).replace(" ", "_")
def process_quality_file(file_path, dataset_dfs, quality_results):
"""
Processes a single quality file and updates the quality_results dictionary.
:param file_path: Path to the quality JSON file.
:param dataset_dfs: Dictionary of DataFrames containing dataset information.
:param quality_results: Dictionary to store the processed quality results.
This function reads a quality JSON file, extracts relevant information,
and updates the quality_results dictionary with various metrics including WER
and Quality of Inference (QoI) for different datasets.
"""
with open(file_path, "r") as file:
test_results = json.load(file)
if len(test_results) == 0:
return
metadata = test_results["metadata"]
test_results = test_results["results"]
model = file_path.split("/")[-3].replace("_", "/")
device = metadata["inference_context"]["device_spec"]["product_name"]
device = get_device_name(device)
timestamp = file_path.split("/")[-1].split(".")[0]
key = model
dataset_name = metadata["dataset_name"]
for test_result in test_results:
audio_file_name = test_result["file"]
dataset_key = "Earnings-22" if "earnings22" in dataset_name else "LibriSpeech"
dataset_df = dataset_dfs[dataset_key]
wer_entry = {
"prediction": text_normalizer(test_result["prediction"]),
"reference": text_normalizer(test_result["reference"]),
}
quality_results[key]["timestamp"] = timestamp
quality_results[key]["dataset_wer"][dataset_name].append(wer_entry)
audio = audio_file_name.split(".")[0]
dataset_row = dataset_df.loc[dataset_df["file"].str.contains(audio)].iloc[0]
reference_wer = dataset_row["wer"]
prediction_wer = test_result["wer"]
quality_results[key]["qoi"].append(1 if prediction_wer <= reference_wer else 0)
def calculate_and_save_quality_results(quality_results, quality_output_path):
"""
Calculates final quality metrics and saves them to a JSON file.
:param quality_results: Dictionary containing raw quality data.
:param quality_output_path: Path to save the processed quality results.
This function processes the raw quality data, calculates average metrics,
and writes the final results to a JSON file, with each entry representing
a unique model's quality metrics across different datasets, including
Word Error Rate (WER) and Quality of Inference (QoI).
"""
with open(quality_output_path, "w") as quality_file:
for key, data in quality_results.items():
model = key
dataset_wers = {
dataset: compute_average_wer(wer)
for dataset, wer in data["dataset_wer"].items()
}
average_wer = (
sum(dataset_wers.values()) / len(dataset_wers)
if len(dataset_wers) != 0
else 0
)
quality_entry = {
"model": model.replace("_", "/"),
"timestamp": data["timestamp"],
"average_wer": round(average_wer, 2),
"dataset_wer": dataset_wers,
"qoi": round(mean(data["qoi"]), 2),
}
json.dump(quality_entry, quality_file)
quality_file.write("\n")
def main():
"""
Main function to orchestrate the quality data generation process.
This function performs the following steps:
1. Downloads quality data if requested.
2. Fetches evaluation data for various datasets.
3. Processes quality files for specific datasets.
4. Calculates and saves quality results, including WER and QoI metrics.
"""
if len(sys.argv) > 1 and sys.argv[1] == "download":
try:
shutil.rmtree("english")
except:
print("Nothing to remove.")
download_dataset("argmaxinc/whisperkit-evals", "english", "WhisperKit")
datasets = {
"Earnings-22": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
"LibriSpeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
"earnings22-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
"librispeech-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
"earnings22-12hours": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
"librispeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
}
dataset_dfs = {}
for dataset_name, url in datasets.items():
evals = fetch_evaluation_data(url)
dataset_dfs[dataset_name] = pd.json_normalize(evals["results"])
source_quality_directory = "argmaxinc/english/WhisperKit/"
quality_results = defaultdict(
lambda: {
"average_wer": [],
"dataset_wer": defaultdict(list),
"qoi": [],
"timestamp": None,
}
)
for subdir, _, files in os.walk(source_quality_directory):
dataset = subdir.split("/")[-1]
if dataset not in ["earnings22-12hours", "librispeech"]:
continue
for filename in files:
if not filename.endswith(".json"):
continue
file_path = os.path.join(subdir, filename)
process_quality_file(file_path, dataset_dfs, quality_results)
calculate_and_save_quality_results(
quality_results, "dashboard_data/quality_data.json"
)
if __name__ == "__main__":
main()
|