import gradio as gr import requests import json from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForQuestionAnswering from datasets import load_dataset import datasets import plotly.io as pio import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots import pandas as pd from sklearn.metrics import confusion_matrix import importlib import torch from dash import Dash, html, dcc import numpy as np from sklearn.metrics import accuracy_score from sklearn.metrics import f1_score def load_model(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if model_type == "text_classification": dataset = load_dataset(dataset_name, config_name) num_labels = len(dataset["train"].features["label"].names) if "roberta" in model_name_or_path.lower(): from transformers import RobertaForSequenceClassification model = RobertaForSequenceClassification.from_pretrained( model_name_or_path, num_labels=num_labels) else: model = AutoModelForSequenceClassification.from_pretrained( model_name_or_path, num_labels=num_labels) elif model_type == "token_classification": dataset = load_dataset(dataset_name, config_name) num_labels = len( dataset["train"].features["ner_tags"].feature.names) model = AutoModelForTokenClassification.from_pretrained( model_name_or_path, num_labels=num_labels) elif model_type == "question_answering": model = AutoModelForQuestionAnswering.from_pretrained(model_name_or_path) else: raise ValueError(f"Invalid model type: {model_type}") return tokenizer, model def test_model(tokenizer, model, test_data: list, label_map: dict): results = [] for text, _, true_label in test_data: inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) outputs = model(**inputs) pred_label = label_map[int(outputs.logits.argmax(dim=-1))] results.append((text, true_label, pred_label)) return results def generate_label_map(dataset): if "label" not in dataset.features or dataset.features["label"] is None: return {} if isinstance(dataset.features["label"], datasets.ClassLabel): num_labels = dataset.features["label"].num_classes label_map = {i: label for i, label in enumerate(dataset.features["label"].names)} else: num_labels = len(set(dataset["label"])) label_map = {i: label for i, label in enumerate(set(dataset["label"]))} return label_map def calculate_fairness_score(results, label_map): true_labels = [r[1] for r in results] pred_labels = [r[2] for r in results] # Overall accuracy # accuracy = (true_labels == pred_labels).mean() accuracy = accuracy_score(true_labels, pred_labels) # Calculate confusion matrix for each group group_names = label_map.values() group_cms = {} for group in group_names: true_group_indices = [i for i, label in enumerate(true_labels) if label == group] pred_group_labels = [pred_labels[i] for i in true_group_indices] true_group_labels = [true_labels[i] for i in true_group_indices] cm = confusion_matrix(true_group_labels, pred_group_labels, labels=list(group_names)) group_cms[group] = cm # Calculate fairness score score = 0 for i, group1 in enumerate(group_names): for j, group2 in enumerate(group_names): if i < j: cm1 = group_cms[group1] cm2 = group_cms[group2] diff = np.abs(cm1 - cm2) score += (diff.sum() / 2) / cm1.sum() return accuracy, score def calculate_per_class_metrics(true_labels, pred_labels, label_map, metric='accuracy'): unique_labels = sorted(label_map.values()) metrics = [] if metric == 'accuracy': for label in unique_labels: label_indices = [i for i, true_label in enumerate(true_labels) if true_label == label] true_label_subset = [true_labels[i] for i in label_indices] pred_label_subset = [pred_labels[i] for i in label_indices] accuracy = accuracy_score(true_label_subset, pred_label_subset) metrics.append(accuracy) elif metric == 'f1': f1_scores = f1_score(true_labels, pred_labels, labels=unique_labels, average=None) metrics = f1_scores.tolist() else: raise ValueError(f"Invalid metric: {metric}") return metrics def generate_visualization(visualization_type, results, label_map): true_labels = [r[1] for r in results] pred_labels = [r[2] for r in results] if visualization_type == "confusion_matrix": return generate_report_card(results, label_map)["fig"] elif visualization_type == "per_class_accuracy": per_class_accuracy = calculate_per_class_metrics( true_labels, pred_labels, label_map, metric='accuracy') colors = px.colors.qualitative.Plotly fig = go.Figure() for i, label in enumerate(label_map.values()): fig.add_trace(go.Bar( x=[label], y=[per_class_accuracy[i]], name=label, marker_color=colors[i % len(colors)] )) fig.update_layout(title='Per-Class Accuracy', xaxis_title='Class', yaxis_title='Accuracy') return fig elif visualization_type == "per_class_f1": per_class_f1 = calculate_per_class_metrics( true_labels, pred_labels, label_map, metric='f1') colors = px.colors.qualitative.Plotly fig = go.Figure() for i, label in enumerate(label_map.values()): fig.add_trace(go.Bar( x=[label], y=[per_class_f1[i]], name=label, marker_color=colors[i % len(colors)] )) fig.update_layout(title='Per-Class F1-Score', xaxis_title='Class', yaxis_title='F1-Score') return fig else: raise ValueError(f"Invalid visualization type: {visualization_type}") def generate_report_card(results, label_map): true_labels = [r[1] for r in results] pred_labels = [r[2] for r in results] cm = confusion_matrix(true_labels, pred_labels, labels=list(label_map.values())) # Create the plotly figure fig = make_subplots(rows=1, cols=1) fig.add_trace(go.Heatmap( z=cm, x=list(label_map.values()), y=list(label_map.values()), colorscale='RdYlGn', colorbar=dict(title='# of Samples') )) fig.update_layout( height=500, width=600, title='Confusion Matrix', xaxis=dict(title='Predicted Labels'), yaxis=dict(title='True Labels', autorange='reversed') ) # Create the text output # accuracy = pd.Series(true_labels) == pd.Series(pred_labels) accuracy = accuracy_score(true_labels, pred_labels, normalize=False) fairness_score = calculate_fairness_score(results, label_map) per_class_accuracy = calculate_per_class_metrics( true_labels, pred_labels, label_map, metric='accuracy') per_class_f1 = calculate_per_class_metrics( true_labels, pred_labels, label_map, metric='f1') text_output = html.Div(children=[ html.H2('Performance Metrics'), html.Div(children=[ html.Div(children=[ html.H3('Accuracy'), html.H4(f'{accuracy}') ], className='metric'), html.Div(children=[ html.H3('Fairness Score'), # html.H4(f'{fairness_score}') html.H4( f'Accuracy: {fairness_score[0]:.2f}, Score: {fairness_score[1]:.2f}') ], className='metric'), ], className='metric-container'), ], className='text-output') # Combine the plot and text output into a Dash container # report_card = html.Div([ # dcc.Graph(figure=fig), # text_output, # ]) # return report_card report_card = { "fig": fig, "accuracy": accuracy, "fairness_score": fairness_score, "per_class_accuracy": per_class_accuracy, "per_class_f1": per_class_f1 } return report_card # return fig, text_output def app(model_type: str, model_name_or_path: str, dataset_name: str, config_name: str, dataset_split: str, num_samples: int, visualization_type: str): tokenizer, model = load_model( model_type, model_name_or_path, dataset_name, config_name) # Load the dataset # Add this line to cast num_samples to an integer num_samples = int(num_samples) dataset = load_dataset( dataset_name, config_name, split=f"{dataset_split}[:{num_samples}]") test_data = [] if dataset_name == "glue": test_data = [(item["sentence"], None, dataset.features["label"].names[item["label"]]) for item in dataset] elif dataset_name == "tweet_eval": test_data = [(item["text"], None, dataset.features["label"].names[item["label"]]) for item in dataset] else: test_data = [(item["sentence"], None, dataset.features["label"].names[item["label"]]) for item in dataset] # if model_type == "text_classification": # for item in dataset: # text = item["sentence"] # context = None # true_label = item["label"] # test_data.append((text, context, true_label)) # elif model_type == "question_answering": # for item in dataset: # text = item["question"] # context = item["context"] # true_label = None # test_data.append((text, context, true_label)) # else: # raise ValueError(f"Invalid model type: {model_type}") label_map = generate_label_map(dataset) results = test_model(tokenizer, model, test_data, label_map) # fig, text_output = generate_report_card(results, label_map) # return fig, text_output report_card = generate_report_card(results, label_map) visualization = generate_visualization(visualization_type, results, label_map) per_class_metrics_str = "\n".join([f"{label}: Acc {acc:.2f}, F1 {f1:.2f}" for label, acc, f1 in zip( label_map.values(), report_card['per_class_accuracy'], report_card['per_class_f1'])]) # return report_card["fig"], f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}" # return f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}", report_card["fig"] return (f"Accuracy: {report_card['accuracy']}, Fairness Score: {report_card['fairness_score'][1]:.2f}\n\n" f"Per-Class Metrics:\n{per_class_metrics_str}"), visualization interface = gr.Interface( fn=app, inputs=[ gr.inputs.Radio(["text_classification", "token_classification", "question_answering"], label="Model Type", default="text_classification"), gr.inputs.Textbox(lines=1, label="Model Name or Path", placeholder="ex: distilbert-base-uncased-finetuned-sst-2-english", default="distilbert-base-uncased-finetuned-sst-2-english"), gr.inputs.Textbox(lines=1, label="Dataset Name", placeholder="ex: glue", default="glue"), gr.inputs.Textbox(lines=1, label="Config Name", placeholder="ex: sst2", default="cola"), gr.inputs.Dropdown( choices=["train", "validation", "test"], label="Dataset Split", default="validation"), gr.inputs.Number(default=100, label="Number of Samples"), gr.inputs.Dropdown( choices=["confusion_matrix", "per_class_accuracy", "per_class_f1"], label="Visualization Type", default="confusion_matrix" ), ], # outputs=gr.Plot(), # outputs=gr.outputs.HTML(), # outputs=[gr.outputs.HTML(), gr.Plot()], outputs=[ gr.outputs.Textbox(label="Fairness and Bias Metrics"), gr.Plot(label="Graph") ], title="Fairness and Bias Testing", description="Enter a model and dataset to test for fairness and bias.", ) # Define the label map globally label_map = {0: "negative", 1: "positive"} if __name__ == "__main__": interface.launch()