|
""" |
|
========================================================== |
|
Gradio demo to Plot multi-class SGD on the iris dataset |
|
========================================================== |
|
|
|
Plot decision surface of multi-class SGD on iris dataset. |
|
The hyperplanes corresponding to the three one-versus-all (OVA) classifiers |
|
are represented by the dashed lines. |
|
|
|
Created by Syed Affan <saffand03@gmail.com> |
|
|
|
""" |
|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from sklearn import datasets |
|
from sklearn.linear_model import SGDClassifier |
|
from sklearn.inspection import DecisionBoundaryDisplay |
|
import matplotlib.cm |
|
|
|
def make_plot(alpha,max_iter,Standardize): |
|
|
|
iris = datasets.load_iris() |
|
fig = plt.figure() |
|
|
|
|
|
|
|
X = iris.data[:, :2] |
|
y = iris.target |
|
colors = "bry" |
|
|
|
|
|
idx = np.arange(X.shape[0]) |
|
np.random.seed(13) |
|
np.random.shuffle(idx) |
|
X = X[idx] |
|
y = y[idx] |
|
|
|
|
|
if Standardize: |
|
mean = X.mean(axis=0) |
|
std = X.std(axis=0) |
|
X = (X - mean) / std |
|
|
|
|
|
clf = SGDClassifier(alpha=alpha, max_iter=max_iter).fit(X, y) |
|
accuracy = clf.score(X,y) |
|
acc = f'### The Accuracy on the entire dataset: {accuracy}' |
|
ax = plt.gca() |
|
DecisionBoundaryDisplay.from_estimator( |
|
clf, |
|
X, |
|
cmap=matplotlib.cm.Paired, |
|
ax=ax, |
|
response_method="predict", |
|
xlabel=iris.feature_names[0], |
|
ylabel=iris.feature_names[1], |
|
) |
|
plt.axis("tight") |
|
|
|
|
|
for i, color in zip(clf.classes_, colors): |
|
idx = np.where(y == i) |
|
plt.scatter( |
|
X[idx, 0], |
|
X[idx, 1], |
|
c=color, |
|
label=iris.target_names[i], |
|
cmap=matplotlib.cm.Paired, |
|
edgecolor="black", |
|
s=20, |
|
) |
|
plt.title("Decision surface of multi-class SGD") |
|
plt.axis("tight") |
|
|
|
|
|
xmin, xmax = plt.xlim() |
|
ymin, ymax = plt.ylim() |
|
coef = clf.coef_ |
|
intercept = clf.intercept_ |
|
|
|
|
|
def plot_hyperplane(c, color): |
|
def line(x0): |
|
return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1] |
|
|
|
plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color) |
|
|
|
|
|
for i, color in zip(clf.classes_, colors): |
|
plot_hyperplane(i, color) |
|
plt.legend() |
|
|
|
return fig,acc |
|
|
|
title = "Plot multi-class SGD on the iris dataset" |
|
|
|
model_card = f""" |
|
## Description |
|
This interactive demo is based on the [Plot multi-class SGD on the iris dataset](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_iris.html#sphx-glr-auto-examples-linear-model-plot-sgd-iris-py) example from the popular [scikit-learn](https://scikit-learn.org/stable/) library, which is a widely-used library for machine learning in Python. |
|
This demo plots the decision surface of multi-class SGD on the iris dataset. The hyperplanes corresponding to the three one-versus-all (OVA) classifiers are represented by the dashed lines. |
|
You can play with the following hyperparameters: |
|
`alpha` is a constant that multiplies the regularization term. The higher the value, the stronger the regularization. |
|
`max_iter` is the maximum number of passes over the training data (aka epochs). |
|
`Standardise` centers the dataset |
|
|
|
## Dataset |
|
[Iris Dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set) |
|
|
|
## Model |
|
currentmodule: [sklearn.linear_model](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model) |
|
class:`SGDClassifier` is the estimator used in this example. |
|
|
|
""" |
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(''' |
|
<div> |
|
<h1 style='text-align: center'>Plot multi-class SGD on iris dataset</h1> |
|
</div> |
|
''') |
|
|
|
gr.Markdown(model_card) |
|
gr.Markdown("Author: <a href=\"https://huggingface.co/sulpha\">sulpha</a>") |
|
d0 = gr.Slider(0.001,5,step=0.001,value=0.001,label='alpha') |
|
d1 = gr.Slider(1,1001,step=10,value=100,label='max_iter') |
|
d2 = gr.Checkbox(value=True,label='Standardize') |
|
|
|
btn =gr.Button(value='Submit') |
|
btn.click(make_plot,inputs=[d0,d1,d2],outputs=[gr.Plot(),gr.Markdown()]) |
|
|
|
demo.launch() |
|
|