File size: 1,891 Bytes
a2f1844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f58dfa9
a2f1844
 
 
f58dfa9
a2f1844
 
 
f58dfa9
a2f1844
f58dfa9
a2f1844
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDClassifier
from sklearn.datasets import make_blobs
import gradio as gr 

def plot_max_margin_hyperplane():
    # we create 50 separable points
    X, Y = make_blobs(n_samples=50, centers=2, random_state=0, cluster_std=0.60)
    # fit the model
    clf = SGDClassifier(loss="hinge", alpha=0.01, max_iter=200)
    clf.fit(X, Y)
    # plot the line, the points, and the nearest vectors to the plane
    xx = np.linspace(-1, 5, 10)
    yy = np.linspace(-1, 5, 10)

    X1, X2 = np.meshgrid(xx, yy)
    Z = np.empty(X1.shape)
    for (i, j), val in np.ndenumerate(X1):
        x1 = val
        x2 = X2[i, j]
        p = clf.decision_function([[x1, x2]])
        Z[i, j] = p[0]
    levels = [-1.0, 0.0, 1.0]
    linestyles = ["dashed", "solid", "dashed"]
    colors = "k"
    fig = plt.figure()
    plt.contour(X1, X2, Z, levels, colors=colors, linestyles=linestyles)
    plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired, edgecolor="black", s=20)

    plt.axis("tight")
    #plt.show()
    return fig 

heading = 'πŸ€—πŸ§‘πŸ€πŸ’™ SGD: Maximum Margin Separating Hyperplane'

with gr.Blocks(title = heading, theme = 'snehilsanyal/scikit-learn') as demo:
    gr.Markdown("# {}".format(heading))
    gr.Markdown(
        """
        ### This demo visualizes the maximum margin hyperplane that seperates\
        a two-class separable dataset using a linear SVM classifier trained using SGD.
        """
    )
    gr.Markdown('Demo is based on [this script from scikit-learn documentation](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_separating_hyperplane.html#sphx-glr-auto-examples-linear-model-plot-sgd-separating-hyperplane-py)')

    button = gr.Button(value = 'Visualize Maximum Margin Hyperplane')
    button.click(plot_max_margin_hyperplane, outputs = gr.Plot())

demo.launch()