File size: 7,312 Bytes
1543414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import json
import os
import shutil
import sys
from collections import defaultdict
from statistics import mean

import pandas as pd
import requests

from text_normalizer import text_normalizer
from utils import compute_average_wer, download_dataset


def fetch_evaluation_data(url):
    """
    Fetches evaluation data from the given URL.
    :param url: The URL to fetch the evaluation data from.
    :returns: The evaluation data as a dictionary.
    :rauses: sys.exit if the request fails
    """
    response = requests.get(url)
    if response.status_code == 200:
        return json.loads(response.text)
    else:
        sys.exit(f"Failed to fetch WhisperKit evals: {response.text}")


def get_device_name(device):
    """
    Gets the device name from the device map if it exists.
    :param device: String representing the device name.
    :returns: The device name from the device map if it exists, otherwise the input device name.
    """
    with open("dashboard_data/device_map.json", "r") as f:
        device_map = json.load(f)
    return device_map.get(device, device).replace(" ", "_")


def process_quality_file(file_path, dataset_dfs, quality_results):
    """
    Processes a single quality file and updates the quality_results dictionary.

    :param file_path: Path to the quality JSON file.
    :param dataset_dfs: Dictionary of DataFrames containing dataset information.
    :param quality_results: Dictionary to store the processed quality results.

    This function reads a quality JSON file, extracts relevant information,
    and updates the quality_results dictionary with various metrics including WER
    and Quality of Inference (QoI) for different datasets.
    """
    with open(file_path, "r") as file:
        test_results = json.load(file)

    if len(test_results) == 0:
        return

    metadata = test_results["metadata"]
    test_results = test_results["results"]
    model = file_path.split("/")[-3].replace("_", "/")
    device = metadata["inference_context"]["device_spec"]["product_name"]
    device = get_device_name(device)
    timestamp = file_path.split("/")[-1].split(".")[0]
    key = model
    dataset_name = metadata["dataset_name"]

    for test_result in test_results:
        audio_file_name = test_result["file"]

        dataset_key = "Earnings-22" if "earnings22" in dataset_name else "LibriSpeech"
        dataset_df = dataset_dfs[dataset_key]

        wer_entry = {
            "prediction": text_normalizer(test_result["prediction"]),
            "reference": text_normalizer(test_result["reference"]),
        }
        quality_results[key]["timestamp"] = timestamp
        quality_results[key]["dataset_wer"][dataset_name].append(wer_entry)

        audio = audio_file_name.split(".")[0]
        dataset_row = dataset_df.loc[dataset_df["file"].str.contains(audio)].iloc[0]
        reference_wer = dataset_row["wer"]
        prediction_wer = test_result["wer"]

        quality_results[key]["qoi"].append(1 if prediction_wer <= reference_wer else 0)


def calculate_and_save_quality_results(quality_results, quality_output_path):
    """
    Calculates final quality metrics and saves them to a JSON file.

    :param quality_results: Dictionary containing raw quality data.
    :param quality_output_path: Path to save the processed quality results.

    This function processes the raw quality data, calculates average metrics,
    and writes the final results to a JSON file, with each entry representing
    a unique model's quality metrics across different datasets, including
    Word Error Rate (WER) and Quality of Inference (QoI).
    """
    with open(quality_output_path, "w") as quality_file:
        for key, data in quality_results.items():
            model = key

            dataset_wers = {
                dataset: compute_average_wer(wer)
                for dataset, wer in data["dataset_wer"].items()
            }
            average_wer = (
                sum(dataset_wers.values()) / len(dataset_wers)
                if len(dataset_wers) != 0
                else 0
            )

            quality_entry = {
                "model": model.replace("_", "/"),
                "timestamp": data["timestamp"],
                "average_wer": round(average_wer, 2),
                "dataset_wer": dataset_wers,
                "qoi": round(mean(data["qoi"]), 2),
            }

            json.dump(quality_entry, quality_file)
            quality_file.write("\n")


def main():
    """
    Main function to orchestrate the quality data generation process.

    This function performs the following steps:
    1. Downloads quality data if requested.
    2. Fetches evaluation data for various datasets.
    3. Processes quality files for specific datasets.
    4. Calculates and saves quality results, including WER and QoI metrics.
    """
    if len(sys.argv) > 1 and sys.argv[1] == "download":
        try:
            shutil.rmtree("english")
        except:
            print("Nothing to remove.")
        download_dataset("argmaxinc/whisperkit-evals", "english", "WhisperKit")

    datasets = {
        "Earnings-22": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
        "LibriSpeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
        "earnings22-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
        "librispeech-10mins": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
        "earnings22-12hours": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/earnings22/2024-03-04_13%3A39%3A42_GMT-0800.json",
        "librispeech": "https://huggingface.co/datasets/argmaxinc/whisperkit-evals/resolve/main/WhisperOpenAIAPI/openai_whisper-large-v2/librispeech/2024-02-28_18%3A45%3A02_GMT-0800.json?download=true",
    }

    dataset_dfs = {}
    for dataset_name, url in datasets.items():
        evals = fetch_evaluation_data(url)
        dataset_dfs[dataset_name] = pd.json_normalize(evals["results"])

    source_quality_directory = "argmaxinc/english/WhisperKit/"

    quality_results = defaultdict(
        lambda: {
            "average_wer": [],
            "dataset_wer": defaultdict(list),
            "qoi": [],
            "timestamp": None,
        }
    )

    for subdir, _, files in os.walk(source_quality_directory):
        dataset = subdir.split("/")[-1]
        if dataset not in ["earnings22-12hours", "librispeech"]:
            continue

        for filename in files:
            if not filename.endswith(".json"):
                continue

            file_path = os.path.join(subdir, filename)
            process_quality_file(file_path, dataset_dfs, quality_results)

    calculate_and_save_quality_results(
        quality_results, "dashboard_data/quality_data.json"
    )


if __name__ == "__main__":
    main()