Spaces:
Running
Running
import json | |
import os | |
import shutil | |
import sys | |
from collections import defaultdict | |
import numpy as np | |
import pandas as pd | |
from sklearn.metrics import confusion_matrix | |
from utils import compute_average_wer, download_dataset | |
def main(): | |
""" | |
Main function to orchestrate the multilingual data generation process. | |
This function performs the following steps: | |
1. Downloads multilingual evaluation data if requested. | |
2. Processes multilingual evaluation files. | |
3. Calculates and saves results, including Word Error Rate (WER) and | |
language detection confusion matrices. | |
""" | |
source_repo = "argmaxinc/whisperkit-evals-multilingual" | |
source_subfolder = "WhisperKit" | |
source_directory = f"{source_repo}/{source_subfolder}" | |
if len(sys.argv) > 1 and sys.argv[1] == "download": | |
try: | |
shutil.rmtree(source_repo) | |
except: | |
print("Nothing to remove.") | |
download_dataset(source_repo, source_repo, source_subfolder) | |
results = defaultdict( | |
lambda: { | |
"average_wer": [], | |
"language_wer": defaultdict(list), | |
"language_detection": [], | |
} | |
) | |
confusion_matrices = {} | |
for subdir, _, files in os.walk(source_directory): | |
for filename in files: | |
if not filename.endswith(".json") or "summary" in filename: | |
continue | |
file_path = os.path.join(subdir, filename) | |
with open(file_path, "r") as f: | |
data = json.load(f) | |
subdir_components = subdir.split(os.path.sep) | |
is_forced = "forced" in subdir_components | |
model = subdir_components[-3] if not is_forced else subdir_components[-4] | |
key = f"{model}/{'forced' if is_forced else 'not_forced'}" | |
for item in data["results"]: | |
if "reference_language" not in item: | |
continue | |
reference_language = item["reference_language"] | |
wer = item["wer"] | |
detected_language = item["predicted_language"] | |
result = { | |
"reference": item["reference"], | |
"prediction": item["prediction"], | |
} | |
results[key]["average_wer"].append(result) | |
results[key]["language_wer"][reference_language].append(result) | |
results[key]["language_detection"].append( | |
(reference_language, detected_language) | |
) | |
calculate_and_save_results(results, confusion_matrices) | |
def calculate_and_save_results(results, confusion_matrices): | |
""" | |
Calculates final multilingual metrics and saves them to CSV and JSON files. | |
:param results: Dictionary containing raw multilingual evaluation data. | |
:param confusion_matrices: Dictionary to store confusion matrices for language detection. | |
This function processes the raw multilingual data, calculates average metrics, | |
creates confusion matrices for language detection, and saves the results to: | |
1. A CSV file with WER data for each model and language. | |
2. A JSON file with confusion matrices for language detection. | |
""" | |
wer_data = [] | |
for key, data in results.items(): | |
model, forced = key.rsplit("/", 1) | |
model = model.replace("_", "/") | |
row = { | |
"Model": model, | |
"Forced Tokens": forced == "forced", | |
"Average WER": compute_average_wer(data["average_wer"]), | |
} | |
for lang, wers in data["language_wer"].items(): | |
row[f"WER_{lang}"] = compute_average_wer(wers) | |
wer_data.append(row) | |
true_languages, detected_languages = zip(*data["language_detection"]) | |
unique_languages = sorted(set(true_languages)) | |
cm = confusion_matrix( | |
true_languages, detected_languages, labels=unique_languages | |
) | |
row_sums = cm.sum(axis=1) | |
cm_normalized = np.zeros_like(cm, dtype=float) | |
non_zero_rows = row_sums != 0 | |
cm_normalized[non_zero_rows] = ( | |
cm[non_zero_rows] / row_sums[non_zero_rows, np.newaxis] | |
) | |
if model not in confusion_matrices: | |
confusion_matrices[model] = {} | |
confusion_matrices[model][forced] = { | |
"matrix": cm_normalized.tolist(), | |
"labels": unique_languages, | |
} | |
df = pd.DataFrame(wer_data) | |
df.to_csv("dashboard_data/multilingual_results.csv", index=False) | |
with open("dashboard_data/multilingual_confusion_matrices.json", "w") as f: | |
json.dump(confusion_matrices, f, indent=2) | |
if __name__ == "__main__": | |
main() | |