Spaces:
Running
Running
import colorsys | |
import json | |
import os | |
import random | |
from concurrent.futures import ThreadPoolExecutor | |
from dataclasses import dataclass, make_dataclass | |
from datetime import datetime | |
from io import BytesIO | |
import aiohttp | |
import evaluate | |
import numpy as np | |
import pandas as pd | |
import plotly.graph_objects as go | |
from huggingface_hub import hf_hub_download, list_repo_files | |
from pydub import AudioSegment | |
from constants import WHISPER_OPEN_AI_LINK | |
# Load the Word Error Rate (WER) metric from the evaluate library | |
wer_metric = evaluate.load("wer") | |
def compute_average_wer(results): | |
""" | |
Compute the average Word Error Rate (WER) for a list of transcription results. | |
:param results: List of dictionaries, each containing 'reference' and 'prediction' keys | |
:return: Average WER as a percentage, rounded to 2 decimal places | |
This function calculates the WER for each reference-prediction pair and returns | |
the average. If no predictions are provided, it returns 100% WER. | |
""" | |
references = [result["reference"] for result in results] | |
predictions = [result["prediction"] for result in results] | |
if len(predictions) == 0: | |
return 1 | |
return round( | |
wer_metric.compute(references=references, predictions=predictions) * 100.0, | |
2, | |
) | |
def read_json_line_by_line(file_path): | |
""" | |
Read a JSON file line by line, parsing each line as a separate JSON object. | |
:param file_path: Path to the JSON file | |
:return: List of parsed JSON objects | |
This function is useful for reading large JSON files that contain one JSON object | |
per line. It handles JSON parsing errors gracefully, skipping invalid lines. | |
""" | |
data = [] | |
with open(file_path, "r") as f: | |
for line in f: | |
try: | |
item = json.loads(line.strip()) | |
data.append(item) | |
except json.JSONDecodeError: | |
print(f"Skipping invalid JSON in {file_path}: {line}") | |
return data | |
def group_wer(group): | |
""" | |
Calculate the Word Error Rate (WER) for a group of transcriptions. | |
:param group: DataFrame group containing 'normalized_reference' and 'normalized_prediction' columns | |
:return: Average WER for the group | |
This function is typically used with DataFrame groupby operations to calculate | |
WER for specific groups of transcriptions. | |
""" | |
return compute_average_wer( | |
group[["normalized_reference", "normalized_prediction"]] | |
.rename( | |
columns={ | |
"normalized_reference": "reference", | |
"normalized_prediction": "prediction", | |
} | |
) | |
.to_dict("records") | |
) | |
def load_multilingual_results(csv_file): | |
""" | |
Load multilingual results from a CSV file into a pandas DataFrame. | |
:param csv_file: Path to the CSV file containing multilingual results | |
:return: DataFrame with the loaded results, or None if the file is not found | |
This function attempts to load a CSV file using pandas, handling potential | |
FileNotFoundError exceptions. | |
""" | |
try: | |
df = pd.json_normalize(csv_file) | |
return df | |
except FileNotFoundError: | |
return None | |
def download_dataset(repo_id, local_dir, remote_dir, path_includes=""): | |
""" | |
Download benchmark result files from a specified Hugging Face repository to a local directory. | |
:param repo_id: ID of the Hugging Face repository | |
:param local_dir: Local directory where downloaded files will be saved | |
:param remote_dir: Remote directory within the repository to download from | |
This function uses the Hugging Face Hub API to list and download files from a | |
specific directory in a repository. It forces the download to ensure up-to-date files. | |
""" | |
files = list_repo_files(repo_id, repo_type="dataset") | |
directory_files = [ | |
file for file in files if file.startswith(remote_dir) and path_includes in file | |
] | |
with ThreadPoolExecutor() as executor: | |
executor.map( | |
lambda file: hf_hub_download( | |
repo_id=repo_id, | |
repo_type="dataset", | |
filename=file, | |
local_dir=local_dir, | |
force_download=True, | |
), | |
directory_files, | |
) | |
def process_file(file_path): | |
""" | |
Process a file containing JSON objects delimited by new lines. | |
:param file_path: Path to the file to be processed | |
:return: List of dictionaries, each representing a parsed JSON object | |
This function reads the file line by line, parsing each line as a JSON object. | |
It handles potential JSON decoding errors, printing error messages for invalid lines. | |
""" | |
data = [] | |
with open(file_path, "r") as file: | |
for line in file: | |
line = line.strip() | |
if not line: | |
continue | |
try: | |
json_obj = json.loads(line) | |
data.append(json_obj) | |
except json.JSONDecodeError as e: | |
print(f"Error decoding JSON in line: {line}") | |
print(f"Error message: {str(e)}") | |
return data | |
def dir_to_json(root_dir, output_file): | |
""" | |
Convert a directory of benchmark result files to a single JSON file. | |
:param root_dir: Root directory containing the benchmark result files | |
:param output_file: Output file where the JSON data will be saved | |
This function walks through the directory structure, processes each file, | |
and writes the combined data to a single JSON file. It extracts metadata | |
from the file path and includes it in the JSON output. | |
""" | |
with open(output_file, "w") as outfile: | |
for subdir, _, files in os.walk(root_dir): | |
for file in files: | |
file_path = os.path.join(subdir, file) | |
# ignore .DS_Store and summary files | |
if file_path.endswith(".DS_Store") or "summary" in file_path: | |
continue | |
parts = file_path.split(os.sep) | |
print(parts) | |
model_version = parts[2] | |
device_name = parts[3].replace("_", " ") | |
os_type_version = parts[4] | |
dataset_name = parts[5] | |
timestamp_commit = parts[6].replace(".json", "") | |
timestamp, commit_hash, commit_timestamp = timestamp_commit.split("_") | |
data_list = process_file(file_path) | |
for data in data_list: | |
original_entry = { | |
"model": model_version.replace("_", "/"), | |
"device": device_name, | |
"os": os_type_version.replace("_", " "), | |
"wer": data["wer"], | |
"dataset_name": dataset_name, | |
"reference_transcription": data["reference_transcription"], | |
"prediction_transcription": data["prediction_transcription"], | |
"difference_transcription": data["difference_transcription"], | |
"audio_file_url": data["audio_file_url"], | |
"timestamp": timestamp.replace("-", ":").replace(":", "-", 2), | |
"commit_hash": commit_hash, | |
"commit_timestamp": commit_timestamp, | |
} | |
outfile.write(json.dumps(original_entry) + "\n") | |
async def download_audio_to_ndarray(url): | |
""" | |
Downloads an audio file from a URL and converts it to a NumPy array. | |
:param url: The URL of the audio file to download | |
:return: A tuple containing the sample rate and audio data as a NumPy array | |
This asynchronous function uses aiohttp to download the audio file, | |
converts it to an AudioSegment, and then to a NumPy array. It handles | |
both mono and stereo audio files. | |
""" | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url) as response: | |
if response.status == 200: | |
audio_bytes = BytesIO(await response.read()) | |
audio = AudioSegment.from_file(audio_bytes, format="mp3") | |
audio_data = np.array(audio.get_array_of_samples()) | |
if audio.channels == 2: | |
audio_data = audio_data.reshape((-1, 2)) | |
return audio.frame_rate, audio_data | |
else: | |
return None, None | |
async def play_audio(url): | |
""" | |
Wrapper function for Gradio to play audio from a URL. | |
:param url: The URL of the audio file to play | |
:return: A tuple of sample rate and audio data, or an error message | |
This function uses download_audio_to_ndarray to get the audio data | |
and returns it in a format suitable for Gradio's audio player. | |
""" | |
sample_rate, audio_data = await download_audio_to_ndarray(url) | |
if audio_data is None: | |
return "Error downloading the file" | |
else: | |
return sample_rate, audio_data | |
def get_filter_cond(df, model, device, os, dataset, timestamp=None): | |
""" | |
Creates a filter condition for a DataFrame based on specified parameters. | |
:param df: DataFrame containing the transcription data | |
:param model: String representing the model name | |
:param device: String representing the device name | |
:param os: String representing the OS name | |
:param dataset: String representing the dataset name | |
:param timestamp: Optional timestamp for filtering (default: None) | |
:return: A boolean mask for filtering the DataFrame | |
This function constructs a complex boolean condition for filtering | |
the DataFrame based on the provided parameters. | |
""" | |
filter_cond = ( | |
(df["model"] == model) | |
& (df["device"] == device) | |
& (df["os"] == os) | |
& (df["dataset_name"] == dataset) | |
) | |
return filter_cond & (df["timestamp"] == timestamp) if timestamp else filter_cond | |
def get_filtered_transcript(df, model, device, os, dataset, timestamp): | |
""" | |
Retrieves filtered transcription data from a DataFrame. | |
:param df: DataFrame containing the transcription data | |
:param model: String representing the model name | |
:param device: String representing the device name | |
:param os: String representing the OS name | |
:param dataset: String representing the dataset name | |
:param timestamp: String representing the timestamp | |
:return: A filtered DataFrame with transcription data | |
This function applies a filter to the input DataFrame and returns | |
relevant columns for transcription analysis. | |
""" | |
filter_cond = get_filter_cond(df, model, device, os, dataset, timestamp) | |
df = df[filter_cond][ | |
[ | |
"reference_transcription", | |
"prediction_transcription", | |
"difference_transcription", | |
"audio_file_url", | |
] | |
] | |
return df | |
def get_filtered_timestamps(df, model, device, os, dataset): | |
""" | |
Retrieves unique timestamps for a specific model, device, OS, and dataset combination. | |
:param df: DataFrame containing the transcription data | |
:param model: String representing the model name | |
:param device: String representing the device name | |
:param os: String representing the OS name | |
:param dataset: String representing the dataset name | |
:return: A filtered DataFrame containing unique timestamps | |
This function is useful for getting a list of available timestamps | |
for a specific configuration, which can be used for further analysis or UI elements. | |
""" | |
filter_cond = get_filter_cond(df, model, device, os, dataset) | |
df = df[filter_cond][["timestamp"]].drop_duplicates() | |
return df | |
def make_model_name_clickable_link(model): | |
""" | |
Creates an HTML link to the Hugging Face model page. | |
:param model: String representing the model name | |
:return: An HTML string containing a clickable link to the model page | |
This function generates a formatted HTML link that can be used in | |
web interfaces to provide direct access to the model's page on Hugging Face. | |
""" | |
return f"""<a style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" href="https://huggingface.co/argmaxinc/whisperkit-coreml/tree/main/{model.replace('/', '_')}" target="_blank">{model}</a>""" | |
def make_dataset_wer_clickable_link(row, dataset): | |
""" | |
Creates a clickable link for the WER value of a dataset. | |
:param row: Row containing the dataset WER value | |
:param dataset: String representing the dataset name | |
:return: An HTML string containing a clickable link to the dataset's WER details | |
This function generates a formatted HTML link that can be used in | |
web interfaces to provide access to detailed WER information for a specific dataset. | |
""" | |
dataset_column = f"{dataset}" | |
href = WHISPER_OPEN_AI_LINK.format( | |
row["Model"].replace("/", "_"), | |
dataset, | |
) | |
return f'<a style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" href="{href}">{row[dataset_column]}</a>' | |
def make_timestamp_clickable_link(model, dataset, timestamp): | |
""" | |
Creates a clickable link for a timestamp. | |
:param model: String representing the model name | |
:param dataset: String representing the dataset name | |
:param timestamp: Timestamp to be displayed and used in the link | |
:return: An HTML string containing a clickable div for the timestamp | |
This function generates a formatted HTML div that can be used as a clickable | |
element in web interfaces, typically for displaying and interacting with specific timestamps. | |
""" | |
elem_id = ( | |
f"{dataset}-{model}-{timestamp}".replace(" ", "_") | |
.replace('"', "") | |
.replace("'", "") | |
.replace(",", "") | |
) | |
onclick = f"onclick=\"document.getElementById('{elem_id}').click();\"" | |
return f'<div style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" {onclick} href="#">{timestamp}</div>' | |
def make_multilingual_model_clickable_link(model): | |
""" | |
Creates a clickable link for a multilingual model name. | |
:param model: String representing the model name | |
:return: An HTML string containing a clickable div for the model name | |
This function generates a formatted HTML div that can be used as a clickable | |
element in web interfaces, typically for displaying and interacting with multilingual model names. | |
""" | |
elem_id = ( | |
f"{model}".replace(" ", "_").replace('"', "").replace("'", "").replace(",", "") | |
) | |
onclick = f"onclick=\"document.getElementById('{elem_id}').click();console.log('hello');\"" | |
return f'<div style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" {onclick} href="#">{model}</div>' | |
def plot_metric( | |
df, y_axis_col, y_axis_title, fig_title, filter_input=None, exclude_input=None | |
): | |
""" | |
Plots a metric for each model-device-OS group in a DataFrame. | |
:param df: DataFrame containing the benchmark data | |
:param y_axis_col: DataFrame column to use as the y-axis | |
:param y_axis_title: Display name for the y-axis | |
:param fig_title: Display title for the figure | |
:param filter_input: Optional string to filter the model-device-OS combinations | |
:param exclude_input: Optional string to exclude model-device-OS combinations | |
:return: A Plotly figure object | |
""" | |
grouped = df.groupby(["model", "device", "os"]) | |
sorted_groups = [group.sort_values("commit_timestamp") for _, group in grouped] | |
if filter_input: | |
filters = [f.strip().lower() for f in filter_input.split(";")] | |
sorted_groups = [ | |
group | |
for group in sorted_groups | |
if any( | |
f | |
in f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}".lower() | |
for f in filters | |
) | |
] | |
if exclude_input: | |
excludes = [e.strip().lower() for e in exclude_input.split(";")] | |
sorted_groups = [ | |
group | |
for group in sorted_groups | |
if not any( | |
e | |
in f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}".lower() | |
for e in excludes | |
) | |
] | |
base_colors = ["#4542f4", "#0e0c06", "#ccf0a7", "#ff7f4e", "#ffd15a"] | |
num_colors = len(sorted_groups) | |
random_colors = generate_random_colors(base_colors, num_colors) | |
fig = go.Figure() | |
for i, group in enumerate(sorted_groups): | |
model_device_os = ( | |
f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}" | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=group["commit_timestamp"].apply( | |
lambda x: datetime.strptime(x, "%Y-%m-%dT%H%M%S").strftime( | |
"%Y-%m-%d %H:%M:%S" | |
) | |
), | |
y=group[y_axis_col], | |
mode="lines+markers", | |
name=model_device_os, | |
line=dict(color=random_colors[i % len(random_colors)]), | |
marker=dict(color=random_colors[i % len(random_colors)]), | |
hovertemplate=( | |
f"<b>{model_device_os}</b><br>" | |
"Timestamp: %{x}<br>" | |
f"{y_axis_title}: %{{y:.2f}}<br>" | |
"<extra></extra>" | |
), | |
) | |
) | |
fig.update_layout( | |
title=fig_title, | |
xaxis_title="Commit Timestamp", | |
yaxis_title=y_axis_title, | |
legend_title="Model-Device-OS", | |
width=1100, | |
height=600, | |
plot_bgcolor="rgb(250,249,244)", | |
) | |
return fig | |
def fields(raw_class): | |
""" | |
Returns the fields of a dataclass. | |
:param raw_class: The dataclass to inspect | |
:return: List of fields in the dataclass | |
This utility function extracts and returns all the fields defined in a dataclass, | |
excluding special methods and attributes. | |
""" | |
return [ | |
v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__" | |
] | |
def get_os_name_and_version(os_string): | |
""" | |
Extracts the OS name and major version from a string. | |
:param os_string: String representing the OS name and version | |
:return: Formatted string with OS name and major version | |
This function splits the input string into OS name and version, | |
then returns a formatted string with just the major version number. | |
""" | |
os_name, os_version = os_string.split() | |
os_version = os_version.split(".")[0] | |
return f"{os_name} {os_version}" | |
def create_initial_quality_column_dict(): | |
""" | |
Creates the initial column dictionary for the quality table. | |
:return: A list of column dictionaries | |
This function defines the basic structure of the quality table, | |
including columns for model, average WER, and QoI (Quality of Implementation). | |
""" | |
return [ | |
[ | |
"model", | |
ColumnContent, | |
ColumnContent("Model", "html", True, never_hidden=True), | |
], | |
["average_wer", ColumnContent, ColumnContent("Average WER", "html", True)], | |
["qoi", ColumnContent, ColumnContent("QoI", "html", True)], | |
] | |
def calculate_parity(m2_ultra_wer, row): | |
""" | |
Calculates the WER parity between M2 Ultra and the current model. | |
:param m2_ultra_wer: DataFrame containing WER values for M2 Ultra | |
:param row: Current row being processed | |
:return: WER difference between M2 Ultra and current model, or None if not applicable | |
This function computes the percentage difference in WER between the M2 Ultra model | |
and the current model, providing a measure of relative performance. | |
""" | |
if row["Model"] in m2_ultra_wer.index: | |
return round(m2_ultra_wer[row["Model"]] - row["Average WER"], 2) | |
return None | |
def create_initial_performance_column_dict(): | |
""" | |
Creates the initial column dictionary for the performance table. | |
:return: A list of column dictionaries | |
This function defines the basic structure of the performance table, | |
including columns for model, device, OS, parity, average WER, QoI, speed, and tokens per second. | |
""" | |
return [ | |
[ | |
"model", | |
ColumnContent, | |
ColumnContent("Model", "html", True, never_hidden=True), | |
], | |
[ | |
"device", | |
ColumnContent, | |
ColumnContent("Device", "html", True, never_hidden=True), | |
], | |
["os", ColumnContent, ColumnContent("OS", "html", True, never_hidden=True)], | |
["parity", ColumnContent, ColumnContent("Parity %", "html", False)], | |
["average_wer", ColumnContent, ColumnContent("Average WER", "html", False)], | |
["qoi", ColumnContent, ColumnContent("QoI", "html", False)], | |
["speed", ColumnContent, ColumnContent("Speed", "html", False)], | |
["toks", ColumnContent, ColumnContent("Tok / s", "html", False)], | |
] | |
def add_datasets_to_quality_columns(column_dict, datasets): | |
""" | |
Adds dataset-specific columns to the quality table column dictionary. | |
:param column_dict: The initial column dictionary | |
:param datasets: List of dataset names to add | |
:return: A dictionary containing the updated column dictionary and related metadata | |
This function extends the quality table structure with columns for each dataset, | |
and creates a dataclass to represent the table structure. It also generates | |
metadata about the columns for use in the UI. | |
""" | |
updated_column_dict = column_dict.copy() | |
for dataset in datasets: | |
field_name = dataset.replace("-", "") | |
updated_column_dict.append( | |
[field_name, ColumnContent, ColumnContent(dataset, "html", True)] | |
) | |
AutoEvalColumn = make_dataclass("AutoEvalColumn", updated_column_dict, frozen=True) | |
COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden] | |
TYPES = [c.type for c in fields(AutoEvalColumn) if not c.hidden] | |
ALWAYS_HERE_COLS = [c.name for c in fields(AutoEvalColumn) if c.never_hidden] | |
TOGGLE_COLS = [c.name for c in fields(AutoEvalColumn) if not c.never_hidden] | |
SELECTED_COLS = [ | |
c.name | |
for c in fields(AutoEvalColumn) | |
if not c.never_hidden and c.displayed_by_default | |
] | |
return { | |
"column_dict": updated_column_dict, | |
"AutoEvalColumn": AutoEvalColumn, | |
"COLS": COLS, | |
"TYPES": TYPES, | |
"ALWAYS_HERE_COLS": ALWAYS_HERE_COLS, | |
"TOGGLE_COLS": TOGGLE_COLS, | |
"SELECTED_COLS": SELECTED_COLS, | |
} | |
def add_datasets_to_performance_columns(column_dict, datasets): | |
""" | |
Adds dataset-specific columns to the performance table column dictionary. | |
:param column_dict: The initial column dictionary | |
:param datasets: List of dataset names to add | |
:return: A dictionary containing the updated column dictionary and related metadata | |
This function extends the performance table structure with columns for each dataset, | |
adding both speed and tokens per second metrics. It also creates a dataclass to | |
represent the table structure and generates metadata about the columns for use in the UI. | |
""" | |
updated_column_dict = column_dict.copy() | |
for dataset in datasets: | |
field_name = dataset.replace("-", "") | |
updated_column_dict.append( | |
[ | |
f"{field_name}_speed", | |
ColumnContent, | |
ColumnContent( | |
f"{'Short-Form' if dataset == 'librispeech-10mins' else 'Long-Form'} Speed", | |
"html", | |
True, | |
), | |
] | |
) | |
updated_column_dict.append( | |
[ | |
f"{field_name}_toks", | |
ColumnContent, | |
ColumnContent( | |
f"{'Short-Form' if dataset == 'librispeech-10mins' else 'Long-Form'} Tok/s", | |
"html", | |
True, | |
), | |
] | |
) | |
AutoEvalColumn = make_dataclass("AutoEvalColumn", updated_column_dict, frozen=True) | |
COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden] | |
TYPES = [c.type for c in fields(AutoEvalColumn) if not c.hidden] | |
ALWAYS_HERE_COLS = [c.name for c in fields(AutoEvalColumn) if c.never_hidden] | |
TOGGLE_COLS = [c.name for c in fields(AutoEvalColumn) if not c.never_hidden] | |
SELECTED_COLS = [ | |
c.name | |
for c in fields(AutoEvalColumn) | |
if not c.never_hidden and c.displayed_by_default | |
] | |
return { | |
"column_dict": updated_column_dict, | |
"AutoEvalColumn": AutoEvalColumn, | |
"COLS": COLS, | |
"TYPES": TYPES, | |
"ALWAYS_HERE_COLS": ALWAYS_HERE_COLS, | |
"TOGGLE_COLS": TOGGLE_COLS, | |
"SELECTED_COLS": SELECTED_COLS, | |
} | |
def create_confusion_matrix_plot(matrix, labels, is_forced): | |
""" | |
Creates a confusion matrix plot for language detection. | |
:param matrix: 2D numpy array representing the confusion matrix | |
:param labels: List of language labels | |
:param is_forced: Boolean indicating whether language hint was used | |
:return: A Plotly figure object representing the confusion matrix | |
This function generates a heatmap visualization of the confusion matrix | |
for language detection, with customized layout and hover information. | |
""" | |
fig = go.Figure( | |
data=go.Heatmap( | |
z=matrix, | |
x=labels, | |
y=labels, | |
colorscale=[ | |
[0, "rgb(250,249,244)"], | |
[0.5, "rgb(69,66,244)"], | |
[1.0, "rgb(14,12,6)"], | |
], | |
hoverongaps=False, | |
hovertemplate="True: %{y}<br>Predicted: %{x}<br>Value: %{z}<extra></extra>", | |
) | |
) | |
fig.update_layout( | |
title=f'Language Detection Confusion Matrix with {"Language Hint" if is_forced else "Language Prediction by Model"}', | |
xaxis_title="Predicted Language", | |
yaxis_title="True Language", | |
xaxis=dict(tickangle=-45), | |
width=600, | |
height=600, | |
margin=dict(l=50, r=50, t=50, b=50), | |
) | |
return fig | |
def hex_to_rgb(hex_color): | |
""" | |
Converts a hexadecimal color code to RGB values. | |
:param hex_color: String representing a color in hexadecimal format | |
:return: Tuple of three integers representing RGB values | |
This function takes a hex color code and returns the corresponding | |
RGB values as a tuple of integers. | |
""" | |
hex_color = hex_color.lstrip("#") | |
return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) | |
def rgb_to_hex(rgb): | |
""" | |
Converts RGB values to a hexadecimal color code. | |
:param rgb: Tuple of three integers representing RGB values | |
:return: String representing the color in hexadecimal format | |
This function takes RGB values as a tuple and returns the corresponding | |
hex color code as a string. | |
""" | |
return "#{:02x}{:02x}{:02x}".format(*rgb) | |
def interpolate_colors(color1, color2, factor): | |
""" | |
Interpolates between two colors in HSV space. | |
:param color1: First color in hexadecimal format | |
:param color2: Second color in hexadecimal format | |
:param factor: Float between 0 and 1, representing the interpolation factor | |
:return: Interpolated color in hexadecimal format | |
This function performs color interpolation in HSV color space, which can | |
produce more visually pleasing results than simple RGB interpolation. | |
""" | |
rgb1 = hex_to_rgb(color1) | |
rgb2 = hex_to_rgb(color2) | |
hsv1 = colorsys.rgb_to_hsv(*[x / 255.0 for x in rgb1]) | |
hsv2 = colorsys.rgb_to_hsv(*[x / 255.0 for x in rgb2]) | |
h = (hsv1[0] + factor * (hsv2[0] - hsv1[0])) % 1.0 | |
s = hsv1[1] + factor * (hsv2[1] - hsv1[1]) | |
v = hsv1[2] + factor * (hsv2[2] - hsv1[2]) | |
rgb = colorsys.hsv_to_rgb(h, s, v) | |
return rgb_to_hex(tuple(int(x * 255) for x in rgb)) | |
def color_distance(color1, color2): | |
""" | |
Calculates the Euclidean distance between two colors in RGB space. | |
:param color1: First color in hexadecimal format | |
:param color2: Second color in hexadecimal format | |
:return: Float representing the distance between the two colors | |
This function computes the Euclidean distance between two colors in RGB space, | |
which can be used as a measure of color similarity. | |
""" | |
rgb1 = hex_to_rgb(color1) | |
rgb2 = hex_to_rgb(color2) | |
return sum((a - b) ** 2 for a, b in zip(rgb1, rgb2)) ** 0.5 | |
def generate_random_colors(base_colors, num_colors, min_distance=30): | |
""" | |
Generates a list of random colors based on a set of base colors. | |
:param base_colors: List of base colors in hexadecimal format | |
:param num_colors: Number of colors to generate | |
:param min_distance: Minimum distance between generated colors (default: 30) | |
:return: List of generated colors in hexadecimal format | |
This function creates a list of random colors by interpolating between | |
the provided base colors. It attempts to maintain a minimum distance | |
between colors to ensure visual distinctiveness. | |
""" | |
generated_colors = [] | |
attempts = 0 | |
max_attempts = 1000 | |
while len(generated_colors) < num_colors and attempts < max_attempts: | |
color1, color2 = random.sample(base_colors, 2) | |
factor = random.random() | |
new_color = interpolate_colors(color1, color2, factor) | |
if all(color_distance(new_color, c) >= min_distance for c in generated_colors): | |
generated_colors.append(new_color) | |
attempts = 0 | |
else: | |
attempts += 1 | |
if attempts > 100: | |
if random.random() < 0.1: | |
generated_colors.append(new_color) | |
attempts = 0 | |
return generated_colors | |
class Task: | |
""" | |
Dataclass representing a benchmark task. | |
:param benchmark: String representing the benchmark name | |
:param metric: String representing the metric used for evaluation | |
:param col_name: String representing the column name in the results DataFrame | |
""" | |
benchmark: str | |
metric: str | |
col_name: str | |
class ColumnContent: | |
""" | |
Dataclass representing a column in the results table. | |
:param name: String representing the column name | |
:param type: String representing the data type of the column | |
:param displayed_by_default: Boolean indicating if the column should be displayed by default | |
:param hidden: Boolean indicating if the column should be hidden (default: False) | |
:param never_hidden: Boolean indicating if the column should never be hidden (default: False) | |
:param dummy: Boolean indicating if this is a dummy column (default: False) | |
""" | |
name: str | |
type: str | |
displayed_by_default: bool | |
hidden: bool = False | |
never_hidden: bool = False | |
dummy: bool = False | |
css = """ | |
@font-face { | |
font-family: 'Zwizz Regular'; | |
font-style: normal; | |
font-weight: normal; | |
src: local('Zwizz Regular'), url('static/Zwizz-Regular.woff') format('woff'); | |
} | |
@font-face { | |
font-family: 'Zwizz Medium'; | |
font-style: normal; | |
font-weight: normal; | |
src: local('Zwizz Medium'), url('static/Zwizz-Medium.woff') format('woff'); | |
} | |
@font-face { | |
font-family: 'Zwizz SemiBold'; | |
font-style: normal; | |
font-weight: normal; | |
src: local('Zwizz SemiBold'), url('static/Zwizz-SemiBold.woff') format('woff'); | |
} | |
@import url('https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&display=swap'); | |
@import url('https://fonts.googleapis.com/css2?family=Sora:wght@300..400&display=swap'); | |
/* Typography Scale */ | |
h1, .h1 { | |
font-family: 'Sora', sans-serif; | |
font-weight: 300; | |
font-size: 2em; | |
letter-spacing: -0.05em; | |
} | |
h2, .h2 { | |
font-family: 'Sora', sans-serif; | |
font-weight: 400; | |
letter-spacing: -0.05em; | |
} | |
h3, h4, h5, .h3, .h4, .h5 { | |
font-family: 'Sora', sans-serif; | |
font-weight: 400; | |
letter-spacing: -0.05em; | |
} | |
h6, .h6, pre, code, .monospace { | |
font-family: 'IBM Plex Mono', monospace; | |
font-weight: 400; | |
letter-spacing: 0.01em; | |
} | |
/* Add strong tag styling */ | |
strong, b { | |
font-family: 'Zwizz SemiBold', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; | |
letter-spacing: -0.02em; | |
} | |
/* Global Zwizz styles */ | |
:root { | |
--zwizz-spacing: -0.02em; | |
} | |
/* All Gradio elements should have Zwizz spacing */ | |
.gradio-container * { | |
letter-spacing: var(--zwizz-spacing); | |
line-height: 1.7; | |
} | |
/* UI Elements */ | |
.tab-buttons button, #models-to-add-text, .gradio-button { | |
font-family: 'Sora', sans-serif; | |
font-weight: 400; | |
letter-spacing: -0.05em; | |
} | |
/* Specific Table Styling */ | |
table, .table, th, td { | |
font-family: 'IBM Plex Mono', 'Noto Color Emoji', sans-serif, monospace !important; | |
font-weight: 400; | |
letter-spacing: 0.01em; | |
} | |
/* Technical/Code Elements */ | |
.code-block, .technical-text { | |
font-family: 'IBM Plex Mono', monospace; | |
font-weight: 400; | |
letter-spacing: 0.01em; | |
} | |
/* Additional Elements */ | |
#methodology-text p, #methodology-text li, .markdown-text { | |
font-family: 'Zwizz Regular', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; | |
font-size: 16px !important; | |
letter-spacing: var(--zwizz-spacing); | |
line-height: 1.7; | |
} | |
/* Font weight utilities */ | |
.zwizz-medium { | |
font-family: 'Zwizz Medium', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; | |
} | |
.zwizz-semibold { | |
font-family: 'Zwizz SemiBold', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; | |
} | |
/* Maintaining Original Layout Rules */ | |
.gradio-container { | |
max-width: 95% !important; | |
} | |
/* Table Layouts */ | |
.large-table, | |
.large-table .table-wrap, | |
#multilingual-model-table .table-wrap, | |
#lookup-table .table-wrap { | |
height: 35em !important; | |
overflow-y: scroll !important; | |
} | |
/* SVG Container Rules */ | |
.svg-container, | |
.main-svg { | |
width: 100% !important; | |
} | |
.large-table, .large-table .table-wrap, #multilingual-model-table .table-wrap, #lookup-table .table-wrap { | |
height: 35em !important; | |
overflow-y: scroll !important; | |
} | |
.left-side-table .table-wrap { | |
height: 15em !important; | |
overflow-y: scroll !important; | |
} | |
#average-wer-table .table-wrap { | |
height: 8em !important; | |
overflow-y: scroll !important; | |
} | |
#general-wer-table .table-wrap { | |
height: 35em !important; | |
overflow-y: scroll !important; | |
} | |
""" | |