import numpy as np import sys, tqdm import torch import torch.nn.functional as F from numpy import interp from collections.abc import Sequence from collections import defaultdict from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, auc from sklearn.metrics import precision_recall_curve, average_precision_score from sklearn.metrics import balanced_accuracy_score, precision_score import warnings import inspect _depth = lambda L: isinstance(L, (Sequence, np.ndarray)) and max(map(_depth, L)) + 1 def get_metrics(y_true, y_pred, scores, mask): ''' ... ''' with warnings.catch_warnings(): warnings.simplefilter("ignore") masked_y_true = y_true[np.where(mask == 1)] masked_y_pred = y_pred[np.where(mask == 1)] masked_scores = scores[np.where(mask == 1)] # metrics that are based on predictions try: cnf = confusion_matrix(masked_y_true, masked_y_pred) TN, FP, FN, TP = cnf.ravel() TNR = TN / (TN + FP) FPR = FP / (FP + TN) FNR = FN / (FN + TP) TPR = TP / (TP + FN) N = TN + TP + FN + FP S = (TP + FN) / N P = (TP + FP) / N acc = (TN + TP) / N sen = TP / (TP + FN) spc = TN / (TN + FP) prc = TP / (TP + FP) f1s = 2 * (prc * sen) / (prc + sen) mcc = (TP / N - S * P) / np.sqrt(P * S * (1 - S) * (1 - P)) # metrics that are based on scores, try: auc_roc = roc_auc_score(masked_y_true, masked_scores) except: auc_roc = 0 try: auc_pr = average_precision_score(masked_y_true, masked_scores) except: auc_pr = 0 bal_acc = balanced_accuracy_score(masked_y_true, masked_y_pred) except: cnf, acc, bal_acc, prc, sen, spc, f1s, mcc, auc_roc, auc_pr = -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 # construct the dictionary of all metrics met = {} met['Confusion Matrix'] = cnf met['Accuracy'] = acc met['Balanced Accuracy'] = bal_acc met['Precision'] = prc met['Sensitivity/Recall'] = sen met['Specificity'] = spc met['F1 score'] = f1s met['MCC'] = mcc met['AUC (ROC)'] = auc_roc met['AUC (PR)'] = auc_pr return met def get_metrics_multitask(y_true, y_pred, scores, mask): ''' ... ''' if type(y_true) is dict: met: dict[str, dict[str, float]] = dict() for k in y_true.keys(): met[k] = get_metrics(y_true[k], y_pred[k], scores[k], mask[k]) else: met = [] for i in range(len(y_true[0])): met.append(get_metrics(y_true[:, i], y_pred[:, i], scores[:, i], mask[:, i])) return met def print_metrics(met): ''' ... ''' for k, v in met.items(): if k not in ['Confusion Matrix']: print('{}:\t{:.4f}'.format(k, v).expandtabs(20)) def print_metrics_multitask(met): ''' ... ''' if type(met) is dict: lbl_ks = list(met.keys()) met_ks = met[lbl_ks[0]].keys() for met_k in met_ks: if met_k not in ['Confusion Matrix']: msg = '{}:\t' + '{:.4f} ' * len(met) val = [met[lbl_k][met_k] for lbl_k in lbl_ks] msg = msg.format(met_k, *val) msg = msg.replace('nan', '------') print(msg.expandtabs(20)) else: for k in met[0]: if k not in ['Confusion Matrix']: msg = '{}:\t' + '{:.4f} ' * len(met) val = [met[i][k] for i in range(len(met))] msg = msg.format(k, *val) msg = msg.replace('nan', '------') print(msg.expandtabs(20)) def pr_interp(rc_, rc, pr): pr_ = np.zeros_like(rc_) locs = np.searchsorted(rc, rc_) for idx, loc in enumerate(locs): l = loc - 1 r = loc r1 = rc[l] if l > -1 else 0 r2 = rc[r] if r < len(rc) else 1 p1 = pr[l] if l > -1 else 1 p2 = pr[r] if r < len(rc) else 0 t1 = (1 - p2) * r2 / p2 / (r2 - r1) if p2 * (r2 - r1) > 1e-16 else (1 - p2) * r2 / 1e-16 t2 = (1 - p1) * r1 / p1 / (r2 - r1) if p1 * (r2 - r1) > 1e-16 else (1 - p1) * r1 / 1e-16 t3 = (1 - p1) * r1 / p1 if p1 > 1e-16 else (1 - p1) * r1 / 1e-16 a = 1 + t1 - t2 b = t3 - t1 * r1 + t2 * r1 pr_[idx] = rc_[idx] / (a * rc_[idx] + b) return pr_ def get_roc_info(y_true_all, scores_all): fpr_pt = np.linspace(0, 1, 1001) tprs, aucs = [], [] for i in range(len(y_true_all)): y_true = y_true_all[i] scores = scores_all[i] fpr, tpr, _ = roc_curve(y_true=y_true, y_score=scores, drop_intermediate=True) tprs.append(interp(fpr_pt, fpr, tpr)) tprs[-1][0] = 0.0 aucs.append(auc(fpr, tpr)) tprs_mean = np.mean(tprs, axis=0) tprs_std = np.std(tprs, axis=0) tprs_upper = np.minimum(tprs_mean + tprs_std, 1) tprs_lower = np.maximum(tprs_mean - tprs_std, 0) auc_mean = auc(fpr_pt, tprs_mean) auc_std = np.std(aucs) auc_std = 1 - auc_mean if auc_mean + auc_std > 1 else auc_std rslt = { 'xs': fpr_pt, 'ys_mean': tprs_mean, 'ys_upper': tprs_upper, 'ys_lower': tprs_lower, 'auc_mean': auc_mean, 'auc_std': auc_std } return rslt def get_pr_info(y_true_all, scores_all): rc_pt = np.linspace(0, 1, 1001) rc_pt[0] = 1e-16 prs = [] aps = [] for i in range(len(y_true_all)): y_true = y_true_all[i] scores = scores_all[i] pr, rc, _ = precision_recall_curve(y_true=y_true, probas_pred=scores) aps.append(average_precision_score(y_true=y_true, y_score=scores)) pr, rc = pr[::-1], rc[::-1] prs.append(pr_interp(rc_pt, rc, pr)) prs_mean = np.mean(prs, axis=0) prs_std = np.std(prs, axis=0) prs_upper = np.minimum(prs_mean + prs_std, 1) prs_lower = np.maximum(prs_mean - prs_std, 0) aps_mean = np.mean(aps) aps_std = np.std(aps) aps_std = 1 - aps_mean if aps_mean + aps_std > 1 else aps_std rslt = { 'xs': rc_pt, 'ys_mean': prs_mean, 'ys_upper': prs_upper, 'ys_lower': prs_lower, 'auc_mean': aps_mean, 'auc_std': aps_std } return rslt def get_and_print_metrics(mdl, dat): ''' ... ''' y_pred = mdl.predict(dat.x) y_prob = mdl.predict_proba(dat.x) met_all = get_metrics(dat.y, y_pred, y_prob) for k, v in met_all.items(): if k not in ['Confusion Matrix']: print('{}:\t{:.4f}'.format(k, v).expandtabs(20)) def get_and_print_metrics_multitask(mdl, dat): ''' ... ''' y_pred = mdl.predict(dat.x) y_prob = mdl.predict_proba(dat.x) met = get_metrics_multitask(dat.y, y_pred, y_prob) print_metrics_multitask(met) def split_dataset(dat, ratio=.8, seed=0): len_trn = int(np.round(len(dat) * .8)) len_vld = len(dat) - len_trn dat_trn, dat_vld = torch.utils.data.random_split( dat, (len_trn, len_vld), generator=torch.Generator().manual_seed(0) ) return dat_trn, dat_vld def l1_regularizer(model, lambda_l1=0.01): ''' LASSO ''' lossl1 = 0 for model_param_name, model_param_value in model.named_parameters(): if model_param_name.endswith('weight'): lossl1 += lambda_l1 * model_param_value.abs().sum() return lossl1 class ProgressBar(tqdm.tqdm): def __init__(self, total, desc, file=sys.stdout): super().__init__(total=total, desc=desc, ascii=True, bar_format='{l_bar}{r_bar}', file=file) def update(self, batch_size, to_disp): postfix = {} for k, v in to_disp.items(): if k == 'cnf': postfix[k] = v.__repr__().replace('\n', '') else: postfix[k] = '{:.6f}'.format(v.cpu().numpy()) self.set_postfix(postfix) super().update(batch_size) def _get_gt_mask(logits, target): target = target.reshape(-1) mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() return mask def _get_other_mask(logits, target): target = target.reshape(-1) mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() return mask def cat_mask(t, mask1, mask2): t1 = (t * mask1).sum(dim=1, keepdims=True) t2 = (t * mask2).sum(1, keepdims=True) rt = torch.cat([t1, t2], dim=1) return rt def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature): gt_mask = _get_gt_mask(logits_student, target) other_mask = _get_other_mask(logits_student, target) pred_student = F.softmax(logits_student / temperature, dim=1) pred_teacher = F.softmax(logits_teacher / temperature, dim=1) pred_student = cat_mask(pred_student, gt_mask, other_mask) pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask) log_pred_student = torch.log(pred_student) tckd_loss = ( F.kl_div(log_pred_student, pred_teacher, size_average=False) * (temperature**2) / target.shape[0] ) pred_teacher_part2 = F.softmax( logits_teacher / temperature - 1000.0 * gt_mask, dim=1 ) log_pred_student_part2 = F.log_softmax( logits_student / temperature - 1000.0 * gt_mask, dim=1 ) nckd_loss = ( F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False) * (temperature**2) / target.shape[0] ) return alpha * tckd_loss + beta * nckd_loss def convert_args_kwargs_to_kwargs(func, args, kwargs): """ ... """ signature = inspect.signature(func) bound_args = signature.bind(*args, **kwargs) bound_args.apply_defaults() return bound_args.arguments