File size: 6,969 Bytes
3531ac7 4f27632 3531ac7 9f37055 3531ac7 9f37055 3531ac7 9f37055 3531ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import plotly.express as px
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
import gradio as gr
# Load data from https://www.openml.org/d/554
X, y = fetch_openml(
"mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)
print("Data loaded")
random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))
scaler = StandardScaler()
def dataset_display(digit, count_per_digit, binary_image):
if digit not in range(10):
# return a figure displaying an error message
return px.imshow(
np.zeros((28, 28)),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Digit {digit} is not in the data",
)
binary_value = True if binary_image == 1 else False
digit_idxs = np.where(y == str(digit))[0]
random_idxs = np.random.choice(digit_idxs, size=count_per_digit, replace=False)
fig = px.imshow(
np.array([X[i].reshape(28, 28) for i in random_idxs]),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Examples of Digit {digit} in Data",
facet_col=0,
facet_col_wrap=5,
binary_string=binary_value,
)
return fig
def predict(img):
try:
img = img.reshape(1, -1)
except:
return "Show Your Drawing Skills"
try:
img = scaler.transform(img)
prediction = clf.predict(img)
return prediction[0]
except:
return "Train the model first"
def train_model(train_sample=5000, c=0.1, tol=0.1, solver="sage", penalty="l1"):
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=train_sample, test_size=10000
)
penalty_dict = {
"l2": ["lbfgs", "newton-cg", "newton-cholesky", "sag", "saga"],
"l1": ["liblinear", "saga"],
"elasticnet": ["saga"],
}
if solver not in penalty_dict[penalty]:
return (
"Solver not supported for the selected penalty",
"Change the Combination",
None,
)
global clf
global scaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
clf = LogisticRegression(C=c, penalty=penalty, solver=solver, tol=tol)
clf.fit(X_train, y_train)
sparsity = np.mean(clf.coef_ == 0) * 100
score = clf.score(X_test, y_test)
coef = clf.coef_.copy()
scale = np.abs(coef).max()
fig = px.imshow(
np.array([coef[i].reshape(28, 28) for i in range(10)]),
labels=dict(x="Pixel columns", y="Pixel rows"),
title=f"Classification vector for each digit",
range_color=[-scale, scale],
facet_col=0,
facet_col_wrap=5,
facet_col_spacing=0.01,
color_continuous_scale="RdBu",
zmin=-scale,
zmax=scale,
)
return score, sparsity, fig
with gr.Blocks() as demo:
gr.Markdown("# MNIST classification using multinomial logistic + L1 ")
gr.Markdown(
"""This interactive demo is based on the [MNIST classification using multinomial logistic + L1](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sparse_logistic_regression_mnist.html#sphx-glr-auto-examples-linear-model-plot-sparse-logistic-regression-mnist-py) example from the popular [scikit-learn](https://scikit-learn.org/stable/) library, which is a widely-used library for machine learning in Python. The primary goal of this demo is to showcase the use of logistic regression in classifying handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset, which is a well-known benchmark dataset in computer vision. The dataset is loaded from [OpenML](https://www.openml.org/d/554), which is an open platform for machine learning research that provides easy access to a large number of datasets.
The model is trained using the scikit-learn library, which provides a range of tools for machine learning, including classification, regression, and clustering algorithms, as well as tools for data preprocessing and model evaluation. The demo calculates the score and sparsity metrics using test data, which provides insight into the model's performance and sparsity, respectively. The score metric indicates how well the model is performing, while the sparsity metric provides information about the number of non-zero coefficients in the model, which can be useful for interpreting the model and reducing its complexity.
"""
)
with gr.Tab("Explore the Data"):
gr.Markdown("## ")
with gr.Row():
digit = gr.Slider(0, 9, label="Select the Digit", value=5, step=1)
count_per_digit = gr.Slider(
1, 10, label="Number of Images", value=10, step=1
)
binary_image = gr.Slider(0, 1, label="Binary Image", value=0, step=1)
gen_btn = gr.Button("Show Me ")
gen_btn.click(
dataset_display,
inputs=[digit, count_per_digit, binary_image],
outputs=gr.Plot(),
)
with gr.Tab("Train Your Model"):
gr.Markdown("# Play with the parameters to see how the model changes")
gr.Markdown("## Solver and penalty")
gr.Markdown(
"""
Penalty | Solver
-------|---------------
l1 | saga
l2 | saga
"""
)
with gr.Row():
train_sample = gr.Slider(
1000, 60000, label="Train Sample", value=5000, step=1
)
c = gr.Slider(0.1, 1, label="C", value=0.1, step=0.1)
tol = gr.Slider(
0.1, 1, label="Tolerance for stopping criteria.", value=0.1, step=0.1
)
max_iter = gr.Slider(100, 1000, label="Max Iter", value=100, step=1)
penalty = gr.Dropdown(
["l1", "l2",], label="Penalty", value="l1"
)
solver = gr.Dropdown(
["saga"],
label="Solver",
value="saga",
)
train_btn = gr.Button("Train")
train_btn.click(
train_model,
inputs=[train_sample, c, tol, solver, penalty],
outputs=[
gr.Textbox(label="Score"),
gr.Textbox(label="Sparsity"),
gr.Plot(),
],
)
with gr.Tab("Predict the Digit"):
gr.Markdown("## Draw a digit and see the model's prediction")
inputs = gr.Sketchpad(brush_radius=1.0)
outputs = gr.Textbox(label="Predicted Label", lines=1)
skecth_btn = gr.Button("Classify the Sketch")
skecth_btn.click(predict, inputs, outputs)
demo.launch()
|