Add ROC-AUC score for each feature
Browse files- app.py +18 -3
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import matplotlib.pyplot as plt
|
3 |
import seaborn as sns
|
4 |
|
|
|
5 |
from datasets import load_dataset
|
6 |
|
7 |
import histos
|
@@ -16,9 +18,22 @@ def get_plot(features, n_bins):
|
|
16 |
plotting_df = dataset_df.copy()
|
17 |
if len(features) == 1:
|
18 |
fig, ax = plt.subplots()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
values = [
|
20 |
-
|
21 |
-
|
22 |
]
|
23 |
labels = ["spin-ON", "spin-OFF"]
|
24 |
fig = histos.ratio_hist(
|
@@ -27,7 +42,7 @@ def get_plot(features, n_bins):
|
|
27 |
reference_label=labels[1],
|
28 |
n_bins=n_bins,
|
29 |
hist_range=None,
|
30 |
-
title=features[0],
|
31 |
)
|
32 |
return fig
|
33 |
if len(features) == 2:
|
|
|
1 |
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
import matplotlib.pyplot as plt
|
4 |
import seaborn as sns
|
5 |
|
6 |
+
from sklearn import metrics
|
7 |
from datasets import load_dataset
|
8 |
|
9 |
import histos
|
|
|
18 |
plotting_df = dataset_df.copy()
|
19 |
if len(features) == 1:
|
20 |
fig, ax = plt.subplots()
|
21 |
+
pos_samples = plotting_df[plotting_df["target"] == "spin-ON"][features[0]]
|
22 |
+
neg_samples = plotting_df[plotting_df["target"] == "spin-OFF"][features[0]]
|
23 |
+
y_score = np.concatenate([pos_samples, neg_samples], axis=0)
|
24 |
+
if pos_samples.mean() >= neg_samples.mean():
|
25 |
+
y_true = np.concatenate(
|
26 |
+
[np.ones_like(pos_samples), np.zeros_like(neg_samples)], axis=0
|
27 |
+
)
|
28 |
+
roc_auc_score = metrics.roc_auc_score(y_true, y_score)
|
29 |
+
else:
|
30 |
+
y_true = np.concatenate(
|
31 |
+
[np.zeros_like(pos_samples), np.ones_like(neg_samples)], axis=0
|
32 |
+
)
|
33 |
+
roc_auc_score = metrics.roc_auc_score(y_true, y_score)
|
34 |
values = [
|
35 |
+
pos_samples,
|
36 |
+
neg_samples,
|
37 |
]
|
38 |
labels = ["spin-ON", "spin-OFF"]
|
39 |
fig = histos.ratio_hist(
|
|
|
42 |
reference_label=labels[1],
|
43 |
n_bins=n_bins,
|
44 |
hist_range=None,
|
45 |
+
title=f"{features[0]} (ROC AUC: {roc_auc_score:.3f})",
|
46 |
)
|
47 |
return fig
|
48 |
if len(features) == 2:
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
matplotlib==3.7.1
|
|
|
2 |
seaborn==0.12.2
|
|
|
1 |
matplotlib==3.7.1
|
2 |
+
scikit-learn==1.2.2
|
3 |
seaborn==0.12.2
|