File size: 3,042 Bytes
67cc8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25762a5
67cc8ee
 
 
 
 
 
25762a5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import time
import numpy as np
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from skimage.data import coins
from skimage.transform import rescale
from sklearn.feature_extraction import image
from sklearn.cluster import spectral_clustering
import gradio as gr


# function for making the clustering plot.
# input: one of the following algorithms: "kmeans", "discretize", "cluster_qr"
def getClusteringPlot(algorithm):
    # load the coins as a numpy array
    orig_coins = coins()

    # Pre-processing the image
    smoothened_coins = gaussian_filter(orig_coins, sigma=2)
    rescaled_coins = rescale(smoothened_coins, 0.2, mode="reflect", anti_aliasing=False)

    # Convert the image into a graph 
    graph = image.img_to_graph(rescaled_coins)

    beta = 10
    eps = 1e-6
    graph.data = np.exp(-beta * graph.data / graph.data.std()) + eps

    # The number of segmented regions to display needs to be chosen manually
    n_regions = 26

    # The spectral clustering quality may also benetif from requesting
    # extra regions for segmentation.
    n_regions_plus = 3
    
    t0 = time.time()
    labels = spectral_clustering(
        graph,
        n_clusters=(n_regions + n_regions_plus),
        eigen_tol=1e-7,
        assign_labels=algorithm,
        random_state=42,
    )

    t1 = time.time()
    labels = labels.reshape(rescaled_coins.shape)
    plt.figure(figsize=(5, 5))
    plt.imshow(rescaled_coins, cmap=plt.cm.gray)

    plt.xticks(())
    plt.yticks(())
    title = "Spectral clustering: %s, %.2fs" % (algorithm, (t1 - t0))
    plt.title(title)
    for l in range(n_regions):
        colors = [plt.cm.nipy_spectral((l + 4) / float(n_regions + 4))]
        plt.contour(labels == l, colors=colors)
        # To view individual segments as appear comment in plt.pause(0.5)
    return (plt, "%.3fs" % (t1 - t0))


# building the gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## Segmenting the picture of Greek coins in regions 🪙")
    gr.Markdown("This demo is based on this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_coin_segmentation.html#sphx-glr-auto-examples-cluster-plot-coin-segmentation-py).")
    gr.Markdown("In this demo, we compare three strategies for performing segmentation-clustering and breaking the below image of Greek coins into multiple partly-homogeneous regions.")
    
    inp = gr.Radio(["kmeans", "discretize", "cluster_qr"], label="Solver", info="Choose a clustering algorithm", value="kmeans")
    with gr.Row():
        plot = gr.Plot(label="Plot")
        num = gr.Textbox(label="Running Time")
    inp.change(getClusteringPlot, inputs=[inp], outputs=[plot, num])
    demo.load(getClusteringPlot, inputs=[inp], outputs=[plot, num])

    gr.HTML("<hr>")
    gr.Image(coins(), label="An image of 24 Greek coins")
    gr.Markdown("The image is retrieved from scikit-image's data [gallery](https://scikit-image.org/docs/stable/auto_examples/).")

if __name__ == "__main__":
    demo.launch()