import io import math import pickle import imageio import hdbscan import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm from minisom import MiniSom from collections import Counter from sklearn.cluster import KMeans from moviepy.editor import ImageSequenceClip from sklearn.preprocessing import LabelEncoder from sklearn.semi_supervised import LabelSpreading class ClusterSOM: def __init__(self): self.hdbscan_model = None self.som_models = {} self.sigma_values = {} self.mean_values = {} self.cluster_mapping = {} self.embedding = None self.dim_red_op = None def train(self, dataset, min_samples_per_cluster=100, n_clusters=None, som_size=(20, 20), sigma=1.0, learning_rate=0.5, num_iteration=200000, random_seed=42, n_neighbors=5, coverage=0.95): """ Train HDBSCAN and SOM models on the given dataset. """ # Train HDBSCAN model print('Identifying clusters in the embedding ...') self.hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=min_samples_per_cluster) self.hdbscan_model.fit(dataset) # Calculate n_clusters if not provided if n_clusters is None: cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common()) cluster_labels = list(cluster_labels) total_points = sum(counts) covered_points = 0 n_clusters = 0 for count in counts: covered_points += count n_clusters += 1 if covered_points / total_points >= coverage: break # Train SOM models for the n_clusters most common clusters in the HDBSCAN model cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common(n_clusters + 1)) cluster_labels = list(cluster_labels) if -1 in cluster_labels: cluster_labels.remove(-1) else: cluster_labels.pop() for i, label in tqdm(enumerate(cluster_labels), total=len(cluster_labels), desc="Fitting 2D maps"): if label == -1: continue # Ignore noise cluster_data = dataset[self.hdbscan_model.labels_ == label] som = MiniSom(som_size[0], som_size[1], dataset.shape[1], sigma=sigma, learning_rate=learning_rate, random_seed=random_seed) som.train_random(cluster_data, num_iteration) self.som_models[i+1] = som self.cluster_mapping[i+1] = label # Compute sigma values mean_cluster, sigma_cluster = self.compute_sigma_values(cluster_data, som_size, som, n_neighbors=n_neighbors) self.sigma_values[i+1] = sigma_cluster self.mean_values[i+1] = mean_cluster def compute_sigma_values(self, cluster_data, som_size, som, n_neighbors=5): som_weights = som.get_weights() # Assign each datapoint to its nearest node partitions = {idx: [] for idx in np.ndindex(som_size[0], som_size[1])} for sample in cluster_data: x, y = som.winner(sample) partitions[(x, y)].append(sample) # Compute the mean distance and std deviation of these partitions mean_cluster = np.zeros(som_size) sigma_cluster = np.zeros(som_size) for idx in partitions: if len(partitions[idx]) > 0: partition_data = np.array(partitions[idx]) mean_distance = np.mean(np.linalg.norm(partition_data - som_weights[idx], axis=-1)) std_distance = np.std(np.linalg.norm(partition_data - som_weights[idx], axis=-1)) else: mean_distance = 0 std_distance = 0 mean_cluster[idx] = mean_distance sigma_cluster[idx] = std_distance return mean_cluster, sigma_cluster def train_label(self, labeled_data, labels): """ Train on labeled data to find centroids and compute distances to the labels. """ le = LabelEncoder() encoded_labels = le.fit_transform(labels) unique_labels = np.unique(encoded_labels) # Use label spreading to propagate the labels label_prop_model = LabelSpreading(kernel='knn', n_neighbors=5) label_prop_model.fit(labeled_data, encoded_labels) # Find the centroids for each label using KMeans kmeans = KMeans(n_clusters=len(unique_labels), random_state=42) kmeans.fit(labeled_data) # Store the label centroids and label encodings self.label_centroids = kmeans.cluster_centers_ self.label_encodings = le def predict(self, data, sigma_factor=1.5): """ Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value. Also, predict the label and distance to the center of the label if labels are trained. """ results = [] for sample in data: min_distance = float('inf') nearest_cluster_idx = None nearest_node = None for i, som in self.som_models.items(): x, y = som.winner(sample) node = som.get_weights()[x, y] distance = np.linalg.norm(sample - node) if distance < min_distance: min_distance = distance nearest_cluster_idx = i nearest_node = (x, y) # Check if the nearest node is within the sigma value if min_distance <= self.mean_values[nearest_cluster_idx][nearest_node] * 1.5: # * self.sigma_values[nearest_cluster_idx][nearest_node] * sigma_factor: if hasattr(self, 'label_centroids'): # Predict the label and distance to the center of the label label_idx = self.label_encodings.inverse_transform([nearest_cluster_idx - 1])[0] label_distance = np.linalg.norm(sample - self.label_centroids[label_idx]) results.append((nearest_cluster_idx, nearest_node, label_idx, label_distance)) else: results.append((nearest_cluster_idx, nearest_node)) else: results.append((-1, None)) # Noise return results def plot_label_heatmap(self): """ Plot a heatmap for each main cluster showing the best label for each coordinate in a single subplot layout. """ if not hasattr(self, 'label_centroids'): raise ValueError("Labels not trained yet.") n_labels = len(self.label_centroids) label_colors = plt.cm.rainbow(np.linspace(0, 1, n_labels)) n_clusters = len(self.som_models) # Create a subplot layout with a heatmap for each main cluster n_rows = int(np.ceil(np.sqrt(n_clusters))) n_cols = n_rows if n_rows * (n_rows - 1) < n_clusters else n_rows - 1 fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10), squeeze=False) for i, (reindexed_label, som) in enumerate(self.som_models.items()): som_weights = som.get_weights() label_map = np.zeros(som_weights.shape[:2], dtype=int) label_distance_map = np.full(som_weights.shape[:2], np.inf) for label_idx, label_centroid in enumerate(self.label_centroids): for x in range(som_weights.shape[0]): for y in range(som_weights.shape[1]): node = som_weights[x, y] distance = np.linalg.norm(label_centroid - node) if distance < label_distance_map[x, y]: label_distance_map[x, y] = distance label_map[x, y] = label_idx row, col = i // n_cols, i % n_cols ax = axes[row, col] cmap = plt.cm.rainbow cmap.set_under(color='white') im = ax.imshow(label_map, cmap=cmap, origin='lower', interpolation='none', vmin=0.5) ax.set_xticks(range(label_map.shape[1])) ax.set_yticks(range(label_map.shape[0])) ax.grid(True, linestyle='-', linewidth=0.5) ax.set_title(f"Label Heatmap for Cluster {reindexed_label}") # Add a colorbar for label colors cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) cbar = fig.colorbar(im, cax=cbar_ax, ticks=range(n_labels)) cbar.ax.set_yticklabels(self.label_encodings.classes_) # Adjust the layout to fit everything nicely fig.subplots_adjust(wspace=0.5, hspace=0.5, right=0.9) plt.show() # rearranging the subplots in the closest square format def rearrange_subplots(self, num_subplots): # Calculate the number of rows and columns for the subplot grid num_rows = math.isqrt(num_subplots) num_cols = math.ceil(num_subplots / num_rows) # Create the figure and subplots fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5), sharex=True, sharey=True) # Flatten the axes array if it is multidimensional if isinstance(axes, np.ndarray): axes = axes.flatten() # Hide any empty subplots for i in range(num_subplots, len(axes)): axes[i].axis('off') return fig, axes def plot_activation(self, data, filename='prediction_output', start=None, end=None): """ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. """ if len(self.som_models) == 0: raise ValueError("SOM models not trained yet.") if start is None: start = 0 if end is None: end = len(data) images = [] for sample in tqdm(data[start:end], desc="Visualizing prediction output"): prediction = self.predict([sample])[0] fig, axes = self.rearrange_subplots(len(self.som_models)) # fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) for idx, (som_key, som) in enumerate(self.som_models.items()): ax = axes[idx] activation_map = np.zeros(som._weights.shape[:2]) for x in range(som._weights.shape[0]): for y in range(som._weights.shape[1]): activation_map[x, y] = np.linalg.norm(sample - som._weights[x, y]) winner = som.winner(sample) # Find the BMU for this SOM activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap if som_key == prediction[0]: # Active SOM im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') if hasattr(self, 'label_centroids'): label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] ax.set_xlabel(f"Label: {label_idx}", fontsize=12) else: # Inactive SOM im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') ax.set_title(f"SOM {som_key}") ax.set_xticks(range(activation_map.shape[1])) ax.set_yticks(range(activation_map.shape[0])) ax.grid(True, linestyle='-', linewidth=0.5) # Create a colorbar for each frame fig.subplots_adjust(right=0.8) # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) # try: # fig.colorbar(im_active, cax=cbar_ax) # except: # pass # Save the plot to a buffer buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) img = imageio.imread(buf) images.append(img) plt.close() # Create the video using moviepy and save it as a mp4 file video = ImageSequenceClip(images, fps=1) return video def save(self, file_path): """ Save the ClusterSOM model to a file. """ model_data = (self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping) if hasattr(self, 'label_centroids'): model_data += (self.label_centroids, self.label_encodings) with open(file_path, "wb") as f: pickle.dump(model_data, f) def load(self, file_path): """ Load a ClusterSOM model from a file. """ with open(file_path, "rb") as f: model_data = pickle.load(f) self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5] if len(model_data) > 5: self.label_centroids, self.label_encodings = model_data[5:] def plot_activation_v2(self, data, slice_select): """ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. """ if len(self.som_models) == 0: raise ValueError("SOM models not trained yet.") try: prediction = self.predict([data[int(slice_select)-1]])[0] except: prediction = self.predict([data[int(slice_select)-2]])[0] fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) for idx, (som_key, som) in enumerate(self.som_models.items()): ax = axes[idx] activation_map = np.zeros(som._weights.shape[:2]) for x in range(som._weights.shape[0]): for y in range(som._weights.shape[1]): activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y]) winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap if som_key == prediction[0]: # Active SOM im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') if hasattr(self, 'label_centroids'): label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] ax.set_xlabel(f"Label: {label_idx}", fontsize=12) else: # Inactive SOM im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') ax.set_title(f"SOM {som_key}") ax.set_xticks(range(activation_map.shape[1])) ax.set_yticks(range(activation_map.shape[0])) ax.grid(True, linestyle='-', linewidth=0.5) plt.tight_layout() return fig def plot_activation_v3(self, data, slice_select): """ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. """ if len(self.som_models) == 0: raise ValueError("SOM models not trained yet.") try: prediction = self.predict([data[int(slice_select)-1]])[0] except: prediction = self.predict([data[int(slice_select)-2]])[0] fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) for idx, (som_key, som) in enumerate(self.som_models.items()): ax = axes[idx] activation_map = np.zeros(som._weights.shape[:2]) for x in range(som._weights.shape[0]): for y in range(som._weights.shape[1]): activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y]) winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap if som_key == prediction[0]: # Active SOM im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') if hasattr(self, 'label_centroids'): label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] ax.set_xlabel(f"Label: {label_idx}", fontsize=12) else: # Inactive SOM im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') ax.set_title(f"SOM {som_key}") ax.set_xticks(range(activation_map.shape[1])) ax.set_yticks(range(activation_map.shape[0])) ax.grid(True, linestyle='-', linewidth=0.5) plt.tight_layout() return fig