""" ========================================================= Gradio Demo to plot multi-class SGD on the iris dataset ========================================================= Plot decision surface of multi-class SGD on iris dataset. The hyperplanes corresponding to the three one-versus-all (OVA) classifiers are represented by the dashed lines. Created by Syed Affan """ import gradio as gr import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.linear_model import SGDClassifier from sklearn.inspection import DecisionBoundaryDisplay import matplotlib.cm def make_plot(alpha,max_iter,Standardize): # import some data to play with iris = datasets.load_iris() fig = plt.figure() # we only take the first two features. We could # avoid this ugly slicing by using a two-dim dataset X = iris.data[:, :2] y = iris.target colors = "bry" # shuffle idx = np.arange(X.shape[0]) np.random.seed(13) np.random.shuffle(idx) X = X[idx] y = y[idx] # standardize if Standardize: mean = X.mean(axis=0) std = X.std(axis=0) X = (X - mean) / std clf = SGDClassifier(alpha=alpha, max_iter=max_iter).fit(X, y) accuracy = clf.score(X,y) acc = f'## The Accuracy on the entire dataset: {accuracy}' #fig,ax = subplots() ax = plt.gca() DecisionBoundaryDisplay.from_estimator( clf, X, cmap=matplotlib.cm.Paired, ax=ax, response_method="predict", xlabel=iris.feature_names[0], ylabel=iris.feature_names[1], ) plt.axis("tight") # Plot also the training points for i, color in zip(clf.classes_, colors): idx = np.where(y == i) plt.scatter( X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i], cmap=matplotlib.cm.Paired, edgecolor="black", s=20, ) plt.title("Decision surface of multi-class SGD") plt.axis("tight") # Plot the three one-against-all classifiers xmin, xmax = plt.xlim() ymin, ymax = plt.ylim() coef = clf.coef_ intercept = clf.intercept_ def plot_hyperplane(c, color): def line(x0): return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1] plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color) for i, color in zip(clf.classes_, colors): plot_hyperplane(i, color) plt.legend() return fig,acc demo = gr.Interface( title = 'Plot multi-class SGD on the iris dataset', fn = make_plot, inputs = [gr.Slider(0.0001,5,step = 0.001,value = 0.001), gr.Slider(1,1000,step=10,value=100), gr.Checkbox(value=True)], outputs = [gr.Plot(),gr.Markdown()] ).launch()