whisperkit-benchmarks / multilingual_generate.py
ardaatahan's picture
initial commit
1543414
raw
history blame
4.6 kB
import json
import os
import shutil
import sys
from collections import defaultdict
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from utils import compute_average_wer, download_dataset
def main():
"""
Main function to orchestrate the multilingual data generation process.
This function performs the following steps:
1. Downloads multilingual evaluation data if requested.
2. Processes multilingual evaluation files.
3. Calculates and saves results, including Word Error Rate (WER) and
language detection confusion matrices.
"""
source_repo = "argmaxinc/whisperkit-evals-multilingual"
source_subfolder = "WhisperKit"
source_directory = f"{source_repo}/{source_subfolder}"
if len(sys.argv) > 1 and sys.argv[1] == "download":
try:
shutil.rmtree(source_repo)
except:
print("Nothing to remove.")
download_dataset(source_repo, source_repo, source_subfolder)
results = defaultdict(
lambda: {
"average_wer": [],
"language_wer": defaultdict(list),
"language_detection": [],
}
)
confusion_matrices = {}
for subdir, _, files in os.walk(source_directory):
for filename in files:
if not filename.endswith(".json") or "summary" in filename:
continue
file_path = os.path.join(subdir, filename)
with open(file_path, "r") as f:
data = json.load(f)
subdir_components = subdir.split(os.path.sep)
is_forced = "forced" in subdir_components
model = subdir_components[-3] if not is_forced else subdir_components[-4]
key = f"{model}/{'forced' if is_forced else 'not_forced'}"
for item in data["results"]:
if "reference_language" not in item:
continue
reference_language = item["reference_language"]
wer = item["wer"]
detected_language = item["predicted_language"]
result = {
"reference": item["reference"],
"prediction": item["prediction"],
}
results[key]["average_wer"].append(result)
results[key]["language_wer"][reference_language].append(result)
results[key]["language_detection"].append(
(reference_language, detected_language)
)
calculate_and_save_results(results, confusion_matrices)
def calculate_and_save_results(results, confusion_matrices):
"""
Calculates final multilingual metrics and saves them to CSV and JSON files.
:param results: Dictionary containing raw multilingual evaluation data.
:param confusion_matrices: Dictionary to store confusion matrices for language detection.
This function processes the raw multilingual data, calculates average metrics,
creates confusion matrices for language detection, and saves the results to:
1. A CSV file with WER data for each model and language.
2. A JSON file with confusion matrices for language detection.
"""
wer_data = []
for key, data in results.items():
model, forced = key.rsplit("/", 1)
row = {
"Model": model,
"Forced Tokens": forced == "forced",
"Average WER": compute_average_wer(data["average_wer"]),
}
for lang, wers in data["language_wer"].items():
row[f"WER_{lang}"] = compute_average_wer(wers)
wer_data.append(row)
true_languages, detected_languages = zip(*data["language_detection"])
unique_languages = sorted(set(true_languages))
cm = confusion_matrix(
true_languages, detected_languages, labels=unique_languages
)
row_sums = cm.sum(axis=1)
cm_normalized = np.zeros_like(cm, dtype=float)
non_zero_rows = row_sums != 0
cm_normalized[non_zero_rows] = (
cm[non_zero_rows] / row_sums[non_zero_rows, np.newaxis]
)
if model not in confusion_matrices:
confusion_matrices[model] = {}
confusion_matrices[model][forced] = {
"matrix": cm_normalized.tolist(),
"labels": unique_languages,
}
df = pd.DataFrame(wer_data)
df.to_csv("dashboard_data/multilingual_results.csv", index=False)
with open("dashboard_data/multilingual_confusion_matrices.json", "w") as f:
json.dump(confusion_matrices, f, indent=2)
if __name__ == "__main__":
main()