Spaces:
Running
Running
<html lang="en-US"> | |
<head> | |
<meta charset="UTF-8"> | |
<!-- Begin Jekyll SEO tag v2.8.0 --> | |
<title>NCTV | Neural Clamping Toolkit and Visualization for Neural Network Calibration</title> | |
<meta property="og:title" content="NCTV" /> | |
<meta property="og:locale" content="en_US" /> | |
<meta name="description" content="Neural Clamping Toolkit and Visualization for Neural Network Calibration" /> | |
<meta property="og:description" content="Neural Clamping Toolkit and Visualization for Neural Network Calibration" /> | |
<script type="application/ld+json"> | |
{"@context":"https://schema.org","@type":"WebSite","description":"Neural Clamping Toolkit and Visualization for Neural Network Calibration","headline":"NCTV","name":"NCTV","url":"https://huggingface.co/spaces/hsiung/NCTV"}</script> | |
<!-- End Jekyll SEO tag --> | |
<link rel="preconnect" href="https://fonts.gstatic.com"> | |
<link rel="preload" href="https://fonts.googleapis.com/css?family=Open+Sans:400,700&display=swap" as="style" type="text/css" crossorigin> | |
<meta name="viewport" content="width=device-width, initial-scale=1"> | |
<meta name="theme-color" content="#157878"> | |
<meta name="apple-mobile-web-app-status-bar-style" content="black-translucent"> | |
<link rel="stylesheet" href="assets/css/bootstrap/bootstrap.min.css?v=90447f115a006bc45b738d9592069468b20e2551"> | |
<link rel="stylesheet" href="assets/css/style.css?v=90447f115a006bc45b738d9592069468b20e2551"> | |
<!-- start custom head snippets, customize with your own _includes/head-custom.html file --> | |
<link rel="stylesheet" href="assets/css/custom_style.css?v=90447f115a006bc45b738d9592069468b20e2551"> | |
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script> | |
<link rel="stylesheet" href="https://ajax.googleapis.com/ajax/libs/jqueryui/1.12.1/themes/smoothness/jquery-ui.css"> | |
<script src="https://ajax.googleapis.com/ajax/libs/jqueryui/1.12.1/jquery-ui.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.9.4/Chart.js"></script> | |
<script src="assets/js/calibration.js?v=90447f115a006bc45b738d9592069468b20e2551"></script> | |
<!-- for mathjax support --> | |
<script src="https://cdnjs.cloudflare.com/polyfill/v3/polyfill.min.js?features=es6"></script> | |
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> | |
<!-- end custom head snippets --> | |
</head> | |
<body> | |
<a id="skip-to-content" href="#content">Skip to the content.</a> | |
<header class="page-header" role="banner"> | |
<h1 class="project-name">NCTV</h1> | |
<h2 class="project-tagline">Neural Clamping Toolkit and Visualization for Neural Network Calibration</h2> | |
</header> | |
<main id="content" class="main-content" role="main"> | |
<h2 id="introduction">Introduction</h2> | |
<p>Neural network calibration is an essential task in deep learning to ensure consistency | |
between the confidence of model prediction and the true correctness likelihood. In this | |
demonstration, we first visualize the idea of neural network calibration on a binary | |
classifier and show model features that represent its calibration. Second, we introduce | |
our proposed framework <strong>Neural Clamping</strong>, which employs a simple joint input-output | |
transformation on a pre-trained classifier. We also provide other calibration approaches | |
(e.g., temperature scaling) to compare with Neural Clamping.</p> | |
<h2 id="what-is-calibration">What is Calibration?</h2> | |
<p>Neural Network Calibration seeks to make model prediction align with its true correctness likelihood. | |
A well-calibrated model should provide accurate predictions and reliable confidence when making inferences. On the | |
contrary, a poor calibration model would have a wide gap between its accuracy and average confidence level. | |
This phenomenon could hamper scenarios requiring accurate uncertainty estimation, such as safety-related tasks | |
(e.g., autonomous driving systems, medical diagnosis, etc.).</p> | |
<div class="container"> | |
<div id="calibration-intro" class="row align-items-center calibration-intro-sec"> | |
<img id="calibration-intro-img" src="images/conf_acc_demo.gif" /> | |
</div> | |
</div> | |
<h3 id="calibration-metrics">Calibration Metrics</h3> | |
<p>Objectively, researchers utilize <strong>Calibration Metrics</strong> to measure the calibration error for a model, for example, | |
Expected Calibration Error (ECE), Static Calibration Error (SCE), Adaptive Calibration Error (ACE), etc.</p> | |
<div class="container calibration-intro-sec"> | |
<div><img id="calibration-intro-img" src="images/metrics/intro-metric-example.png" /></div> | |
</div> | |
<div id="calibration-metrics-formula" class="container"> | |
<div id="calibration-metrics-formula-list" class="row align-items-center formula-list"> | |
<a href="#ECE-formula" class="selected">ECE</a> | |
<a href="#SCE-formula">SCE</a> | |
<a href="#ACE-formula">ACE</a> | |
<div style="clear: both"></div> | |
</div> | |
<div id="calibration-metrics-formula-content" class="row align-items-center"> | |
<span id="ECE-formula" class="formula" style="">$$\displaystyle \text{ECE}=\sum_{i=1}^{M}\frac{|B_i|}{n}|\text{acc}(B_i)-\text{conf}(B_i)|$$</span> | |
<span id="SCE-formula" class="formula" style="display: none;">$$\displaystyle \text{SCE}=\frac{1}{K}\sum_{k=1}^{K}\sum_{i=1}^{M}\frac{|B_i^k|}{n}|\text{acc}(i, k)-\text{conf}(i,k)|$$</span> | |
<span id="ACE-formula" class="formula" style="display: none;">$$\displaystyle \text{ACE}=\frac{1}{KR}\sum_{k=1}^{K}\sum_{r=1}^{R}|\text{acc}(r,k)-\text{conf}(r,k)|$$</span> | |
</div> | |
</div> | |
<h2 id="proposed-approach-neural-clamping">Proposed Approach: Neural Clamping</h2> | |
<div class="container"><img id="calibration-header" src="images/header.png" /></div> | |
<h2 id="demonstration">Demonstration</h2> | |
<p>In the current research, a reliability diagram is drawn to show the calibration performance of a model. However, since | |
reliability diagrams often only provide fixed bar graphs statically, further explanation from the chart is limited. In | |
this demonstration, we show how to make reliability diagrams interactive and insightful to help researchers and | |
developers gain more insights from the graph. Specifically, we provide three CIFAR-100 classification models | |
in this demonstration. Multiple Bin numbers are also support</p> | |
<p>We hope this tool could also facilitate the development process.</p> | |
<div id="calibration-demo" class="container"> | |
<div class="row align-items-center"> | |
<div class="row" style="display: none;"> | |
<div class="datasets-list"> | |
<span style="margin-right: 1em;">Datasets</span> | |
<span class="radio-group"><input type="radio" id="CIFAR-100" class="options" name="datasets" value="cifar100" checked="" /><label for="CIFAR-100" class="option-label">CIFAR-100</label></span> | |
<span class="radio-group"><input type="radio" id="ImageNet" class="options" name="datasets" value="imagenet" /><label for="ImageNet" class="option-label">ImageNet</label></span> | |
</div> | |
</div> | |
<div class="row" style="margin: 10px 0 0"> | |
<div class="models-list"> | |
<span style="margin-right: 1em;">Models</span> | |
<span class="radio-group"><input type="radio" id="ResNet110" class="options" name="models" value="resnet110" checked="" /><label for="ResNet110" class="option-label">ResNet110</label></span> | |
<span class="radio-group"><input type="radio" id="DenseNet121" class="options" name="models" value="densenet121" /><label for="DenseNet121" class="option-label">DenseNet121</label></span> | |
<span class="radio-group"><input type="radio" id="WideResNet40-10" class="options" name="models" value="wideresnet40_10" /><label for="WideResNet40-10" class="option-label">WideResNet40-10</label></span> | |
</div> | |
</div> | |
</div> | |
<div class="row align-items-center"> | |
<div class="col-4"> | |
<div id="toolbox"> | |
<div class="row align-items-center" style="margin-top: 2em;"><input type="radio" id="tool_none" class="options" name="calibration_tool" value="none" checked="" /><label for="tool_none" class="calibrate-tool">None</label></div> | |
<div class="row align-items-center"><input type="radio" id="tool_ts" class="options" name="calibration_tool" value="ts" /><label for="tool_ts" class="calibrate-tool">Temp. Scaling</label></div> | |
<div class="row align-items-center"><input type="radio" id="tool_delta" class="options" name="calibration_tool" value="delta" /><label for="tool_delta" class="calibrate-tool">Univ. Perturbation</label></div> | |
<div class="row align-items-center"><input type="radio" id="tool_neural_clamping" class="options" name="calibration_tool" value="neural_clamping" /><label for="tool_neural_clamping" class="calibrate-tool"><span style="font-weight: bold;">Neural Clamping</span></label></div> | |
</div> | |
<div class="row align-items-center"> | |
<div class="legend"><img src="images/demo-legend.png" alt="legend" /></div> | |
<div class="figure-option"><label class="container" for="ActualOnly">Actual Only<input id="ActualOnly" type="checkbox" name="ActualOnly" value="Actual Only" onchange="figureOption()" /><span class="checkmark"></span></label></div> | |
</div> | |
<div class="row align-items-center"> | |
<div class="calibration-error"><span class="calibration-metric">Expected Calibration Error</span><span class="calibration-error-value" id="ece-value">0.10731</span></div> | |
</div> | |
</div> | |
<div class="col-8"> | |
<figure class="figure"> | |
<img id="reliability-diagram" src="images/cifar100/resnet110/none/bin15.png" alt="CIFAR-100 Calibrated Reliability Diagram (Full)" /> | |
<div class="slider-container"> | |
<div class="slider-label"><span>Bin Number</span></div> | |
<div class="slider-content" id="bin-slider"><div id="bin-num" class="ui-slider-handle"></div></div> | |
</div> | |
<div class="slider-container"> | |
<div class="slider-label"><span>Temp Scaling</span></div> | |
<div class="slider-content" id="ts-slider"><div id="temp-scale" class="slider-value ui-slider-handle"></div></div> | |
</div> | |
<figcaption class="figure-caption"> | |
</figcaption> | |
</figure> | |
</div> | |
</div> | |
</div> | |
<h2 id="use-nctookit-to-calibrate-your-own-models">Use NCTookit to Calibrate Your Own Models</h2> | |
<p>Quick Start by running the following code! Or, <a href="https://colab.research.google.com/drive/1HosL29iJxK7Z8wNR9X3aWCgvbdu1ZgFu"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" /></a>. | |
Using this tool, users can use our proposed package, \(\texttt{NCTookit}\), to calibrate the model.</p> | |
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># !pip install -q git+https://github.com/yungchentang/NCToolkit.git | |
</span><span class="kn">from</span> <span class="nn">neural_clamping.nc_wrapper</span> <span class="kn">import</span> <span class="n">NCWrapper</span> | |
<span class="kn">from</span> <span class="nn">neural_clamping.utils</span> <span class="kn">import</span> <span class="n">load_model</span><span class="p">,</span> <span class="n">load_dataset</span><span class="p">,</span> <span class="n">model_classes</span><span class="p">,</span> <span class="n">plot_reliability_diagram</span> | |
<span class="c1"># Load model | |
</span><span class="n">model</span> <span class="o">=</span> <span class="n">load_model</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s">'ARCHITECTURE'</span><span class="p">,</span> <span class="n">data</span><span class="o">=</span><span class="s">'DATASET'</span><span class="p">,</span> <span class="n">checkpoint_path</span><span class="o">=</span><span class="s">'CHECKPOINT_PATH'</span><span class="p">)</span> | |
<span class="n">num_classes</span> <span class="o">=</span> <span class="n">model_classes</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="s">'DATASET'</span><span class="p">)</span> | |
<span class="c1"># Dataset loader | |
</span><span class="n">valloader</span> <span class="o">=</span> <span class="n">load_dataset</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="s">'DATASET'</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s">'val'</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="s">"BATCH_SIZE"</span><span class="p">)</span> | |
<span class="n">testloader</span> <span class="o">=</span> <span class="n">load_dataset</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="s">'DATASET'</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s">'test'</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="s">"BATCH_SIZE"</span><span class="p">)</span> | |
<span class="c1"># Build Neural Clamping framework | |
</span><span class="n">nc</span> <span class="o">=</span> <span class="n">NCWrapper</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="p">...)</span> | |
<span class="c1"># Calibrated using Neural Clamping | |
</span><span class="n">nc</span><span class="p">.</span><span class="n">train_NC</span><span class="p">(</span><span class="n">val_loader</span><span class="o">=</span><span class="n">valloader</span><span class="p">,</span> <span class="n">epoch</span><span class="o">=</span><span class="s">'EPOCH'</span><span class="p">,</span> <span class="p">...)</span> | |
<span class="c1"># General Evaluation | |
</span><span class="n">nc</span><span class="p">.</span><span class="n">test_with_NC</span><span class="p">(</span><span class="n">test_loader</span><span class="o">=</span><span class="n">testloader</span><span class="p">)</span> | |
<span class="c1"># Visualization | |
</span><span class="n">bin_acc</span><span class="p">,</span> <span class="n">conf_axis</span><span class="p">,</span> <span class="n">ece_score</span> <span class="o">=</span> <span class="n">nc</span><span class="p">.</span><span class="n">reliability_diagram</span><span class="p">(</span><span class="n">test_loader</span><span class="o">=</span><span class="n">testloader</span><span class="p">,</span> <span class="n">rd_criterion</span><span class="o">=</span><span class="s">"ECE"</span><span class="p">,</span> <span class="n">n_bins</span><span class="o">=</span><span class="mi">30</span><span class="p">)</span> | |
<span class="n">plot_reliability_diagram</span><span class="p">(</span><span class="n">conf_axis</span><span class="p">,</span> <span class="n">bin_acc</span><span class="p">)</span> | |
</code></pre></div></div> | |
<h2 id="citations">Citations</h2> | |
<p>If you find Neural Clamping helpful and useful for your research, please cite our main paper as follows:</p> | |
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article{tang2024neural, | |
title={{Neural Clamping: Joint Input Perturbation and Temperature Scaling for Neural Network Calibration}}, | |
author={Yung-Chen Tang and Pin-Yu Chen and Tsung-Yi Ho}, | |
journal={Transactions on Machine Learning Research}, | |
issn={2835-8856}, | |
year={2024}, | |
url={https://openreview.net/forum?id=qSFToMqLcq}, | |
} | |
@inproceedings{hsiung2023nctv, | |
title={{NCTV: Neural Clamping Toolkit and Visualization for Neural Network Calibration}}, | |
author={Lei Hsiung and Yung-Chen Tang and Pin-Yu Chen and Tsung-Yi Ho}, | |
booktitle={Proceedings of the Thirty-Seventh AAAI Conference on Artificial Intelligence}, | |
publisher={Association for the Advancement of Artificial Intelligence}, | |
year={2023}, | |
month={February} | |
} | |
</code></pre></div></div> | |
<footer class="site-footer"> | |
<span class="site-footer-owner">NCTV is maintained by <a href="https://hsiung.cc">Lei Hsiung</a> and <a href="https://github.com/yungchentang">Yung-Chen Tang</a>.</span> | |
</footer> | |
</main> | |
</body> | |
</html> | |