Spaces:
Build error
Build error
import io | |
import math | |
import pickle | |
import imageio | |
import hdbscan | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
from minisom import MiniSom | |
from collections import Counter | |
from sklearn.cluster import KMeans | |
from moviepy.editor import ImageSequenceClip | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.semi_supervised import LabelSpreading | |
class ClusterSOM: | |
def __init__(self): | |
self.hdbscan_model = None | |
self.som_models = {} | |
self.sigma_values = {} | |
self.mean_values = {} | |
self.cluster_mapping = {} | |
self.embedding = None | |
self.dim_red_op = None | |
def train(self, dataset, min_samples_per_cluster=100, n_clusters=None, som_size=(20, 20), sigma=1.0, learning_rate=0.5, num_iteration=200000, random_seed=42, n_neighbors=5, coverage=0.95): | |
""" | |
Train HDBSCAN and SOM models on the given dataset. | |
""" | |
# Train HDBSCAN model | |
print('Identifying clusters in the embedding ...') | |
self.hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=min_samples_per_cluster) | |
self.hdbscan_model.fit(dataset) | |
# Calculate n_clusters if not provided | |
if n_clusters is None: | |
cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common()) | |
cluster_labels = list(cluster_labels) | |
total_points = sum(counts) | |
covered_points = 0 | |
n_clusters = 0 | |
for count in counts: | |
covered_points += count | |
n_clusters += 1 | |
if covered_points / total_points >= coverage: | |
break | |
# Train SOM models for the n_clusters most common clusters in the HDBSCAN model | |
cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common(n_clusters + 1)) | |
cluster_labels = list(cluster_labels) | |
if -1 in cluster_labels: | |
cluster_labels.remove(-1) | |
else: | |
cluster_labels.pop() | |
for i, label in tqdm(enumerate(cluster_labels), total=len(cluster_labels), desc="Fitting 2D maps"): | |
if label == -1: | |
continue # Ignore noise | |
cluster_data = dataset[self.hdbscan_model.labels_ == label] | |
som = MiniSom(som_size[0], som_size[1], dataset.shape[1], sigma=sigma, learning_rate=learning_rate, random_seed=random_seed) | |
som.train_random(cluster_data, num_iteration) | |
self.som_models[i+1] = som | |
self.cluster_mapping[i+1] = label | |
# Compute sigma values | |
mean_cluster, sigma_cluster = self.compute_sigma_values(cluster_data, som_size, som, n_neighbors=n_neighbors) | |
self.sigma_values[i+1] = sigma_cluster | |
self.mean_values[i+1] = mean_cluster | |
def compute_sigma_values(self, cluster_data, som_size, som, n_neighbors=5): | |
som_weights = som.get_weights() | |
# Assign each datapoint to its nearest node | |
partitions = {idx: [] for idx in np.ndindex(som_size[0], som_size[1])} | |
for sample in cluster_data: | |
x, y = som.winner(sample) | |
partitions[(x, y)].append(sample) | |
# Compute the mean distance and std deviation of these partitions | |
mean_cluster = np.zeros(som_size) | |
sigma_cluster = np.zeros(som_size) | |
for idx in partitions: | |
if len(partitions[idx]) > 0: | |
partition_data = np.array(partitions[idx]) | |
mean_distance = np.mean(np.linalg.norm(partition_data - som_weights[idx], axis=-1)) | |
std_distance = np.std(np.linalg.norm(partition_data - som_weights[idx], axis=-1)) | |
else: | |
mean_distance = 0 | |
std_distance = 0 | |
mean_cluster[idx] = mean_distance | |
sigma_cluster[idx] = std_distance | |
return mean_cluster, sigma_cluster | |
def train_label(self, labeled_data, labels): | |
""" | |
Train on labeled data to find centroids and compute distances to the labels. | |
""" | |
le = LabelEncoder() | |
encoded_labels = le.fit_transform(labels) | |
unique_labels = np.unique(encoded_labels) | |
# Use label spreading to propagate the labels | |
label_prop_model = LabelSpreading(kernel='knn', n_neighbors=5) | |
label_prop_model.fit(labeled_data, encoded_labels) | |
# Find the centroids for each label using KMeans | |
kmeans = KMeans(n_clusters=len(unique_labels), random_state=42) | |
kmeans.fit(labeled_data) | |
# Store the label centroids and label encodings | |
self.label_centroids = kmeans.cluster_centers_ | |
self.label_encodings = le | |
def predict(self, data, sigma_factor=1.5): | |
""" | |
Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value. | |
Also, predict the label and distance to the center of the label if labels are trained. | |
""" | |
results = [] | |
for sample in data: | |
min_distance = float('inf') | |
nearest_cluster_idx = None | |
nearest_node = None | |
for i, som in self.som_models.items(): | |
x, y = som.winner(sample) | |
node = som.get_weights()[x, y] | |
distance = np.linalg.norm(sample - node) | |
if distance < min_distance: | |
min_distance = distance | |
nearest_cluster_idx = i | |
nearest_node = (x, y) | |
# Check if the nearest node is within the sigma value | |
if min_distance <= self.mean_values[nearest_cluster_idx][nearest_node] * 1.5: # * self.sigma_values[nearest_cluster_idx][nearest_node] * sigma_factor: | |
if hasattr(self, 'label_centroids'): | |
# Predict the label and distance to the center of the label | |
label_idx = self.label_encodings.inverse_transform([nearest_cluster_idx - 1])[0] | |
label_distance = np.linalg.norm(sample - self.label_centroids[label_idx]) | |
results.append((nearest_cluster_idx, nearest_node, label_idx, label_distance)) | |
else: | |
results.append((nearest_cluster_idx, nearest_node)) | |
else: | |
results.append((-1, None)) # Noise | |
return results | |
def plot_label_heatmap(self): | |
""" | |
Plot a heatmap for each main cluster showing the best label for each coordinate in a single subplot layout. | |
""" | |
if not hasattr(self, 'label_centroids'): | |
raise ValueError("Labels not trained yet.") | |
n_labels = len(self.label_centroids) | |
label_colors = plt.cm.rainbow(np.linspace(0, 1, n_labels)) | |
n_clusters = len(self.som_models) | |
# Create a subplot layout with a heatmap for each main cluster | |
n_rows = int(np.ceil(np.sqrt(n_clusters))) | |
n_cols = n_rows if n_rows * (n_rows - 1) < n_clusters else n_rows - 1 | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10), squeeze=False) | |
for i, (reindexed_label, som) in enumerate(self.som_models.items()): | |
som_weights = som.get_weights() | |
label_map = np.zeros(som_weights.shape[:2], dtype=int) | |
label_distance_map = np.full(som_weights.shape[:2], np.inf) | |
for label_idx, label_centroid in enumerate(self.label_centroids): | |
for x in range(som_weights.shape[0]): | |
for y in range(som_weights.shape[1]): | |
node = som_weights[x, y] | |
distance = np.linalg.norm(label_centroid - node) | |
if distance < label_distance_map[x, y]: | |
label_distance_map[x, y] = distance | |
label_map[x, y] = label_idx | |
row, col = i // n_cols, i % n_cols | |
ax = axes[row, col] | |
cmap = plt.cm.rainbow | |
cmap.set_under(color='white') | |
im = ax.imshow(label_map, cmap=cmap, origin='lower', interpolation='none', vmin=0.5) | |
ax.set_xticks(range(label_map.shape[1])) | |
ax.set_yticks(range(label_map.shape[0])) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
ax.set_title(f"Label Heatmap for Cluster {reindexed_label}") | |
# Add a colorbar for label colors | |
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) | |
cbar = fig.colorbar(im, cax=cbar_ax, ticks=range(n_labels)) | |
cbar.ax.set_yticklabels(self.label_encodings.classes_) | |
# Adjust the layout to fit everything nicely | |
fig.subplots_adjust(wspace=0.5, hspace=0.5, right=0.9) | |
plt.show() | |
# rearranging the subplots in the closest square format | |
def rearrange_subplots(self, num_subplots): | |
# Calculate the number of rows and columns for the subplot grid | |
num_rows = math.isqrt(num_subplots) | |
num_cols = math.ceil(num_subplots / num_rows) | |
# Create the figure and subplots | |
fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5), sharex=True, sharey=True) | |
# Flatten the axes array if it is multidimensional | |
if isinstance(axes, np.ndarray): | |
axes = axes.flatten() | |
# Hide any empty subplots | |
for i in range(num_subplots, len(axes)): | |
axes[i].axis('off') | |
return fig, axes | |
def plot_activation(self, data, filename='prediction_output', start=None, end=None): | |
""" | |
Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. | |
""" | |
if len(self.som_models) == 0: | |
raise ValueError("SOM models not trained yet.") | |
if start is None: | |
start = 0 | |
if end is None: | |
end = len(data) | |
images = [] | |
for sample in tqdm(data[start:end], desc="Visualizing prediction output"): | |
prediction = self.predict([sample])[0] | |
fig, axes = self.rearrange_subplots(len(self.som_models)) | |
# fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) | |
fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) | |
for idx, (som_key, som) in enumerate(self.som_models.items()): | |
ax = axes[idx] | |
activation_map = np.zeros(som._weights.shape[:2]) | |
for x in range(som._weights.shape[0]): | |
for y in range(som._weights.shape[1]): | |
activation_map[x, y] = np.linalg.norm(sample - som._weights[x, y]) | |
winner = som.winner(sample) # Find the BMU for this SOM | |
activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap | |
if som_key == prediction[0]: # Active SOM | |
im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') | |
ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign | |
ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') | |
if hasattr(self, 'label_centroids'): | |
label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] | |
ax.set_xlabel(f"Label: {label_idx}", fontsize=12) | |
else: # Inactive SOM | |
im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') | |
ax.set_title(f"SOM {som_key}") | |
ax.set_xticks(range(activation_map.shape[1])) | |
ax.set_yticks(range(activation_map.shape[0])) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
# Create a colorbar for each frame | |
fig.subplots_adjust(right=0.8) | |
# cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) | |
# try: | |
# fig.colorbar(im_active, cax=cbar_ax) | |
# except: | |
# pass | |
# Save the plot to a buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
img = imageio.imread(buf) | |
images.append(img) | |
plt.close() | |
# Create the video using moviepy and save it as a mp4 file | |
video = ImageSequenceClip(images, fps=1) | |
return video | |
def save(self, file_path): | |
""" | |
Save the ClusterSOM model to a file. | |
""" | |
model_data = (self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping) | |
if hasattr(self, 'label_centroids'): | |
model_data += (self.label_centroids, self.label_encodings) | |
with open(file_path, "wb") as f: | |
pickle.dump(model_data, f) | |
def load(self, file_path): | |
""" | |
Load a ClusterSOM model from a file. | |
""" | |
with open(file_path, "rb") as f: | |
model_data = pickle.load(f) | |
self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5] | |
if len(model_data) > 5: | |
self.label_centroids, self.label_encodings = model_data[5:] | |
def plot_activation_v2(self, data, slice_select): | |
""" | |
Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. | |
""" | |
if len(self.som_models) == 0: | |
raise ValueError("SOM models not trained yet.") | |
try: | |
prediction = self.predict([data[int(slice_select)-1]])[0] | |
except: | |
prediction = self.predict([data[int(slice_select)-2]])[0] | |
fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) | |
fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) | |
for idx, (som_key, som) in enumerate(self.som_models.items()): | |
ax = axes[idx] | |
activation_map = np.zeros(som._weights.shape[:2]) | |
for x in range(som._weights.shape[0]): | |
for y in range(som._weights.shape[1]): | |
activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y]) | |
winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM | |
activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap | |
if som_key == prediction[0]: # Active SOM | |
im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') | |
ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign | |
ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') | |
if hasattr(self, 'label_centroids'): | |
label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] | |
ax.set_xlabel(f"Label: {label_idx}", fontsize=12) | |
else: # Inactive SOM | |
im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') | |
ax.set_title(f"SOM {som_key}") | |
ax.set_xticks(range(activation_map.shape[1])) | |
ax.set_yticks(range(activation_map.shape[0])) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
plt.tight_layout() | |
return fig | |
def plot_activation_v3(self, data, slice_select): | |
""" | |
Generate a GIF visualization of the prediction output using the activation maps of individual SOMs. | |
""" | |
if len(self.som_models) == 0: | |
raise ValueError("SOM models not trained yet.") | |
try: | |
prediction = self.predict([data[int(slice_select)-1]])[0] | |
except: | |
prediction = self.predict([data[int(slice_select)-2]])[0] | |
fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True) | |
fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16) | |
for idx, (som_key, som) in enumerate(self.som_models.items()): | |
ax = axes[idx] | |
activation_map = np.zeros(som._weights.shape[:2]) | |
for x in range(som._weights.shape[0]): | |
for y in range(som._weights.shape[1]): | |
activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y]) | |
winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM | |
activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap | |
if som_key == prediction[0]: # Active SOM | |
im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none') | |
ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign | |
ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold') | |
if hasattr(self, 'label_centroids'): | |
label_idx = self.label_encodings.inverse_transform([som_key - 1])[0] | |
ax.set_xlabel(f"Label: {label_idx}", fontsize=12) | |
else: # Inactive SOM | |
im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none') | |
ax.set_title(f"SOM {som_key}") | |
ax.set_xticks(range(activation_map.shape[1])) | |
ax.set_yticks(range(activation_map.shape[0])) | |
ax.grid(True, linestyle='-', linewidth=0.5) | |
plt.tight_layout() | |
return fig |