import numpy as np import hdbscan from minisom import MiniSom import pickle from collections import Counter import matplotlib.pyplot as plt import phate import imageio from tqdm import tqdm import io import plotly.graph_objs as go import plotly.subplots as sp import umap from sklearn.datasets import make_blobs from sklearn.preprocessing import LabelEncoder from sklearn.cluster import KMeans from sklearn.semi_supervised import LabelSpreading from moviepy.editor import * class ClusterSOM: def __init__(self): self.hdbscan_model = None self.som_models = {} self.sigma_values = {} self.mean_values = {} self.cluster_mapping = {} self.embedding = None self.dim_red_op = None def train(self, dataset, min_samples_per_cluster=100, n_clusters=None, som_size=(20, 20), sigma=1.0, learning_rate=0.5, num_iteration=200000, random_seed=42, n_neighbors=5, coverage=0.95): """ Train HDBSCAN and SOM models on the given dataset. """ # Train HDBSCAN model print('Identifying clusters in the embedding ...') self.hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=min_samples_per_cluster) self.hdbscan_model.fit(dataset) # Calculate n_clusters if not provided if n_clusters is None: cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common()) cluster_labels = list(cluster_labels) total_points = sum(counts) covered_points = 0 n_clusters = 0 for count in counts: covered_points += count n_clusters += 1 if covered_points / total_points >= coverage: break # Train SOM models for the n_clusters most common clusters in the HDBSCAN model cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common(n_clusters + 1)) cluster_labels = list(cluster_labels) if -1 in cluster_labels: cluster_labels.remove(-1) else: cluster_labels.pop() for i, label in tqdm(enumerate(cluster_labels), total=len(cluster_labels), desc="Fitting 2D maps"): if label == -1: continue # Ignore noise cluster_data = dataset[self.hdbscan_model.labels_ == label] som = MiniSom(som_size[0], som_size[1], dataset.shape[1], sigma=sigma, learning_rate=learning_rate, random_seed=random_seed) som.train_random(cluster_data, num_iteration) self.som_models[i+1] = som self.cluster_mapping[i+1] = label # Compute sigma values mean_cluster, sigma_cluster = self.compute_sigma_values(cluster_data, som_size, som, n_neighbors=n_neighbors) self.sigma_values[i+1] = sigma_cluster self.mean_values[i+1] = mean_cluster def compute_sigma_values(self, cluster_data, som_size, som, n_neighbors=5): som_weights = som.get_weights() # Assign each datapoint to its nearest node partitions = {idx: [] for idx in np.ndindex(som_size[0], som_size[1])} for sample in cluster_data: x, y = som.winner(sample) partitions[(x, y)].append(sample) # Compute the mean distance and std deviation of these partitions mean_cluster = np.zeros(som_size) sigma_cluster = np.zeros(som_size) for idx in partitions: if len(partitions[idx]) > 0: partition_data = np.array(partitions[idx]) mean_distance = np.mean(np.linalg.norm(partition_data - som_weights[idx], axis=-1)) std_distance = np.std(np.linalg.norm(partition_data - som_weights[idx], axis=-1)) else: mean_distance = 0 std_distance = 0 mean_cluster[idx] = mean_distance sigma_cluster[idx] = std_distance return mean_cluster, sigma_cluster def train_label(self, labeled_data, labels): """ Train on labeled data to find centroids and compute distances to the labels. """ le = LabelEncoder() encoded_labels = le.fit_transform(labels) unique_labels = np.unique(encoded_labels) # Use label spreading to propagate the labels label_prop_model = LabelSpreading(kernel='knn', n_neighbors=5) label_prop_model.fit(labeled_data, encoded_labels) # Find the centroids for each label using KMeans kmeans = KMeans(n_clusters=len(unique_labels), random_state=42) kmeans.fit(labeled_data) # Store the label centroids and label encodings self.label_centroids = kmeans.cluster_centers_ self.label_encodings = le def predict(self, data, sigma_factor=1.5): """ Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value. Also, predict the label and distance to the center of the label if labels are trained. """ results = [] for sample in data: min_distance = float('inf') nearest_cluster_idx = None nearest_node = None for i, som in self.som_models.items(): x, y = som.winner(sample) node = som.get_weights()[x, y] distance = np.linalg.norm(sample - node) if distance < min_distance: min_distance = distance nearest_cluster_idx = i nearest_node = (x, y) # Check if the nearest node is within the sigma value if min_distance <= self.mean_values[nearest_cluster_idx][nearest_node] * 1.5: # * self.sigma_values[nearest_cluster_idx][nearest_node] * sigma_factor: if hasattr(self, 'label_centroids'): # Predict the label and distance to the center of the label label_idx = self.label_encodings.inverse_transform([nearest_cluster_idx - 1])[0] label_distance = np.linalg.norm(sample - self.label_centroids[label_idx]) results.append((nearest_cluster_idx, nearest_node, label_idx, label_distance)) else: results.append((nearest_cluster_idx, nearest_node)) else: results.append((-1, None)) # Noise return results def plot_embedding(self, new_data=None, dim_reduction='umap', interactive=False): """ Plot the dataset and SOM grids for each cluster. If new_data is provided, it will be used for plotting instead of the entire dataset. """ if self.hdbscan_model is None: raise ValueError("HDBSCAN model not trained yet.") if len(self.som_models) == 0: raise ValueError("SOM models not trained yet.") if dim_reduction not in ['phate', 'umap']: raise ValueError("Invalid dimensionality reduction method. Use 'phate' or 'umap'.") if self.dim_red_op is None or self.embedding is None: n_components = 3 if dim_reduction == 'phate': self.dim_red_op = phate.PHATE(n_components=n_components, random_state=42) elif dim_reduction == 'umap': self.dim_red_op = umap.UMAP(n_components=n_components, random_state=42) self.embedding = self.dim_red_op.fit_transform(new_data) if new_data is not None: new_embedding = self.dim_red_op.transform(new_data) else: new_embedding = self.embedding if interactive: fig = sp.make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter3d'}]]) else: fig = plt.figure(figsize=(30, 30)) ax = fig.add_subplot(111, projection='3d') colors = plt.cm.rainbow(np.linspace(0, 1, len(self.som_models) + 1)) for reindexed_label, som in self.som_models.items(): original_label = self.cluster_mapping[reindexed_label] cluster_data = embedding[self.hdbscan_model.labels_ == original_label] som_weights = som.get_weights() som_embedding = dim_red_op.transform(som_weights.reshape(-1, dataset.shape[1])).reshape(som_weights.shape[0], som_weights.shape[1], n_components) if interactive: # Plot the original data points fig.add_trace( go.Scatter3d( x=cluster_data[:, 0], y=cluster_data[:, 1], z=cluster_data[:, 2], mode='markers', marker=dict(color=colors[reindexed_label], size=1), name=f"Cluster {reindexed_label}" ) ) else: # Plot the original data points ax.scatter(cluster_data[:, 0], cluster_data[:, 1], cluster_data[:, 2], c=[colors[reindexed_label]], alpha=0.3, s=5, label=f"Cluster {reindexed_label}") for x in range(som_embedding.shape[0]): for y in range(som_embedding.shape[1]): if interactive: # Plot the SOM grid fig.add_trace( go.Scatter3d( x=[som_embedding[x, y, 0]], y=[som_embedding[x, y, 1]], z=[som_embedding[x, y, 2]], mode='markers+text', marker=dict(color=colors[reindexed_label], size=3, symbol='circle'), text=[f"{x},{y}"], textposition="top center" ) ) else: # Plot the SOM grid ax.plot([som_embedding[x, y, 0]], [som_embedding[x, y, 1]], [som_embedding[x, y, 2]], '+', markersize=8, mew=2, zorder=10, c=colors[reindexed_label]) for i in range(som_embedding.shape[0] - 1): for j in range(som_embedding.shape[1] - 1): if interactive: # Plot the SOM connections fig.add_trace( go.Scatter3d( x=np.append(som_embedding[i:i+2, j, 0], som_embedding[i, j:j+2, 0]), y=np.append(som_embedding[i:i+2, j, 1], som_embedding[i, j:j+2, 1]), z=np.append(som_embedding[i:i+2, j, 2], som_embedding[i, j:j+2, 2]), mode='lines', line=dict(color=colors[reindexed_label], width=2), showlegend=False ) ) else: # Plot the SOM connections ax.plot(som_embedding[i:i+2, j, 0], som_embedding[i:i+2, j, 1], som_embedding[i:i+2, j, 2], lw=1, c=colors[reindexed_label]) ax.plot(som_embedding[i, j:j+2, 0], som_embedding[i, j:j+2, 1], som_embedding[i, j:j+2, 2], lw=1, c=colors[reindexed_label]) if interactive: # Plot noise noise_data = embedding[self.hdbscan_model.labels_ == -1] if len(noise_data) > 0: fig.add_trace( go.Scatter3d( x=noise_data[:, 0], y=noise_data[:, 1], z=noise_data[:, 2], mode='markers', marker=dict(color="gray", size=1), name="Noise" ) ) fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z')) fig.show() else: # Plot noise noise_data = embedding[self.hdbscan_model.labels_ == -1] if len(noise_data) > 0: ax.scatter(noise_data[:, 0], noise_data[:, 1], noise_data[:, 2], c="gray", label="Noise") ax.legend() plt.show() def plot_label_heatmap(self): """ Plot a heatmap for each main cluster showing the best label for each coordinate in a single subplot layout. """ if not hasattr(self, 'label_centroids'): raise ValueError("Labels not trained yet.") n_labels = len(self.label_centroids) label_colors = plt.cm.rainbow(np.linspace(0, 1, n_labels)) n_clusters = len(self.som_models) # Create a subplot layout with a heatmap for each main cluster n_rows = int(np.ceil(np.sqrt(n_clusters))) n_cols = n_rows if n_rows * (n_rows - 1) < n_clusters else n_rows - 1 fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10), squeeze=False) for i, (reindexed_label, som) in enumerate(self.som_models.items()): som_weights = som.get_weights() label_map = np.zeros(som_weights.shape[:2], dtype=int) label_distance_map = np.full(som_weights.shape[:2], np.inf) for label_idx, label_centroid in enumerate(self.label_centroids): for x in range(som_weights.shape[0]): for y in range(som_weights.shape[1]): node = som_weights[x, y] distance = np.linalg.norm(label_centroid - node) if distance < label_distance_map[x, y]: label_distance_map[x, y] = distance label_map[x, y] = label_idx row, col = i // n_cols, i % n_cols ax = axes[row, col] cmap = plt.cm.rainbow cmap.set_under(color='white') im = ax.imshow(label_map, cmap=cmap, origin='lower', interpolation='none', vmin=0.5) ax.set_xticks(range(label_map.shape[1])) ax.set_yticks(range(label_map.shape[0])) ax.grid(True, linestyle='-', linewidth=0.5) ax.set_title(f"Label Heatmap for Cluster {reindexed_label}") # Add a colorbar for label colors cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) cbar = fig.colorbar(im, cax=cbar_ax, ticks=range(n_labels)) cbar.ax.set_yticklabels(self.label_encodings.classes_) # Adjust the layout to fit everything nicely fig.subplots_adjust(wspace=0.5, hspace=0.5, right=0.9) plt.show() def plot_activation(self, data, filename='prediction_output', start=None, end=None): """ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. """ if len(self.som_models) == 0: raise ValueError("SOM models not trained yet.") if start is None: start = 0 if end is None: end = len(data) images = [] for sample in tqdm(data[start:end], desc="Visualizing prediction output"): prediction = self.predict([sample])[0] # if prediction[0] == -1: # Noise # continue fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) for idx, (som_key, som) in enumerate(self.som_models.items()): ax = axes[idx] activation_map = np.zeros(som._weights.shape[:2]) for x in range(som._weights.shape[0]): for y in range(som._weights.shape[1]): activation_map[x, y] = np.linalg.norm(sample - som._weights[x, y]) winner = som.winner(sample) # Find the BMU for this SOM activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap if som_key == prediction[0]: # Active SOM im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') if hasattr(self, 'label_centroids'): label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] ax.set_xlabel(f"Label: {label_idx}", fontsize=12) else: # Inactive SOM im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') ax.set_title(f"SOM {som_key}") ax.set_xticks(range(activation_map.shape[1])) ax.set_yticks(range(activation_map.shape[0])) ax.grid(True, linestyle='-', linewidth=0.5) # Create a colorbar for each frame fig.subplots_adjust(right=0.8) cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) try: fig.colorbar(im_active, cax=cbar_ax) except: pass # Save the plot to a buffer buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) img = imageio.imread(buf) images.append(img) plt.close() # Save the images as a GIF imageio.mimsave(f"{filename}.gif", images, duration=500, loop=1) # Load the gif gif_file = f"{filename}.gif" # Replace with the path to your GIF file clip = VideoFileClip(gif_file) # Convert the gif to mp4 mp4_file = f"{filename}.mp4" # Replace with the desired output path clip.write_videofile(mp4_file, codec='libx264') # Close the clip to release resources clip.close() def save(self, file_path): """ Save the ClusterSOM model to a file. """ model_data = (self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping) if hasattr(self, 'label_centroids'): model_data += (self.label_centroids, self.label_encodings) with open(file_path, "wb") as f: pickle.dump(model_data, f) def load(self, file_path): """ Load a ClusterSOM model from a file. """ with open(file_path, "rb") as f: model_data = pickle.load(f) self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5] if len(model_data) > 5: self.label_centroids, self.label_encodings = model_data[5:] def plot_activation_v2(self, data, slice_select): """ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. """ if len(self.som_models) == 0: raise ValueError("SOM models not trained yet.") try: prediction = self.predict([data[int(slice_select)-1]])[0] except: prediction = self.predict([data[int(slice_select)-2]])[0] fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) for idx, (som_key, som) in enumerate(self.som_models.items()): ax = axes[idx] activation_map = np.zeros(som._weights.shape[:2]) for x in range(som._weights.shape[0]): for y in range(som._weights.shape[1]): activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y]) winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap if som_key == prediction[0]: # Active SOM im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') if hasattr(self, 'label_centroids'): label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] ax.set_xlabel(f"Label: {label_idx}", fontsize=12) else: # Inactive SOM im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') ax.set_title(f"SOM {som_key}") ax.set_xticks(range(activation_map.shape[1])) ax.set_yticks(range(activation_map.shape[0])) ax.grid(True, linestyle='-', linewidth=0.5) plt.tight_layout() return fig