snehilsanyal commited on
Commit
a2f1844
β€’
1 Parent(s): 043ae0e

Add application file

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from sklearn.linear_model import SGDClassifier
4
+ from sklearn.datasets import make_blobs
5
+ import gradio as gr
6
+
7
+ def plot_max_margin_hyperplane():
8
+ # we create 50 separable points
9
+ X, Y = make_blobs(n_samples=50, centers=2, random_state=0, cluster_std=0.60)
10
+ # fit the model
11
+ clf = SGDClassifier(loss="hinge", alpha=0.01, max_iter=200)
12
+ clf.fit(X, Y)
13
+ # plot the line, the points, and the nearest vectors to the plane
14
+ xx = np.linspace(-1, 5, 10)
15
+ yy = np.linspace(-1, 5, 10)
16
+
17
+ X1, X2 = np.meshgrid(xx, yy)
18
+ Z = np.empty(X1.shape)
19
+ for (i, j), val in np.ndenumerate(X1):
20
+ x1 = val
21
+ x2 = X2[i, j]
22
+ p = clf.decision_function([[x1, x2]])
23
+ Z[i, j] = p[0]
24
+ levels = [-1.0, 0.0, 1.0]
25
+ linestyles = ["dashed", "solid", "dashed"]
26
+ colors = "k"
27
+ fig = plt.figure()
28
+ plt.contour(X1, X2, Z, levels, colors=colors, linestyles=linestyles)
29
+ plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired, edgecolor="black", s=20)
30
+
31
+ plt.axis("tight")
32
+ #plt.show()
33
+ return fig
34
+
35
+ heading = 'πŸ€—πŸ§‘πŸ€πŸ’™ SGD: Maximum Margin Separating Hyperplane'
36
+
37
+ with gr.blocks(title = heading) as demo:
38
+ gr.Markdown("# {}".format(heading))
39
+ gr.Markdown(
40
+ """
41
+ ## This demo visualizes the maximum margin hyperplane that seperates\
42
+ a two-class separable dataset using a linear SVM classifier trained using SGD.
43
+ """
44
+ )
45
+ gr.Markdown('Demo is based on [this script](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_separating_hyperplane.html#sphx-glr-auto-examples-linear-model-plot-sgd-separating-hyperplane-py)')
46
+
47
+ button = gr.Button(value = 'Visualize SGD Maximum Margin Hyperplane')
48
+ button.click(plot_max_margin_hyperplane, outputs = gr.Plot())
49
+
50
+ demo.launch()