Spaces:
Build error
Build error
import numpy as np | |
import hdbscan | |
from minisom import MiniSom | |
import pickle | |
from collections import Counter | |
import matplotlib.pyplot as plt | |
import phate | |
import imageio | |
from tqdm import tqdm | |
import io | |
import plotly.graph_objs as go | |
import plotly.subplots as sp | |
import umap | |
from sklearn.datasets import make_blobs | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.cluster import KMeans | |
from sklearn.semi_supervised import LabelSpreading | |
from moviepy.editor import * | |
class ClusterSOM: | |
def __init__(self): | |
self.hdbscan_model = None | |
self.som_models = {} | |
self.sigma_values = {} | |
self.mean_values = {} | |
self.cluster_mapping = {} | |
self.embedding = None | |
self.dim_red_op = None | |
def train(self, dataset, min_samples_per_cluster=100, n_clusters=None, som_size=(20, 20), sigma=1.0, learning_rate=0.5, num_iteration=200000, random_seed=42, n_neighbors=5, coverage=0.95): | |
""" | |
Train HDBSCAN and SOM models on the given dataset. | |
""" | |
# Train HDBSCAN model | |
print('Identifying clusters in the embedding ...') | |
self.hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=min_samples_per_cluster) | |
self.hdbscan_model.fit(dataset) | |
# Calculate n_clusters if not provided | |
if n_clusters is None: | |
cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common()) | |
cluster_labels = list(cluster_labels) | |
total_points = sum(counts) | |
covered_points = 0 | |
n_clusters = 0 | |
for count in counts: | |
covered_points += count | |
n_clusters += 1 | |
if covered_points / total_points >= coverage: | |
break | |
# Train SOM models for the n_clusters most common clusters in the HDBSCAN model | |
cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common(n_clusters + 1)) | |
cluster_labels = list(cluster_labels) | |
if -1 in cluster_labels: | |
cluster_labels.remove(-1) | |
else: | |
cluster_labels.pop() | |
for i, label in tqdm(enumerate(cluster_labels), total=len(cluster_labels), desc="Fitting 2D maps"): | |
if label == -1: | |
continue # Ignore noise | |
cluster_data = dataset[self.hdbscan_model.labels_ == label] | |
som = MiniSom(som_size[0], som_size[1], dataset.shape[1], sigma=sigma, learning_rate=learning_rate, random_seed=random_seed) | |
som.train_random(cluster_data, num_iteration) | |
self.som_models[i+1] = som | |
self.cluster_mapping[i+1] = label | |
# Compute sigma values | |
mean_cluster, sigma_cluster = self.compute_sigma_values(cluster_data, som_size, som, n_neighbors=n_neighbors) | |
self.sigma_values[i+1] = sigma_cluster | |
self.mean_values[i+1] = mean_cluster | |
def compute_sigma_values(self, cluster_data, som_size, som, n_neighbors=5): | |
som_weights = som.get_weights() | |
# Assign each datapoint to its nearest node | |
partitions = {idx: [] for idx in np.ndindex(som_size[0], som_size[1])} | |
for sample in cluster_data: | |
x, y = som.winner(sample) | |
partitions[(x, y)].append(sample) | |
# Compute the mean distance and std deviation of these partitions | |
mean_cluster = np.zeros(som_size) | |
sigma_cluster = np.zeros(som_size) | |
for idx in partitions: | |
if len(partitions[idx]) > 0: | |
partition_data = np.array(partitions[idx]) | |
mean_distance = np.mean(np.linalg.norm(partition_data - som_weights[idx], axis=-1)) | |
std_distance = np.std(np.linalg.norm(partition_data - som_weights[idx], axis=-1)) | |
else: | |
mean_distance = 0 | |
std_distance = 0 | |
mean_cluster[idx] = mean_distance | |
sigma_cluster[idx] = std_distance | |
return mean_cluster, sigma_cluster | |
def train_label(self, labeled_data, labels): | |
""" | |
Train on labeled data to find centroids and compute distances to the labels. | |
""" | |
le = LabelEncoder() | |
encoded_labels = le.fit_transform(labels) | |
unique_labels = np.unique(encoded_labels) | |
# Use label spreading to propagate the labels | |
label_prop_model = LabelSpreading(kernel='knn', n_neighbors=5) | |
label_prop_model.fit(labeled_data, encoded_labels) | |
# Find the centroids for each label using KMeans | |
kmeans = KMeans(n_clusters=len(unique_labels), random_state=42) | |
kmeans.fit(labeled_data) | |
# Store the label centroids and label encodings | |
self.label_centroids = kmeans.cluster_centers_ | |
self.label_encodings = le | |
def predict(self, data, sigma_factor=1.5): | |
""" | |
Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value. | |
Also, predict the label and distance to the center of the label if labels are trained. | |
""" | |
results = [] | |
for sample in data: | |
min_distance = float('inf') | |
nearest_cluster_idx = None | |
nearest_node = None | |
for i, som in self.som_models.items(): | |
x, y = som.winner(sample) | |
node = som.get_weights()[x, y] | |
distance = np.linalg.norm(sample - node) | |
if distance < min_distance: | |
min_distance = distance | |
nearest_cluster_idx = i | |
nearest_node = (x, y) | |
# Check if the nearest node is within the sigma value | |
if min_distance <= self.mean_values[nearest_cluster_idx][nearest_node] * 1.5: # * self.sigma_values[nearest_cluster_idx][nearest_node] * sigma_factor: | |
if hasattr(self, 'label_centroids'): | |
# Predict the label and distance to the center of the label | |
label_idx = self.label_encodings.inverse_transform([nearest_cluster_idx - 1])[0] | |
label_distance = np.linalg.norm(sample - self.label_centroids[label_idx]) | |
results.append((nearest_cluster_idx, nearest_node, label_idx, label_distance)) | |
else: | |
results.append((nearest_cluster_idx, nearest_node)) | |
else: | |
results.append((-1, None)) # Noise | |
return results | |
def plot_embedding(self, new_data=None, dim_reduction='umap', interactive=False): | |
""" | |
Plot the dataset and SOM grids for each cluster. | |
If new_data is provided, it will be used for plotting instead of the entire dataset. | |
""" | |
if self.hdbscan_model is None: | |
raise ValueError("HDBSCAN model not trained yet.") | |
if len(self.som_models) == 0: | |
raise ValueError("SOM models not trained yet.") | |
if dim_reduction not in ['phate', 'umap']: | |
raise ValueError("Invalid dimensionality reduction method. Use 'phate' or 'umap'.") | |
if self.dim_red_op is None or self.embedding is None: | |
n_components = 3 | |
if dim_reduction == 'phate': | |
self.dim_red_op = phate.PHATE(n_components=n_components, random_state=42) | |
elif dim_reduction == 'umap': | |
self.dim_red_op = umap.UMAP(n_components=n_components, random_state=42) | |
self.embedding = self.dim_red_op.fit_transform(new_data) | |
if new_data is not None: | |
new_embedding = self.dim_red_op.transform(new_data) | |
else: | |
new_embedding = self.embedding | |
if interactive: | |
fig = sp.make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter3d'}]]) | |
else: | |
fig = plt.figure(figsize=(30, 30)) | |
ax = fig.add_subplot(111, projection='3d') | |
colors = plt.cm.rainbow(np.linspace(0, 1, len(self.som_models) + 1)) | |
for reindexed_label, som in self.som_models.items(): | |
original_label = self.cluster_mapping[reindexed_label] | |
cluster_data = embedding[self.hdbscan_model.labels_ == original_label] | |
som_weights = som.get_weights() | |
som_embedding = dim_red_op.transform(som_weights.reshape(-1, dataset.shape[1])).reshape(som_weights.shape[0], som_weights.shape[1], n_components) | |
if interactive: | |
# Plot the original data points | |
fig.add_trace( | |
go.Scatter3d( | |
x=cluster_data[:, 0], | |
y=cluster_data[:, 1], | |
z=cluster_data[:, 2], | |
mode='markers', | |
marker=dict(color=colors[reindexed_label], size=1), | |
name=f"Cluster {reindexed_label}" | |
) | |
) | |
else: | |
# Plot the original data points | |
ax.scatter(cluster_data[:, 0], cluster_data[:, 1], cluster_data[:, 2], c=[colors[reindexed_label]], alpha=0.3, s=5, label=f"Cluster {reindexed_label}") | |
for x in range(som_embedding.shape[0]): | |
for y in range(som_embedding.shape[1]): | |
if interactive: | |
# Plot the SOM grid | |
fig.add_trace( | |
go.Scatter3d( | |
x=[som_embedding[x, y, 0]], | |
y=[som_embedding[x, y, 1]], | |
z=[som_embedding[x, y, 2]], | |
mode='markers+text', | |
marker=dict(color=colors[reindexed_label], size=3, symbol='circle'), | |
text=[f"{x},{y}"], | |
textposition="top center" | |
) | |
) | |
else: | |
# Plot the SOM grid | |
ax.plot([som_embedding[x, y, 0]], [som_embedding[x, y, 1]], [som_embedding[x, y, 2]], '+', markersize=8, mew=2, zorder=10, c=colors[reindexed_label]) | |
for i in range(som_embedding.shape[0] - 1): | |
for j in range(som_embedding.shape[1] - 1): | |
if interactive: | |
# Plot the SOM connections | |
fig.add_trace( | |
go.Scatter3d( | |
x=np.append(som_embedding[i:i+2, j, 0], som_embedding[i, j:j+2, 0]), | |
y=np.append(som_embedding[i:i+2, j, 1], som_embedding[i, j:j+2, 1]), | |
z=np.append(som_embedding[i:i+2, j, 2], som_embedding[i, j:j+2, 2]), | |
mode='lines', | |
line=dict(color=colors[reindexed_label], width=2), | |
showlegend=False | |
) | |
) | |
else: | |
# Plot the SOM connections | |
ax.plot(som_embedding[i:i+2, j, 0], som_embedding[i:i+2, j, 1], som_embedding[i:i+2, j, 2], lw=1, c=colors[reindexed_label]) | |
ax.plot(som_embedding[i, j:j+2, 0], som_embedding[i, j:j+2, 1], som_embedding[i, j:j+2, 2], lw=1, c=colors[reindexed_label]) | |
if interactive: | |
# Plot noise | |
noise_data = embedding[self.hdbscan_model.labels_ == -1] | |
if len(noise_data) > 0: | |
fig.add_trace( | |
go.Scatter3d( | |
x=noise_data[:, 0], | |
y=noise_data[:, 1], | |
z=noise_data[:, 2], | |
mode='markers', | |
marker=dict(color="gray", size=1), | |
name="Noise" | |
) | |
) | |
fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z')) | |
fig.show() | |
else: | |
# Plot noise | |
noise_data = embedding[self.hdbscan_model.labels_ == -1] | |
if len(noise_data) > 0: | |
ax.scatter(noise_data[:, 0], noise_data[:, 1], noise_data[:, 2], c="gray", label="Noise") | |
ax.legend() | |
plt.show() | |
def plot_label_heatmap(self): | |
""" | |
Plot a heatmap for each main cluster showing the best label for each coordinate in a single subplot layout. | |
""" | |
if not hasattr(self, 'label_centroids'): | |
raise ValueError("Labels not trained yet.") | |
n_labels = len(self.label_centroids) | |
label_colors = plt.cm.rainbow(np.linspace(0, 1, n_labels)) | |
n_clusters = len(self.som_models) | |
# Create a subplot layout with a heatmap for each main cluster | |
n_rows = int(np.ceil(np.sqrt(n_clusters))) | |
n_cols = n_rows if n_rows * (n_rows - 1) < n_clusters else n_rows - 1 | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10), squeeze=False) | |
for i, (reindexed_label, som) in enumerate(self.som_models.items()): | |
som_weights = som.get_weights() | |
label_map = np.zeros(som_weights.shape[:2], dtype=int) | |
label_distance_map = np.full(som_weights.shape[:2], np.inf) | |
for label_idx, label_centroid in enumerate(self.label_centroids): | |
for x in range(som_weights.shape[0]): | |
for y in range(som_weights.shape[1]): | |
node = som_weights[x, y] | |
distance = np.linalg.norm(label_centroid - node) | |
if distance < label_distance_map[x, y]: | |
label_distance_map[x, y] = distance | |
label_map[x, y] = label_idx | |
row, col = i // n_cols, i % n_cols | |
ax = axes[row, col] | |
cmap = plt.cm.rainbow | |
cmap.set_under(color='white') | |
im = ax.imshow(label_map, cmap=cmap, origin='lower', interpolation='none', vmin=0.5) | |
ax.set_xticks(range(label_map.shape[1])) | |
ax.set_yticks(range(label_map.shape[0])) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
ax.set_title(f"Label Heatmap for Cluster {reindexed_label}") | |
# Add a colorbar for label colors | |
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) | |
cbar = fig.colorbar(im, cax=cbar_ax, ticks=range(n_labels)) | |
cbar.ax.set_yticklabels(self.label_encodings.classes_) | |
# Adjust the layout to fit everything nicely | |
fig.subplots_adjust(wspace=0.5, hspace=0.5, right=0.9) | |
plt.show() | |
def plot_activation(self, data, filename='prediction_output', start=None, end=None): | |
""" | |
Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. | |
""" | |
if len(self.som_models) == 0: | |
raise ValueError("SOM models not trained yet.") | |
if start is None: | |
start = 0 | |
if end is None: | |
end = len(data) | |
images = [] | |
for sample in tqdm(data[start:end], desc="Visualizing prediction output"): | |
prediction = self.predict([sample])[0] | |
# if prediction[0] == -1: # Noise | |
# continue | |
fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) | |
fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) | |
for idx, (som_key, som) in enumerate(self.som_models.items()): | |
ax = axes[idx] | |
activation_map = np.zeros(som._weights.shape[:2]) | |
for x in range(som._weights.shape[0]): | |
for y in range(som._weights.shape[1]): | |
activation_map[x, y] = np.linalg.norm(sample - som._weights[x, y]) | |
winner = som.winner(sample) # Find the BMU for this SOM | |
activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap | |
if som_key == prediction[0]: # Active SOM | |
im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') | |
ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign | |
ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') | |
if hasattr(self, 'label_centroids'): | |
label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] | |
ax.set_xlabel(f"Label: {label_idx}", fontsize=12) | |
else: # Inactive SOM | |
im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') | |
ax.set_title(f"SOM {som_key}") | |
ax.set_xticks(range(activation_map.shape[1])) | |
ax.set_yticks(range(activation_map.shape[0])) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
# Create a colorbar for each frame | |
fig.subplots_adjust(right=0.8) | |
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) | |
try: | |
fig.colorbar(im_active, cax=cbar_ax) | |
except: | |
pass | |
# Save the plot to a buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
img = imageio.imread(buf) | |
images.append(img) | |
plt.close() | |
# Save the images as a GIF | |
imageio.mimsave(f"{filename}.gif", images, duration=500, loop=1) | |
# Load the gif | |
gif_file = f"{filename}.gif" # Replace with the path to your GIF file | |
clip = VideoFileClip(gif_file) | |
# Convert the gif to mp4 | |
mp4_file = f"{filename}.mp4" # Replace with the desired output path | |
clip.write_videofile(mp4_file, codec='libx264') | |
# Close the clip to release resources | |
clip.close() | |
def save(self, file_path): | |
""" | |
Save the ClusterSOM model to a file. | |
""" | |
model_data = (self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping) | |
if hasattr(self, 'label_centroids'): | |
model_data += (self.label_centroids, self.label_encodings) | |
with open(file_path, "wb") as f: | |
pickle.dump(model_data, f) | |
def load(self, file_path): | |
""" | |
Load a ClusterSOM model from a file. | |
""" | |
with open(file_path, "rb") as f: | |
model_data = pickle.load(f) | |
self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5] | |
if len(model_data) > 5: | |
self.label_centroids, self.label_encodings = model_data[5:] | |
def plot_activation_v2(self, data, slice_select): | |
""" | |
Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. | |
""" | |
if len(self.som_models) == 0: | |
raise ValueError("SOM models not trained yet.") | |
try: | |
prediction = self.predict([data[int(slice_select)-1]])[0] | |
except: | |
prediction = self.predict([data[int(slice_select)-2]])[0] | |
fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) | |
fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) | |
for idx, (som_key, som) in enumerate(self.som_models.items()): | |
ax = axes[idx] | |
activation_map = np.zeros(som._weights.shape[:2]) | |
for x in range(som._weights.shape[0]): | |
for y in range(som._weights.shape[1]): | |
activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y]) | |
winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM | |
activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap | |
if som_key == prediction[0]: # Active SOM | |
im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') | |
ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign | |
ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') | |
if hasattr(self, 'label_centroids'): | |
label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] | |
ax.set_xlabel(f"Label: {label_idx}", fontsize=12) | |
else: # Inactive SOM | |
im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') | |
ax.set_title(f"SOM {som_key}") | |
ax.set_xticks(range(activation_map.shape[1])) | |
ax.set_yticks(range(activation_map.shape[0])) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
plt.tight_layout() | |
return fig |