File size: 4,604 Bytes
1543414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import json
import os
import shutil
import sys
from collections import defaultdict

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

from utils import compute_average_wer, download_dataset


def main():
    """
    Main function to orchestrate the multilingual data generation process.

    This function performs the following steps:
    1. Downloads multilingual evaluation data if requested.
    2. Processes multilingual evaluation files.
    3. Calculates and saves results, including Word Error Rate (WER) and
       language detection confusion matrices.
    """
    source_repo = "argmaxinc/whisperkit-evals-multilingual"
    source_subfolder = "WhisperKit"
    source_directory = f"{source_repo}/{source_subfolder}"
    if len(sys.argv) > 1 and sys.argv[1] == "download":
        try:
            shutil.rmtree(source_repo)
        except:
            print("Nothing to remove.")
        download_dataset(source_repo, source_repo, source_subfolder)

    results = defaultdict(
        lambda: {
            "average_wer": [],
            "language_wer": defaultdict(list),
            "language_detection": [],
        }
    )

    confusion_matrices = {}

    for subdir, _, files in os.walk(source_directory):
        for filename in files:
            if not filename.endswith(".json") or "summary" in filename:
                continue

            file_path = os.path.join(subdir, filename)
            with open(file_path, "r") as f:
                data = json.load(f)

            subdir_components = subdir.split(os.path.sep)
            is_forced = "forced" in subdir_components
            model = subdir_components[-3] if not is_forced else subdir_components[-4]

            key = f"{model}/{'forced' if is_forced else 'not_forced'}"

            for item in data["results"]:
                if "reference_language" not in item:
                    continue
                reference_language = item["reference_language"]
                wer = item["wer"]
                detected_language = item["predicted_language"]

                result = {
                    "reference": item["reference"],
                    "prediction": item["prediction"],
                }

                results[key]["average_wer"].append(result)
                results[key]["language_wer"][reference_language].append(result)
                results[key]["language_detection"].append(
                    (reference_language, detected_language)
                )

    calculate_and_save_results(results, confusion_matrices)


def calculate_and_save_results(results, confusion_matrices):
    """
    Calculates final multilingual metrics and saves them to CSV and JSON files.

    :param results: Dictionary containing raw multilingual evaluation data.
    :param confusion_matrices: Dictionary to store confusion matrices for language detection.

    This function processes the raw multilingual data, calculates average metrics,
    creates confusion matrices for language detection, and saves the results to:
    1. A CSV file with WER data for each model and language.
    2. A JSON file with confusion matrices for language detection.
    """
    wer_data = []
    for key, data in results.items():
        model, forced = key.rsplit("/", 1)
        row = {
            "Model": model,
            "Forced Tokens": forced == "forced",
            "Average WER": compute_average_wer(data["average_wer"]),
        }
        for lang, wers in data["language_wer"].items():
            row[f"WER_{lang}"] = compute_average_wer(wers)
        wer_data.append(row)

        true_languages, detected_languages = zip(*data["language_detection"])
        unique_languages = sorted(set(true_languages))
        cm = confusion_matrix(
            true_languages, detected_languages, labels=unique_languages
        )

        row_sums = cm.sum(axis=1)
        cm_normalized = np.zeros_like(cm, dtype=float)
        non_zero_rows = row_sums != 0
        cm_normalized[non_zero_rows] = (
            cm[non_zero_rows] / row_sums[non_zero_rows, np.newaxis]
        )

        if model not in confusion_matrices:
            confusion_matrices[model] = {}
        confusion_matrices[model][forced] = {
            "matrix": cm_normalized.tolist(),
            "labels": unique_languages,
        }

    df = pd.DataFrame(wer_data)
    df.to_csv("dashboard_data/multilingual_results.csv", index=False)

    with open("dashboard_data/multilingual_confusion_matrices.json", "w") as f:
        json.dump(confusion_matrices, f, indent=2)


if __name__ == "__main__":
    main()