arcan3 commited on
Commit
7a69981
1 Parent(s): 0231841

som name changed, placeholder added, new models added

Browse files
app.py CHANGED
@@ -16,23 +16,36 @@ from funcs.dataloader import BaseDataset2, read_json_files
16
 
17
  DEVICE = torch.device("cpu")
18
  reducer10d = PHATEAE(epochs=30, n_components=10, lr=.0001, batch_size=128, t='auto', knn=8, relax=True, metric='euclidean')
19
- reducer10d.load('models/r10d_3.pth')
20
 
21
  cluster_som = ClusterSOM()
22
- cluster_som.load("models/cluster_som3.pkl")
23
 
24
  def map_som2animation(som_value):
25
  mapping = {
26
- 2: 0, # walk
27
- 1: 1, # trot
28
- 3: 2, # gallop
29
- 5: 3, # idle
30
- 4: 3, # other
31
- -1:3, #other
32
- }
33
 
34
  return mapping.get(som_value, None)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def deviation_scores(tensor_data, scale=50):
37
  if len(tensor_data) < 5:
38
  raise ValueError("The input tensor must have at least 5 elements.")
@@ -97,8 +110,12 @@ def get_som_mp4_v2(csv_file_box, slice_size_slider, sample_rate, window_size_sli
97
  processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box = process_data(csv_file_box, slice_size_slider, sample_rate, window_size_slider)
98
 
99
  try:
 
 
100
  train_x, train_y = read_json_files(json_file_box)
101
  except:
 
 
102
  train_x, train_y = read_json_files(json_file_box.name)
103
 
104
  # Convert tensors to numpy arrays if necessary
@@ -124,13 +141,14 @@ def get_som_mp4_v2(csv_file_box, slice_size_slider, sample_rate, window_size_sli
124
  csv_writer.writerow(header)
125
  csv_writer.writerows(processed_data)
126
 
127
- os.system('curl -X POST -F "csv_file=@animation_table.csv" https://metric-space.ngrok.io/generate --output animation.mp4')
128
 
129
  # prediction = cluster_som.predict(embedding10d)
130
  som_video = cluster.plot_activation(embedding10d)
131
  som_video.write_videofile('som_sequence.mp4')
132
-
133
- return processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box, 'som_sequence.mp4', 'animation.mp4'
 
134
 
135
  # ml inference
136
  def get_som_mp4(file, slice_select, reducer=reducer10d, cluster=cluster_som):
@@ -183,7 +201,11 @@ with gr.Blocks(title='Cabasus') as cabasus_sensor:
183
 
184
  with gr.Row():
185
  animation = gr.Video(label='animation')
186
- activation_video = gr.Video(label='real')
 
 
 
 
187
 
188
  plot_box_leg = gr.Plot(label="Filtered Signal Plot")
189
  slice_slider = gr.Slider(minimum=1, maximum=300, label='Slice select', step=1)
 
16
 
17
  DEVICE = torch.device("cpu")
18
  reducer10d = PHATEAE(epochs=30, n_components=10, lr=.0001, batch_size=128, t='auto', knn=8, relax=True, metric='euclidean')
19
+ reducer10d.load('models/r10d_6.pth')
20
 
21
  cluster_som = ClusterSOM()
22
+ cluster_som.load("models/cluster_som6.pkl")
23
 
24
  def map_som2animation(som_value):
25
  mapping = {
26
+ 2: 0, # walk
27
+ 1: 1, # trot
28
+ 3: 2, # gallop
29
+ 5: 3, # idle
30
+ 4: 3, # other
31
+ -1:3, #other
32
+ }
33
 
34
  return mapping.get(som_value, None)
35
 
36
+ # def map_som2animation_v2(som_value):
37
+ # mapping = {
38
+ # versammelter_trab: center of SOM-1,
39
+ # arbeits-trab: south-east od SOM-1,
40
+ # mittels-trab: North of SOM-1,
41
+ # starker-trab: North-west of SOM1,
42
+
43
+ # starker-schritt:
44
+
45
+ # }
46
+
47
+ # return mapping.get(som_value, None)
48
+
49
  def deviation_scores(tensor_data, scale=50):
50
  if len(tensor_data) < 5:
51
  raise ValueError("The input tensor must have at least 5 elements.")
 
110
  processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box = process_data(csv_file_box, slice_size_slider, sample_rate, window_size_slider)
111
 
112
  try:
113
+ if json_file_box is None:
114
+ return processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box, None, None
115
  train_x, train_y = read_json_files(json_file_box)
116
  except:
117
+ if json_file_box.name is None:
118
+ return processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box, None, None
119
  train_x, train_y = read_json_files(json_file_box.name)
120
 
121
  # Convert tensors to numpy arrays if necessary
 
141
  csv_writer.writerow(header)
142
  csv_writer.writerows(processed_data)
143
 
144
+ # os.system('curl -X POST -F "csv_file=@animation_table.csv" https://metric-space.ngrok.io/generate --output animation.mp4')
145
 
146
  # prediction = cluster_som.predict(embedding10d)
147
  som_video = cluster.plot_activation(embedding10d)
148
  som_video.write_videofile('som_sequence.mp4')
149
+
150
+ # return processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box, 'som_sequence.mp4', 'animation.mp4'
151
+ return processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box, 'som_sequence.mp4', None
152
 
153
  # ml inference
154
  def get_som_mp4(file, slice_select, reducer=reducer10d, cluster=cluster_som):
 
201
 
202
  with gr.Row():
203
  animation = gr.Video(label='animation')
204
+ activation_video = gr.Video(label='activation channels')
205
+
206
+ with gr.Row():
207
+ real_video = gr.Video(label='real video')
208
+ trend_graph = gr.Video(label='trend graph')
209
 
210
  plot_box_leg = gr.Plot(label="Filtered Signal Plot")
211
  slice_slider = gr.Slider(minimum=1, maximum=300, label='Slice select', step=1)
funcs/not_needed_som_funcs.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import pickle
4
+ import imageio
5
+ import hdbscan
6
+
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+
10
+ from tqdm import tqdm
11
+ from minisom import MiniSom
12
+ from collections import Counter
13
+ from sklearn.cluster import KMeans
14
+ from moviepy.editor import ImageSequenceClip
15
+ from sklearn.preprocessing import LabelEncoder
16
+ from sklearn.semi_supervised import LabelSpreading
17
+
18
+ class ClusterSOM:
19
+ def __init__(self):
20
+ self.hdbscan_model = None
21
+ self.som_models = {}
22
+ self.sigma_values = {}
23
+ self.mean_values = {}
24
+ self.cluster_mapping = {}
25
+ self.embedding = None
26
+ self.dim_red_op = None
27
+
28
+ def train(self, dataset, min_samples_per_cluster=100, n_clusters=None, som_size=(20, 20), sigma=1.0, learning_rate=0.5, num_iteration=200000, random_seed=42, n_neighbors=5, coverage=0.95):
29
+ """
30
+ Train HDBSCAN and SOM models on the given dataset.
31
+ """
32
+ # Train HDBSCAN model
33
+ print('Identifying clusters in the embedding ...')
34
+ self.hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=min_samples_per_cluster)
35
+ self.hdbscan_model.fit(dataset)
36
+
37
+ # Calculate n_clusters if not provided
38
+ if n_clusters is None:
39
+ cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common())
40
+ cluster_labels = list(cluster_labels)
41
+ total_points = sum(counts)
42
+ covered_points = 0
43
+ n_clusters = 0
44
+ for count in counts:
45
+ covered_points += count
46
+ n_clusters += 1
47
+ if covered_points / total_points >= coverage:
48
+ break
49
+
50
+ # Train SOM models for the n_clusters most common clusters in the HDBSCAN model
51
+ cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common(n_clusters + 1))
52
+ cluster_labels = list(cluster_labels)
53
+
54
+ if -1 in cluster_labels:
55
+ cluster_labels.remove(-1)
56
+ else:
57
+ cluster_labels.pop()
58
+
59
+ for i, label in tqdm(enumerate(cluster_labels), total=len(cluster_labels), desc="Fitting 2D maps"):
60
+ if label == -1:
61
+ continue # Ignore noise
62
+ cluster_data = dataset[self.hdbscan_model.labels_ == label]
63
+ som = MiniSom(som_size[0], som_size[1], dataset.shape[1], sigma=sigma, learning_rate=learning_rate, random_seed=random_seed)
64
+ som.train_random(cluster_data, num_iteration)
65
+ self.som_models[i+1] = som
66
+ self.cluster_mapping[i+1] = label
67
+
68
+ # Compute sigma values
69
+ mean_cluster, sigma_cluster = self.compute_sigma_values(cluster_data, som_size, som, n_neighbors=n_neighbors)
70
+ self.sigma_values[i+1] = sigma_cluster
71
+ self.mean_values[i+1] = mean_cluster
72
+
73
+ def compute_sigma_values(self, cluster_data, som_size, som, n_neighbors=5):
74
+ som_weights = som.get_weights()
75
+
76
+ # Assign each datapoint to its nearest node
77
+ partitions = {idx: [] for idx in np.ndindex(som_size[0], som_size[1])}
78
+ for sample in cluster_data:
79
+ x, y = som.winner(sample)
80
+ partitions[(x, y)].append(sample)
81
+
82
+ # Compute the mean distance and std deviation of these partitions
83
+ mean_cluster = np.zeros(som_size)
84
+ sigma_cluster = np.zeros(som_size)
85
+ for idx in partitions:
86
+ if len(partitions[idx]) > 0:
87
+ partition_data = np.array(partitions[idx])
88
+ mean_distance = np.mean(np.linalg.norm(partition_data - som_weights[idx], axis=-1))
89
+ std_distance = np.std(np.linalg.norm(partition_data - som_weights[idx], axis=-1))
90
+ else:
91
+ mean_distance = 0
92
+ std_distance = 0
93
+ mean_cluster[idx] = mean_distance
94
+ sigma_cluster[idx] = std_distance
95
+
96
+ return mean_cluster, sigma_cluster
97
+
98
+ def train_label(self, labeled_data, labels):
99
+ """
100
+ Train on labeled data to find centroids and compute distances to the labels.
101
+ """
102
+ le = LabelEncoder()
103
+ encoded_labels = le.fit_transform(labels)
104
+ unique_labels = np.unique(encoded_labels)
105
+
106
+ # Use label spreading to propagate the labels
107
+ label_prop_model = LabelSpreading(kernel='knn', n_neighbors=5)
108
+ label_prop_model.fit(labeled_data, encoded_labels)
109
+
110
+ # Find the centroids for each label using KMeans
111
+ kmeans = KMeans(n_clusters=len(unique_labels), random_state=42)
112
+ kmeans.fit(labeled_data)
113
+
114
+ # Store the label centroids and label encodings
115
+ self.label_centroids = kmeans.cluster_centers_
116
+ self.label_encodings = le
117
+
118
+ def predict(self, data, sigma_factor=1.5):
119
+ """
120
+ Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value.
121
+ Also, predict the label and distance to the center of the label if labels are trained.
122
+ """
123
+ results = []
124
+
125
+ for sample in data:
126
+ min_distance = float('inf')
127
+ nearest_cluster_idx = None
128
+ nearest_node = None
129
+
130
+ for i, som in self.som_models.items():
131
+ x, y = som.winner(sample)
132
+ node = som.get_weights()[x, y]
133
+ distance = np.linalg.norm(sample - node)
134
+
135
+ if distance < min_distance:
136
+ min_distance = distance
137
+ nearest_cluster_idx = i
138
+ nearest_node = (x, y)
139
+
140
+ # Check if the nearest node is within the sigma value
141
+ if min_distance <= self.mean_values[nearest_cluster_idx][nearest_node] * 1.5: # * self.sigma_values[nearest_cluster_idx][nearest_node] * sigma_factor:
142
+ if hasattr(self, 'label_centroids'):
143
+ # Predict the label and distance to the center of the label
144
+ label_idx = self.label_encodings.inverse_transform([nearest_cluster_idx - 1])[0]
145
+ label_distance = np.linalg.norm(sample - self.label_centroids[label_idx])
146
+ results.append((nearest_cluster_idx, nearest_node, label_idx, label_distance))
147
+ else:
148
+ results.append((nearest_cluster_idx, nearest_node))
149
+ else:
150
+ results.append((-1, None)) # Noise
151
+
152
+ return results
153
+
154
+ def plot_label_heatmap(self):
155
+ """
156
+ Plot a heatmap for each main cluster showing the best label for each coordinate in a single subplot layout.
157
+ """
158
+ if not hasattr(self, 'label_centroids'):
159
+ raise ValueError("Labels not trained yet.")
160
+
161
+ n_labels = len(self.label_centroids)
162
+ label_colors = plt.cm.rainbow(np.linspace(0, 1, n_labels))
163
+ n_clusters = len(self.som_models)
164
+
165
+ # Create a subplot layout with a heatmap for each main cluster
166
+ n_rows = int(np.ceil(np.sqrt(n_clusters)))
167
+ n_cols = n_rows if n_rows * (n_rows - 1) < n_clusters else n_rows - 1
168
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10), squeeze=False)
169
+
170
+ for i, (reindexed_label, som) in enumerate(self.som_models.items()):
171
+ som_weights = som.get_weights()
172
+ label_map = np.zeros(som_weights.shape[:2], dtype=int)
173
+ label_distance_map = np.full(som_weights.shape[:2], np.inf)
174
+
175
+ for label_idx, label_centroid in enumerate(self.label_centroids):
176
+ for x in range(som_weights.shape[0]):
177
+ for y in range(som_weights.shape[1]):
178
+ node = som_weights[x, y]
179
+ distance = np.linalg.norm(label_centroid - node)
180
+
181
+ if distance < label_distance_map[x, y]:
182
+ label_distance_map[x, y] = distance
183
+ label_map[x, y] = label_idx
184
+
185
+ row, col = i // n_cols, i % n_cols
186
+ ax = axes[row, col]
187
+ cmap = plt.cm.rainbow
188
+ cmap.set_under(color='white')
189
+ im = ax.imshow(label_map, cmap=cmap, origin='lower', interpolation='none', vmin=0.5)
190
+ ax.set_xticks(range(label_map.shape[1]))
191
+ ax.set_yticks(range(label_map.shape[0]))
192
+ ax.grid(True, linestyle='-', linewidth=0.5)
193
+ ax.set_title(f"Label Heatmap for Cluster {reindexed_label}")
194
+
195
+ # Add a colorbar for label colors
196
+ cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
197
+ cbar = fig.colorbar(im, cax=cbar_ax, ticks=range(n_labels))
198
+ cbar.ax.set_yticklabels(self.label_encodings.classes_)
199
+
200
+ # Adjust the layout to fit everything nicely
201
+ fig.subplots_adjust(wspace=0.5, hspace=0.5, right=0.9)
202
+
203
+ plt.show()
204
+
205
+ # rearranging the subplots in the closest square format
206
+ def rearrange_subplots(self, num_subplots):
207
+ # Calculate the number of rows and columns for the subplot grid
208
+ num_rows = math.isqrt(num_subplots)
209
+ num_cols = math.ceil(num_subplots / num_rows)
210
+
211
+ # Create the figure and subplots
212
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5), sharex=True, sharey=True)
213
+
214
+ # Flatten the axes array if it is multidimensional
215
+ if isinstance(axes, np.ndarray):
216
+ axes = axes.flatten()
217
+
218
+ # Hide any empty subplots
219
+ for i in range(num_subplots, len(axes)):
220
+ axes[i].axis('off')
221
+
222
+ return fig, axes
223
+
224
+ def plot_activation(self, data, filename='prediction_output', start=None, end=None):
225
+ """
226
+ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
227
+ """
228
+ if len(self.som_models) == 0:
229
+ raise ValueError("SOM models not trained yet.")
230
+
231
+ if start is None:
232
+ start = 0
233
+
234
+ if end is None:
235
+ end = len(data)
236
+
237
+ images = []
238
+ for sample in tqdm(data[start:end], desc="Visualizing prediction output"):
239
+ prediction = self.predict([sample])[0]
240
+
241
+ fig, axes = self.rearrange_subplots(len(self.som_models))
242
+
243
+ # fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True)
244
+ fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16)
245
+
246
+ for idx, (som_key, som) in enumerate(self.som_models.items()):
247
+ ax = axes[idx]
248
+ activation_map = np.zeros(som._weights.shape[:2])
249
+ for x in range(som._weights.shape[0]):
250
+ for y in range(som._weights.shape[1]):
251
+ activation_map[x, y] = np.linalg.norm(sample - som._weights[x, y])
252
+
253
+ winner = som.winner(sample) # Find the BMU for this SOM
254
+ activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap
255
+
256
+ if som_key == prediction[0]: # Active SOM
257
+ im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none')
258
+ ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign
259
+ ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold')
260
+ if hasattr(self, 'label_centroids'):
261
+ label_idx = self.label_encodings.inverse_transform([som_key - 1])[0]
262
+ ax.set_xlabel(f"Label: {label_idx}", fontsize=12)
263
+ else: # Inactive SOM
264
+ im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none')
265
+ ax.set_title(f"SOM {som_key}")
266
+
267
+ ax.set_xticks(range(activation_map.shape[1]))
268
+ ax.set_yticks(range(activation_map.shape[0]))
269
+ ax.grid(True, linestyle='-', linewidth=0.5)
270
+
271
+ # Create a colorbar for each frame
272
+ fig.subplots_adjust(right=0.8)
273
+ # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
274
+ # try:
275
+ # fig.colorbar(im_active, cax=cbar_ax)
276
+ # except:
277
+ # pass
278
+
279
+ # Save the plot to a buffer
280
+ buf = io.BytesIO()
281
+ plt.savefig(buf, format='png')
282
+ buf.seek(0)
283
+ img = imageio.imread(buf)
284
+ images.append(img)
285
+ plt.close()
286
+
287
+ # Create the video using moviepy and save it as a mp4 file
288
+ video = ImageSequenceClip(images, fps=1)
289
+
290
+ return video
291
+
292
+ def save(self, file_path):
293
+ """
294
+ Save the ClusterSOM model to a file.
295
+ """
296
+ model_data = (self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping)
297
+ if hasattr(self, 'label_centroids'):
298
+ model_data += (self.label_centroids, self.label_encodings)
299
+
300
+ with open(file_path, "wb") as f:
301
+ pickle.dump(model_data, f)
302
+
303
+ def load(self, file_path):
304
+ """
305
+ Load a ClusterSOM model from a file.
306
+ """
307
+ with open(file_path, "rb") as f:
308
+ model_data = pickle.load(f)
309
+
310
+ self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5]
311
+ if len(model_data) > 5:
312
+ self.label_centroids, self.label_encodings = model_data[5:]
313
+
314
+
315
+ def plot_activation_v2(self, data, slice_select):
316
+ """
317
+ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
318
+ """
319
+ if len(self.som_models) == 0:
320
+ raise ValueError("SOM models not trained yet.")
321
+
322
+ try:
323
+ prediction = self.predict([data[int(slice_select)-1]])[0]
324
+ except:
325
+ prediction = self.predict([data[int(slice_select)-2]])[0]
326
+
327
+ fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True)
328
+ fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16)
329
+
330
+ for idx, (som_key, som) in enumerate(self.som_models.items()):
331
+ ax = axes[idx]
332
+ activation_map = np.zeros(som._weights.shape[:2])
333
+ for x in range(som._weights.shape[0]):
334
+ for y in range(som._weights.shape[1]):
335
+ activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y])
336
+
337
+ winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM
338
+ activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap
339
+
340
+ if som_key == prediction[0]: # Active SOM
341
+ im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none')
342
+ ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign
343
+ ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold')
344
+ if hasattr(self, 'label_centroids'):
345
+ label_idx = self.label_encodings.inverse_transform([som_key - 1])[0]
346
+ ax.set_xlabel(f"Label: {label_idx}", fontsize=12)
347
+ else: # Inactive SOM
348
+ im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none')
349
+ ax.set_title(f"SOM {som_key}")
350
+
351
+ ax.set_xticks(range(activation_map.shape[1]))
352
+ ax.set_yticks(range(activation_map.shape[0]))
353
+ ax.grid(True, linestyle='-', linewidth=0.5)
354
+
355
+ plt.tight_layout()
356
+
357
+ return fig
358
+
359
+ def plot_activation_v3(self, data, slice_select):
360
+ """
361
+ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
362
+ """
363
+ if len(self.som_models) == 0:
364
+ raise ValueError("SOM models not trained yet.")
365
+
366
+ try:
367
+ prediction = self.predict([data[int(slice_select)-1]])[0]
368
+ except:
369
+ prediction = self.predict([data[int(slice_select)-2]])[0]
370
+
371
+ fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True)
372
+ fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16)
373
+
374
+ for idx, (som_key, som) in enumerate(self.som_models.items()):
375
+ ax = axes[idx]
376
+ activation_map = np.zeros(som._weights.shape[:2])
377
+ for x in range(som._weights.shape[0]):
378
+ for y in range(som._weights.shape[1]):
379
+ activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y])
380
+
381
+ winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM
382
+ activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap
383
+
384
+ if som_key == prediction[0]: # Active SOM
385
+ im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none')
386
+ ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign
387
+ ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold')
388
+ if hasattr(self, 'label_centroids'):
389
+ label_idx = self.label_encodings.inverse_transform([som_key - 1])[0]
390
+ ax.set_xlabel(f"Label: {label_idx}", fontsize=12)
391
+ else: # Inactive SOM
392
+ im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none')
393
+ ax.set_title(f"SOM {som_key}")
394
+
395
+ ax.set_xticks(range(activation_map.shape[1]))
396
+ ax.set_yticks(range(activation_map.shape[0]))
397
+ ax.grid(True, linestyle='-', linewidth=0.5)
398
+
399
+ plt.tight_layout()
400
+
401
+ return fig
funcs/plot_func.py CHANGED
@@ -73,10 +73,6 @@ def plot_overlay_data_from_json(json_file, slice_select, sensors=['GZ1', 'GZ2',
73
  with open(json_file.name, "r") as f:
74
  slices = json.load(f)
75
 
76
- # # Read the JSON file
77
- # with open(json_file, "r") as f:
78
- # slices = json.load(f)
79
-
80
  # Create subplots for each sensor
81
  fig, axs = plt.subplots(len(sensors), 1, figsize=(12, 2 * len(sensors)), sharex=True)
82
 
 
73
  with open(json_file.name, "r") as f:
74
  slices = json.load(f)
75
 
 
 
 
 
76
  # Create subplots for each sensor
77
  fig, axs = plt.subplots(len(sensors), 1, figsize=(12, 2 * len(sensors)), sharex=True)
78
 
funcs/som.py CHANGED
@@ -1,20 +1,12 @@
1
- import numpy as np
2
- import hdbscan
3
- from minisom import MiniSom
4
  import pickle
5
- from collections import Counter
6
- import matplotlib.pyplot as plt
7
- import phate
8
  import imageio
 
 
 
 
9
  from tqdm import tqdm
10
- import io
11
- import plotly.graph_objs as go
12
- import plotly.subplots as sp
13
- import umap
14
- from sklearn.datasets import make_blobs
15
- from sklearn.preprocessing import LabelEncoder
16
- from sklearn.cluster import KMeans
17
- from sklearn.semi_supervised import LabelSpreading
18
  from moviepy.editor import ImageSequenceClip
19
 
20
  class ClusterSOM:
@@ -26,97 +18,18 @@ class ClusterSOM:
26
  self.cluster_mapping = {}
27
  self.embedding = None
28
  self.dim_red_op = None
29
-
30
- def train(self, dataset, min_samples_per_cluster=100, n_clusters=None, som_size=(20, 20), sigma=1.0, learning_rate=0.5, num_iteration=200000, random_seed=42, n_neighbors=5, coverage=0.95):
31
- """
32
- Train HDBSCAN and SOM models on the given dataset.
33
- """
34
- # Train HDBSCAN model
35
- print('Identifying clusters in the embedding ...')
36
- self.hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=min_samples_per_cluster)
37
- self.hdbscan_model.fit(dataset)
38
-
39
- # Calculate n_clusters if not provided
40
- if n_clusters is None:
41
- cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common())
42
- cluster_labels = list(cluster_labels)
43
- total_points = sum(counts)
44
- covered_points = 0
45
- n_clusters = 0
46
- for count in counts:
47
- covered_points += count
48
- n_clusters += 1
49
- if covered_points / total_points >= coverage:
50
- break
51
-
52
- # Train SOM models for the n_clusters most common clusters in the HDBSCAN model
53
- cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common(n_clusters + 1))
54
- cluster_labels = list(cluster_labels)
55
-
56
- if -1 in cluster_labels:
57
- cluster_labels.remove(-1)
58
- else:
59
- cluster_labels.pop()
60
-
61
- for i, label in tqdm(enumerate(cluster_labels), total=len(cluster_labels), desc="Fitting 2D maps"):
62
- if label == -1:
63
- continue # Ignore noise
64
- cluster_data = dataset[self.hdbscan_model.labels_ == label]
65
- som = MiniSom(som_size[0], som_size[1], dataset.shape[1], sigma=sigma, learning_rate=learning_rate, random_seed=random_seed)
66
- som.train_random(cluster_data, num_iteration)
67
- self.som_models[i+1] = som
68
- self.cluster_mapping[i+1] = label
69
-
70
- # Compute sigma values
71
- mean_cluster, sigma_cluster = self.compute_sigma_values(cluster_data, som_size, som, n_neighbors=n_neighbors)
72
- self.sigma_values[i+1] = sigma_cluster
73
- self.mean_values[i+1] = mean_cluster
74
-
75
- def compute_sigma_values(self, cluster_data, som_size, som, n_neighbors=5):
76
- som_weights = som.get_weights()
77
-
78
- # Assign each datapoint to its nearest node
79
- partitions = {idx: [] for idx in np.ndindex(som_size[0], som_size[1])}
80
- for sample in cluster_data:
81
- x, y = som.winner(sample)
82
- partitions[(x, y)].append(sample)
83
-
84
- # Compute the mean distance and std deviation of these partitions
85
- mean_cluster = np.zeros(som_size)
86
- sigma_cluster = np.zeros(som_size)
87
- for idx in partitions:
88
- if len(partitions[idx]) > 0:
89
- partition_data = np.array(partitions[idx])
90
- mean_distance = np.mean(np.linalg.norm(partition_data - som_weights[idx], axis=-1))
91
- std_distance = np.std(np.linalg.norm(partition_data - som_weights[idx], axis=-1))
92
- else:
93
- mean_distance = 0
94
- std_distance = 0
95
- mean_cluster[idx] = mean_distance
96
- sigma_cluster[idx] = std_distance
97
-
98
- return mean_cluster, sigma_cluster
99
-
100
- def train_label(self, labeled_data, labels):
101
  """
102
- Train on labeled data to find centroids and compute distances to the labels.
103
  """
104
- le = LabelEncoder()
105
- encoded_labels = le.fit_transform(labels)
106
- unique_labels = np.unique(encoded_labels)
107
-
108
- # Use label spreading to propagate the labels
109
- label_prop_model = LabelSpreading(kernel='knn', n_neighbors=5)
110
- label_prop_model.fit(labeled_data, encoded_labels)
111
-
112
- # Find the centroids for each label using KMeans
113
- kmeans = KMeans(n_clusters=len(unique_labels), random_state=42)
114
- kmeans.fit(labeled_data)
115
-
116
- # Store the label centroids and label encodings
117
- self.label_centroids = kmeans.cluster_centers_
118
- self.label_encodings = le
119
 
 
 
 
 
120
  def predict(self, data, sigma_factor=1.5):
121
  """
122
  Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value.
@@ -153,182 +66,26 @@ class ClusterSOM:
153
 
154
  return results
155
 
156
- def plot_embedding(self, new_data=None, dim_reduction='umap', interactive=False):
157
- """
158
- Plot the dataset and SOM grids for each cluster.
159
- If new_data is provided, it will be used for plotting instead of the entire dataset.
160
- """
161
-
162
- if self.hdbscan_model is None:
163
- raise ValueError("HDBSCAN model not trained yet.")
164
-
165
- if len(self.som_models) == 0:
166
- raise ValueError("SOM models not trained yet.")
167
-
168
- if dim_reduction not in ['phate', 'umap']:
169
- raise ValueError("Invalid dimensionality reduction method. Use 'phate' or 'umap'.")
170
 
171
- if self.dim_red_op is None or self.embedding is None:
172
- n_components = 3
173
- if dim_reduction == 'phate':
174
- self.dim_red_op = phate.PHATE(n_components=n_components, random_state=42)
175
- elif dim_reduction == 'umap':
176
- self.dim_red_op = umap.UMAP(n_components=n_components, random_state=42)
177
-
178
- self.embedding = self.dim_red_op.fit_transform(new_data)
179
-
180
- if new_data is not None:
181
- new_embedding = self.dim_red_op.transform(new_data)
182
- else:
183
- new_embedding = self.embedding
184
-
185
- if interactive:
186
- fig = sp.make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter3d'}]])
187
- else:
188
- fig = plt.figure(figsize=(30, 30))
189
- ax = fig.add_subplot(111, projection='3d')
190
-
191
- colors = plt.cm.rainbow(np.linspace(0, 1, len(self.som_models) + 1))
192
-
193
- for reindexed_label, som in self.som_models.items():
194
- original_label = self.cluster_mapping[reindexed_label]
195
- cluster_data = embedding[self.hdbscan_model.labels_ == original_label]
196
- som_weights = som.get_weights()
197
-
198
- som_embedding = dim_red_op.transform(som_weights.reshape(-1, dataset.shape[1])).reshape(som_weights.shape[0], som_weights.shape[1], n_components)
199
-
200
- if interactive:
201
- # Plot the original data points
202
- fig.add_trace(
203
- go.Scatter3d(
204
- x=cluster_data[:, 0],
205
- y=cluster_data[:, 1],
206
- z=cluster_data[:, 2],
207
- mode='markers',
208
- marker=dict(color=colors[reindexed_label], size=1),
209
- name=f"Cluster {reindexed_label}"
210
- )
211
- )
212
- else:
213
- # Plot the original data points
214
- ax.scatter(cluster_data[:, 0], cluster_data[:, 1], cluster_data[:, 2], c=[colors[reindexed_label]], alpha=0.3, s=5, label=f"Cluster {reindexed_label}")
215
-
216
- for x in range(som_embedding.shape[0]):
217
- for y in range(som_embedding.shape[1]):
218
- if interactive:
219
- # Plot the SOM grid
220
- fig.add_trace(
221
- go.Scatter3d(
222
- x=[som_embedding[x, y, 0]],
223
- y=[som_embedding[x, y, 1]],
224
- z=[som_embedding[x, y, 2]],
225
- mode='markers+text',
226
- marker=dict(color=colors[reindexed_label], size=3, symbol='circle'),
227
- text=[f"{x},{y}"],
228
- textposition="top center"
229
- )
230
- )
231
- else:
232
- # Plot the SOM grid
233
- ax.plot([som_embedding[x, y, 0]], [som_embedding[x, y, 1]], [som_embedding[x, y, 2]], '+', markersize=8, mew=2, zorder=10, c=colors[reindexed_label])
234
-
235
- for i in range(som_embedding.shape[0] - 1):
236
- for j in range(som_embedding.shape[1] - 1):
237
- if interactive:
238
- # Plot the SOM connections
239
- fig.add_trace(
240
- go.Scatter3d(
241
- x=np.append(som_embedding[i:i+2, j, 0], som_embedding[i, j:j+2, 0]),
242
- y=np.append(som_embedding[i:i+2, j, 1], som_embedding[i, j:j+2, 1]),
243
- z=np.append(som_embedding[i:i+2, j, 2], som_embedding[i, j:j+2, 2]),
244
- mode='lines',
245
- line=dict(color=colors[reindexed_label], width=2),
246
- showlegend=False
247
- )
248
- )
249
- else:
250
- # Plot the SOM connections
251
- ax.plot(som_embedding[i:i+2, j, 0], som_embedding[i:i+2, j, 1], som_embedding[i:i+2, j, 2], lw=1, c=colors[reindexed_label])
252
- ax.plot(som_embedding[i, j:j+2, 0], som_embedding[i, j:j+2, 1], som_embedding[i, j:j+2, 2], lw=1, c=colors[reindexed_label])
253
-
254
- if interactive:
255
- # Plot noise
256
- noise_data = embedding[self.hdbscan_model.labels_ == -1]
257
- if len(noise_data) > 0:
258
- fig.add_trace(
259
- go.Scatter3d(
260
- x=noise_data[:, 0],
261
- y=noise_data[:, 1],
262
- z=noise_data[:, 2],
263
- mode='markers',
264
- marker=dict(color="gray", size=1),
265
- name="Noise"
266
- )
267
- )
268
- fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
269
- fig.show()
270
- else:
271
- # Plot noise
272
- noise_data = embedding[self.hdbscan_model.labels_ == -1]
273
- if len(noise_data) > 0:
274
- ax.scatter(noise_data[:, 0], noise_data[:, 1], noise_data[:, 2], c="gray", label="Noise")
275
- ax.legend()
276
- plt.show()
277
-
278
-
279
- def plot_label_heatmap(self):
280
- """
281
- Plot a heatmap for each main cluster showing the best label for each coordinate in a single subplot layout.
282
- """
283
- if not hasattr(self, 'label_centroids'):
284
- raise ValueError("Labels not trained yet.")
285
-
286
- n_labels = len(self.label_centroids)
287
- label_colors = plt.cm.rainbow(np.linspace(0, 1, n_labels))
288
- n_clusters = len(self.som_models)
289
-
290
- # Create a subplot layout with a heatmap for each main cluster
291
- n_rows = int(np.ceil(np.sqrt(n_clusters)))
292
- n_cols = n_rows if n_rows * (n_rows - 1) < n_clusters else n_rows - 1
293
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10), squeeze=False)
294
-
295
- for i, (reindexed_label, som) in enumerate(self.som_models.items()):
296
- som_weights = som.get_weights()
297
- label_map = np.zeros(som_weights.shape[:2], dtype=int)
298
- label_distance_map = np.full(som_weights.shape[:2], np.inf)
299
-
300
- for label_idx, label_centroid in enumerate(self.label_centroids):
301
- for x in range(som_weights.shape[0]):
302
- for y in range(som_weights.shape[1]):
303
- node = som_weights[x, y]
304
- distance = np.linalg.norm(label_centroid - node)
305
-
306
- if distance < label_distance_map[x, y]:
307
- label_distance_map[x, y] = distance
308
- label_map[x, y] = label_idx
309
-
310
- row, col = i // n_cols, i % n_cols
311
- ax = axes[row, col]
312
- cmap = plt.cm.rainbow
313
- cmap.set_under(color='white')
314
- im = ax.imshow(label_map, cmap=cmap, origin='lower', interpolation='none', vmin=0.5)
315
- ax.set_xticks(range(label_map.shape[1]))
316
- ax.set_yticks(range(label_map.shape[0]))
317
- ax.grid(True, linestyle='-', linewidth=0.5)
318
- ax.set_title(f"Label Heatmap for Cluster {reindexed_label}")
319
 
320
- # Add a colorbar for label colors
321
- cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
322
- cbar = fig.colorbar(im, cax=cbar_ax, ticks=range(n_labels))
323
- cbar.ax.set_yticklabels(self.label_encodings.classes_)
324
 
325
- # Adjust the layout to fit everything nicely
326
- fig.subplots_adjust(wspace=0.5, hspace=0.5, right=0.9)
 
327
 
328
- plt.show()
329
-
330
-
331
- def plot_activation(self, data, filename='prediction_output', start=None, end=None):
332
  """
333
  Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
334
  """
@@ -344,10 +101,10 @@ class ClusterSOM:
344
  images = []
345
  for sample in tqdm(data[start:end], desc="Visualizing prediction output"):
346
  prediction = self.predict([sample])[0]
347
- # if prediction[0] == -1: # Noise
348
- # continue
349
 
350
- fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True)
351
  fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16)
352
 
353
  for idx, (som_key, som) in enumerate(self.som_models.items()):
@@ -363,25 +120,22 @@ class ClusterSOM:
363
  if som_key == prediction[0]: # Active SOM
364
  im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none')
365
  ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign
366
- ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold')
367
  if hasattr(self, 'label_centroids'):
368
  label_idx = self.label_encodings.inverse_transform([som_key - 1])[0]
369
  ax.set_xlabel(f"Label: {label_idx}", fontsize=12)
370
  else: # Inactive SOM
371
  im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none')
372
- ax.set_title(f"SOM {som_key}")
 
 
 
373
 
374
- ax.set_xticks(range(activation_map.shape[1]))
375
- ax.set_yticks(range(activation_map.shape[0]))
376
  ax.grid(True, linestyle='-', linewidth=0.5)
377
 
378
  # Create a colorbar for each frame
379
- fig.subplots_adjust(right=0.8)
380
- cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
381
- try:
382
- fig.colorbar(im_active, cax=cbar_ax)
383
- except:
384
- pass
385
 
386
  # Save the plot to a buffer
387
  buf = io.BytesIO()
@@ -396,29 +150,6 @@ class ClusterSOM:
396
 
397
  return video
398
 
399
- def save(self, file_path):
400
- """
401
- Save the ClusterSOM model to a file.
402
- """
403
- model_data = (self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping)
404
- if hasattr(self, 'label_centroids'):
405
- model_data += (self.label_centroids, self.label_encodings)
406
-
407
- with open(file_path, "wb") as f:
408
- pickle.dump(model_data, f)
409
-
410
- def load(self, file_path):
411
- """
412
- Load a ClusterSOM model from a file.
413
- """
414
- with open(file_path, "rb") as f:
415
- model_data = pickle.load(f)
416
-
417
- self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5]
418
- if len(model_data) > 5:
419
- self.label_centroids, self.label_encodings = model_data[5:]
420
-
421
-
422
  def plot_activation_v2(self, data, slice_select):
423
  """
424
  Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
@@ -462,47 +193,4 @@ class ClusterSOM:
462
  plt.tight_layout()
463
 
464
  return fig
465
-
466
- def plot_activation_v3(self, data, slice_select):
467
- """
468
- Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
469
- """
470
- if len(self.som_models) == 0:
471
- raise ValueError("SOM models not trained yet.")
472
-
473
- try:
474
- prediction = self.predict([data[int(slice_select)-1]])[0]
475
- except:
476
- prediction = self.predict([data[int(slice_select)-2]])[0]
477
-
478
- fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True)
479
- fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16)
480
-
481
- for idx, (som_key, som) in enumerate(self.som_models.items()):
482
- ax = axes[idx]
483
- activation_map = np.zeros(som._weights.shape[:2])
484
- for x in range(som._weights.shape[0]):
485
- for y in range(som._weights.shape[1]):
486
- activation_map[x, y] = np.linalg.norm(data[int(slice_select)-1] - som._weights[x, y])
487
-
488
- winner = som.winner(data[int(slice_select)-1]) # Find the BMU for this SOM
489
- activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap
490
-
491
- if som_key == prediction[0]: # Active SOM
492
- im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none')
493
- ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign
494
- ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold')
495
- if hasattr(self, 'label_centroids'):
496
- label_idx = self.label_encodings.inverse_transform([som_key - 1])[0]
497
- ax.set_xlabel(f"Label: {label_idx}", fontsize=12)
498
- else: # Inactive SOM
499
- im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none')
500
- ax.set_title(f"SOM {som_key}")
501
-
502
- ax.set_xticks(range(activation_map.shape[1]))
503
- ax.set_yticks(range(activation_map.shape[0]))
504
- ax.grid(True, linestyle='-', linewidth=0.5)
505
-
506
- plt.tight_layout()
507
-
508
- return fig
 
1
+ import io
2
+ import math
 
3
  import pickle
 
 
 
4
  import imageio
5
+
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
  from tqdm import tqdm
 
 
 
 
 
 
 
 
10
  from moviepy.editor import ImageSequenceClip
11
 
12
  class ClusterSOM:
 
18
  self.cluster_mapping = {}
19
  self.embedding = None
20
  self.dim_red_op = None
21
+
22
+ def load(self, file_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  """
24
+ Load a ClusterSOM model from a file.
25
  """
26
+ with open(file_path, "rb") as f:
27
+ model_data = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5]
30
+ if len(model_data) > 5:
31
+ self.label_centroids, self.label_encodings = model_data[5:]
32
+
33
  def predict(self, data, sigma_factor=1.5):
34
  """
35
  Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value.
 
66
 
67
  return results
68
 
69
+ # rearranging the subplots in the closest square format
70
+ def rearrange_subplots(self, num_subplots):
71
+ # Calculate the number of rows and columns for the subplot grid
72
+ num_rows = math.isqrt(num_subplots)
73
+ num_cols = math.ceil(num_subplots / num_rows)
 
 
 
 
 
 
 
 
 
74
 
75
+ # Create the figure and subplots
76
+ fig, axes = plt.subplots(num_rows, num_cols, sharex=True, sharey=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Flatten the axes array if it is multidimensional
79
+ if isinstance(axes, np.ndarray):
80
+ axes = axes.flatten()
 
81
 
82
+ # Hide any empty subplots
83
+ for i in range(num_subplots, len(axes)):
84
+ axes[i].axis('off')
85
 
86
+ return fig, axes
87
+
88
+ def plot_activation(self, data, start=None, end=None):
 
89
  """
90
  Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
91
  """
 
101
  images = []
102
  for sample in tqdm(data[start:end], desc="Visualizing prediction output"):
103
  prediction = self.predict([sample])[0]
104
+
105
+ fig, axes = self.rearrange_subplots(len(self.som_models))
106
 
107
+ # fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True)
108
  fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16)
109
 
110
  for idx, (som_key, som) in enumerate(self.som_models.items()):
 
120
  if som_key == prediction[0]: # Active SOM
121
  im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none')
122
  ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign
123
+ ax.set_title(f"A {som_key}", color='blue', fontweight='bold', fontsize=10)
124
  if hasattr(self, 'label_centroids'):
125
  label_idx = self.label_encodings.inverse_transform([som_key - 1])[0]
126
  ax.set_xlabel(f"Label: {label_idx}", fontsize=12)
127
  else: # Inactive SOM
128
  im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none')
129
+ ax.set_title(f"A {som_key}", fontsize=10)
130
+
131
+ ax.set_xticks([])
132
+ ax.set_yticks([])
133
 
 
 
134
  ax.grid(True, linestyle='-', linewidth=0.5)
135
 
136
  # Create a colorbar for each frame
137
+ plt.tight_layout()
138
+ fig.subplots_adjust(wspace=0, hspace=0)
 
 
 
 
139
 
140
  # Save the plot to a buffer
141
  buf = io.BytesIO()
 
150
 
151
  return video
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def plot_activation_v2(self, data, slice_select):
154
  """
155
  Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
 
193
  plt.tight_layout()
194
 
195
  return fig
196
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/cluster_som6.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33382cbda76042b3ed585814f52d5a82f64c042e9721a630e19e12363f2dbf4f
3
+ size 9207489
models/r10d_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6bb76c4aaae152ed11e4cd16e63a24ccd3ce684092521489f576ae27f62ea19
3
+ size 13100259