from PIL import Image import gradio as gr import numpy as np from datasets import load_dataset import os import tempfile dataset = load_dataset("erceguder/histocan-test", token=os.environ["HF_TOKEN"]) COLOR_PALETTE = { 'others': (0, 0, 0), 't-g1': (0, 192, 0), 't-g2': (255, 224, 32), 't-g3': (255, 0, 0), 'normal-mucosa': (0, 32, 255) } def files_uploaded(paths): if len(paths) != 16: raise gr.Error("16 segmentation masks are needed.") uploaded_file_names = [paths[i].name.split('/')[-1] for i in range(16)] for i in range(16): if f"test{i:04d}.png" not in uploaded_file_names: raise gr.Error(f"Uploaded file names are not recognized.") def evaluate(paths): if len(dataset["test"]) != 16: raise gr.Error("Could not read ground truth data from the server.") if paths == None: raise gr.Error("Upload segmentation masks first!") # Init dicts for accumulating image metrics and calculating per-class scores metrics = {} for class_ in COLOR_PALETTE.keys(): idict = { "tp": 0.0, "fp": 0.0, "tn": 0.0, "fn": 0.0, } metrics[class_] = idict scores = {} for class_ in COLOR_PALETTE.keys(): idict = { "recall": 0.0, "precision": 0.0, "f1": 0.0 } scores[class_] = idict tmpdir = tempfile.TemporaryDirectory() for path in paths: os.rename(path.name, os.path.join(tmpdir.name, path.name.split('/')[-1])) for item in dataset["test"]: pred_path = os.path.join(tmpdir.name, item["name"]) pred = np.array(Image.open(pred_path)) gt = np.array(item["annotation"]) assert gt.ndim == 2 assert pred.ndim == 3 and pred.shape[-1] == 3 assert gt.shape == pred.shape[:-1] # Get predictions for all classes out = [(pred == color).all(axis=-1) for color in COLOR_PALETTE.values()] maps = np.stack(out) # Calculate confusion matrix and metrics for i, class_ in enumerate(COLOR_PALETTE.keys()): class_pred = maps[i] class_gt = (gt == i) tp = np.sum(class_pred[class_gt==True]) fp = np.sum(class_pred[class_gt==False]) tn = np.sum(np.logical_not(class_pred)[class_gt==False]) fn = np.sum(np.logical_not(class_pred)[class_gt==True]) # Accumulate metrics for each class metrics[class_]['tp'] += tp metrics[class_]['fp'] += fp metrics[class_]['tn'] += tn metrics[class_]['fn'] += fn # Init mean recall, precision and F1 score mRecall = 0.0 mPrecision = 0.0 mF1 = 0.0 # Calculate recall, precision and f1 scores for each class for i, class_ in enumerate(COLOR_PALETTE.keys()): scores[class_]['recall'] = metrics[class_]['tp'] / (metrics[class_]['tp'] + metrics[class_]['fn']) if metrics[class_]['tp'] > 0 else 0.0 scores[class_]['precision'] = metrics[class_]['tp'] / (metrics[class_]['tp'] + metrics[class_]['fp']) if metrics[class_]['tp'] > 0 else 0.0 scores[class_]['f1'] = 2 * scores[class_]['precision'] * scores[class_]['recall'] / (scores[class_]['precision'] + scores[class_]['recall']) if (scores[class_]['precision'] != 0 and scores[class_]['recall'] != 0) else 0.0 mRecall += scores[class_]['recall'] mPrecision += scores[class_]['precision'] mF1 += scores[class_]['f1'] # Calculate mean recall, precision and F1 score over all classes class_count = len(COLOR_PALETTE) mRecall /= class_count mPrecision /= class_count mF1 /= class_count tmpdir.cleanup() result = """