Hnabil's picture
Update app.py
2cecf9a
from functools import partial
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter
import numpy as np
from sklearn import datasets, manifold
SEED = 0
N_COMPONENTS = 2
np.random.seed(SEED)
def get_circles(n_samples):
X, color = datasets.make_circles(
n_samples=n_samples,
factor=0.5,
noise=0.05,
random_state=SEED
)
return X, color
def get_s_curve(n_samples):
X, color = datasets.make_s_curve(n_samples=n_samples, random_state=SEED)
X[:, 1], X[:, 2] = X[:, 2], X[:, 1].copy()
return X, color
def get_uniform_grid(n_samples):
x = np.linspace(0, 1, int(np.sqrt(n_samples)))
xx, yy = np.meshgrid(x, x)
X = np.hstack(
[
xx.ravel().reshape(-1, 1),
yy.ravel().reshape(-1, 1),
]
)
color = xx.ravel()
return X, color
DATA_MAPPING = {
'Circles': get_circles,
'S-curve': get_s_curve,
'Uniform Grid': get_uniform_grid,
}
def plot_data(dataset: str, perplexity: int, n_samples: int, tsne: bool):
if isinstance(perplexity, dict):
perplexity = perplexity['value']
else:
perplexity = int(perplexity)
X, color = DATA_MAPPING[dataset](n_samples)
if tsne:
tsne = manifold.TSNE(
n_components=N_COMPONENTS,
init="random",
random_state=0,
perplexity=perplexity,
n_iter=400,
)
Y = tsne.fit_transform(X)
else:
Y = X
fig, ax = plt.subplots(figsize=(7, 7))
ax.scatter(Y[:, 0], Y[:, 1], c=color)
ax.xaxis.set_major_formatter(NullFormatter())
ax.yaxis.set_major_formatter(NullFormatter())
ax.axis("tight")
return fig
title = "t-SNE: The effect of various perplexity values on the shape"
description = """
t-Stochastic Neighborhood Embedding ([t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html)) is a powerful technique dimensionality reduction and visualization of high dimensional datasets.
One of the key parameters in t-SNE is perplexity, which controls the number of nearest neighbors used to represent each data point in the low-dimensional space.
In this illustration, we explore the impact of various perplexity values on t-SNE visualizations using three commonly used datasets: Concentric Circles, S-curve and Uniform Grid.
By comparing the resulting visualizations, we demonstrate how changing the perplexity value affects the shape of the visualization.
Created by [@Hnabil](https://huggingface.co/Hnabil) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/manifold/plot_t_sne_perplexity.html)
"""
with gr.Blocks(title=title) as demo:
gr.HTML(f"<b>{title}</b>")
gr.Markdown(description)
input_data = gr.Radio(
list(DATA_MAPPING),
value="Circles",
label="dataset"
)
n_samples = gr.Slider(
minimum=100,
maximum=1000,
value=150,
step=25,
label='Number of Samples'
)
perplexity = gr.Slider(
minimum=2,
maximum=100,
value=5,
step=1,
label='Perplexity'
)
with gr.Row():
with gr.Column():
plot = gr.Plot(label="Original data")
fn = partial(plot_data, tsne=False)
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
demo.load(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
with gr.Column():
plot = gr.Plot(label="t-SNE")
fn = partial(plot_data, tsne=True)
input_data.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
perplexity.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
n_samples.change(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
demo.load(fn=fn, inputs=[input_data, perplexity, n_samples], outputs=plot)
demo.launch()