File size: 3,772 Bytes
f96cf95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs


def get_clusters_plot(n_blobs, quantile, cluster_std):
    X, _, centers = make_blobs(
        n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True
    )

    bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500)

    ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    ms.fit(X)
    labels = ms.labels_
    cluster_centers = ms.cluster_centers_

    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)

    fig = plt.figure()

    for k in range(n_clusters_):
        my_members = labels == k
        cluster_center = cluster_centers[k]
        plt.scatter(X[my_members, 0], X[my_members, 1])
        plt.plot(
            cluster_center[0],
            cluster_center[1],
            "x",
            markeredgecolor="k",
            markersize=14,
        )
        plt.xlabel("Feature 1")
        plt.ylabel("Feature 2")

    plt.title(f"Estimated number of clusters: {n_clusters_}")

    if len(centers) != n_clusters_:
        message = (
            '<p style="text-align: center;">'
            + f"The number of estimated clusters ({n_clusters_})"
            + f" differs from the true number of clusters ({n_blobs})."
            + " Try changing the `Quantile` parameter.</p>"
        )
    else:
        message = (
            '<p style="text-align: center;">'
            + f"The number of estimated clusters ({n_clusters_})"
            + f" matches the true number of clusters ({n_blobs})!</p>"
        )
    return fig, message


with gr.Blocks() as demo:
    gr.Markdown(
        """
            # Mean Shift Clustering

            This space shows how to use the [Mean Shift Clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html) algorithm to cluster 2D data points. You can change the parameters using the sliders and see how the model performs.
            
            This space is based on [sklearn's original demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py).
            """
    )
    with gr.Row():
        with gr.Column(scale=1):
            n_blobs = gr.Slider(
                minimum=2,
                maximum=10,
                label="Number of clusters in the data",
                step=1,
                value=3,
            )
            quantile = gr.Slider(
                minimum=0,
                maximum=1,
                step=0.05,
                value=0.2,
                label="Quantile",
                info="Used to determine clustering's bandwidth.",
            )
            cluster_std = gr.Slider(
                minimum=0.1,
                maximum=1,
                label="Clusters' standard deviation",
                step=0.1,
                value=0.6,
            )
        with gr.Column(scale=4):
            clusters_plots = gr.Plot(label="Clusters' Plot")
            message = gr.HTML()

    n_blobs.change(
        get_clusters_plot,
        [n_blobs, quantile, cluster_std],
        [clusters_plots, message],
        queue=False,
    )
    quantile.change(
        get_clusters_plot,
        [n_blobs, quantile, cluster_std],
        [clusters_plots, message],
        queue=False,
    )
    cluster_std.change(
        get_clusters_plot,
        [n_blobs, quantile, cluster_std],
        [clusters_plots, message],
        queue=False,
    )
    demo.load(
        get_clusters_plot,
        [n_blobs, quantile, cluster_std],
        [clusters_plots, message],
        queue=False,
    )

if __name__ == "__main__":
    demo.launch()