File size: 4,306 Bytes
4ed06b9 b74e3c1 4ed06b9 83750f3 4ed06b9 31bfb1f 83750f3 4ed06b9 83750f3 4ed06b9 83750f3 4ed06b9 b74e3c1 4ed06b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""
==========================================================
Gradio demo to Plot multi-class SGD on the iris dataset
==========================================================
Plot decision surface of multi-class SGD on iris dataset.
The hyperplanes corresponding to the three one-versus-all (OVA) classifiers
are represented by the dashed lines.
Created by Syed Affan <saffand03@gmail.com>
"""
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.linear_model import SGDClassifier
from sklearn.inspection import DecisionBoundaryDisplay
import matplotlib.cm
def make_plot(alpha,max_iter,Standardize):
# import some data to play with
iris = datasets.load_iris()
fig = plt.figure()
# we only take the first two features. We could
# avoid this ugly slicing by using a two-dim dataset
X = iris.data[:, :2]
y = iris.target
colors = "bry"
# shuffle
idx = np.arange(X.shape[0])
np.random.seed(13)
np.random.shuffle(idx)
X = X[idx]
y = y[idx]
# standardize
if Standardize:
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std
clf = SGDClassifier(alpha=alpha, max_iter=max_iter).fit(X, y)
accuracy = clf.score(X,y)
acc = f'### The Accuracy on the entire dataset: {accuracy}'
ax = plt.gca()
DecisionBoundaryDisplay.from_estimator(
clf,
X,
cmap=matplotlib.cm.Paired,
ax=ax,
response_method="predict",
xlabel=iris.feature_names[0],
ylabel=iris.feature_names[1],
)
plt.axis("tight")
# Plot also the training points
for i, color in zip(clf.classes_, colors):
idx = np.where(y == i)
plt.scatter(
X[idx, 0],
X[idx, 1],
c=color,
label=iris.target_names[i],
cmap=matplotlib.cm.Paired,
edgecolor="black",
s=20,
)
plt.title("Decision surface of multi-class SGD")
plt.axis("tight")
# Plot the three one-against-all classifiers
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
coef = clf.coef_
intercept = clf.intercept_
def plot_hyperplane(c, color):
def line(x0):
return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]
plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color)
for i, color in zip(clf.classes_, colors):
plot_hyperplane(i, color)
plt.legend()
return fig,acc
title = "Plot multi-class SGD on the iris dataset"
model_card = f"""
## Description
This interactive demo is based on the [Plot multi-class SGD on the iris dataset](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_iris.html#sphx-glr-auto-examples-linear-model-plot-sgd-iris-py) example from the popular [scikit-learn](https://scikit-learn.org/stable/) library, which is a widely-used library for machine learning in Python.
This demo plots the decision surface of multi-class SGD on the iris dataset. The hyperplanes corresponding to the three one-versus-all (OVA) classifiers are represented by the dashed lines.
You can play with the following hyperparameters:
`alpha` is a constant that multiplies the regularization term. The higher the value, the stronger the regularization.
`max_iter` is the maximum number of passes over the training data (aka epochs).
`Standardise` centers the dataset
## Dataset
[Iris Dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set)
## Model
currentmodule: [sklearn.linear_model](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model)
class:`SGDClassifier` is the estimator used in this example.
"""
with gr.Blocks(title=title) as demo:
gr.Markdown('''
<div>
<h1 style='text-align: center'>Plot multi-class SGD on iris dataset</h1>
</div>
''')
gr.Markdown(model_card)
gr.Markdown("Author: <a href=\"https://huggingface.co/sulpha\">sulpha</a>")
d0 = gr.Slider(0.001,5,step=0.001,value=0.001,label='alpha')
d1 = gr.Slider(1,1001,step=10,value=100,label='max_iter')
d2 = gr.Checkbox(value=True,label='Standardize')
btn =gr.Button(value='Submit')
btn.click(make_plot,inputs=[d0,d1,d2],outputs=[gr.Plot(),gr.Markdown()])
demo.launch()
|