|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
from sklearn.cluster import AgglomerativeClustering |
|
from sklearn.neighbors import kneighbors_graph |
|
|
|
import gradio as gr |
|
|
|
plt.switch_backend("agg") |
|
|
|
np.random.seed(42) |
|
|
|
def agglomorative_cluster(n_samples: int, n_neighbours: int, n_clusters: int, linkage: str, connectivity: bool) -> "plt.Figure": |
|
|
|
t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, n_samples)) |
|
x = t * np.cos(t) |
|
y = t * np.sin(t) |
|
|
|
X = np.concatenate((x, y)) |
|
X += 0.7 * np.random.randn(2, n_samples) |
|
X = X.T |
|
|
|
knn_graph = kneighbors_graph(X, n_neighbors=n_neighbours, include_self=False) |
|
connectivity = knn_graph if not connectivity else None |
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(24, 15)) |
|
model = AgglomerativeClustering(linkage=linkage, connectivity=connectivity, n_clusters=int(n_clusters)) |
|
model.fit(X) |
|
ax.scatter(X[:, 0], X[:, 1], c=model.labels_, cmap=plt.cm.nipy_spectral) |
|
ax.axis("equal") |
|
ax.axis("off") |
|
|
|
return fig |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# Agglomorative Clustering with and without Structure |
|
|
|
This space is an implementation of the scikit-learn document [Agglomorative clustering with and without structure](https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html#sphx-glr-auto-examples-cluster-plot-agglomerative-clustering-py) |
|
This space shows the effects of imposing **connectivity graph** to capture local structure in the data. |
|
You can uncheck the checkbox `connectivity` to see the effect on data clustering without **connectivity graph**. There are other parameters in this space |
|
which you can play with such as `n_samples` (the number of data samples), `n_neighbours` (the number of neighbours), `n_clusters` (the number of clusters) and |
|
what type of linkage to use for Agglomorative clustering `linkage`. |
|
|
|
Have fun playing with the tool π€ |
|
""") |
|
|
|
n_samples = gr.Slider(0, 20_000, label="n_samples", info="the number of samples in the data.", step=1) |
|
n_neighbours = gr.Slider(0, 30, label="n_neighbours", info="the number of neighbours in the data", step=1) |
|
n_clusters = gr.Slider(3, 30, label="n_clusters", info="the number of clusters in the data", step=2) |
|
linkage = gr.Dropdown(['average', 'complete', 'ward', 'single'], label="linkage", info="the different types of aggolomorative clustering techniques") |
|
connectivity = gr.Checkbox(True, label="connectivity", info="whether to impose a connectivity into the graph") |
|
output = gr.Plot(label="Plot") |
|
|
|
plot_btn = gr.Button("Plot") |
|
plot_btn.click(fn=agglomorative_cluster, inputs=[n_samples, n_neighbours, n_clusters, linkage, connectivity], |
|
outputs=output, api_name="plotcluster") |
|
|
|
demo.launch() |
|
|
|
|