arcan3 commited on
Commit
46fcc2f
·
1 Parent(s): 03507e5

added labelling stages

Browse files
.gitignore CHANGED
@@ -1,4 +1,7 @@
1
  # Byte-compiled / optimized / DLL files
 
 
 
2
  __pycache__/
3
  *.py[cod]
4
  *$py.class
 
1
  # Byte-compiled / optimized / DLL files
2
+ *.zip
3
+ Data-*
4
+ drive-*
5
  __pycache__/
6
  *.py[cod]
7
  *$py.class
app.py CHANGED
@@ -10,8 +10,7 @@ with gr.Blocks(title='Cabasus') as cabasus_sensor:
10
  with gr.Row():
11
  processed_file_box = gr.File(label='Processed CSV File')
12
  json_file_box = gr.File(label='Generated Json file')
13
-
14
- video_box = gr.PlayableVideo(label='Video box')
15
  with gr.Row():
16
  slice_size_slider = gr.inputs.Slider(16, 512, 1, 64, label="Slice Size")
17
  sample_rate = gr.inputs.Slider(1, 199, 1, 20, label="Sample rate")
@@ -26,13 +25,20 @@ with gr.Blocks(title='Cabasus') as cabasus_sensor:
26
  plot_box_overlay = gr.Plot(label="Overlay Signal Plot")
27
 
28
  with gr.Row():
29
- slice_slider = gr.Slider(minimum=0, maximum=300, label='Current slice', step=1)
30
 
31
- slices_per_leg = gr.Textbox(label="Number of slices found per LEG")
 
 
32
 
33
- csv_file_box.change(process_data, inputs=[csv_file_box, slice_size_slider, sample_rate, window_size_slider], outputs=[processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider])
34
- leg_dropdown.change(plot_sensor_data_from_json, inputs=[json_file_box, leg_dropdown], outputs=[plot_box_leg])
35
- repeat_process.click(process_data, inputs=[csv_file_box, slice_size_slider, sample_rate, window_size_slider], outputs=[processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider])
 
36
 
 
 
 
 
37
 
38
  cabasus_sensor.queue(concurrency_count=2).launch(debug=True)
 
10
  with gr.Row():
11
  processed_file_box = gr.File(label='Processed CSV File')
12
  json_file_box = gr.File(label='Generated Json file')
13
+
 
14
  with gr.Row():
15
  slice_size_slider = gr.inputs.Slider(16, 512, 1, 64, label="Slice Size")
16
  sample_rate = gr.inputs.Slider(1, 199, 1, 20, label="Sample rate")
 
25
  plot_box_overlay = gr.Plot(label="Overlay Signal Plot")
26
 
27
  with gr.Row():
28
+ slice_slider = gr.Slider(minimum=1, maximum=300, label='Current slice', step=1)
29
 
30
+ with gr.Row():
31
+ plot_slice_leg = gr.Plot(label="Sliced Signal Plot")
32
+ get_real_slice = gr.Plot(label="Real Signal Plot")
33
 
34
+ with gr.Row():
35
+ animation = gr.PlayableVideo(label="Animated horse steps")
36
+
37
+ slices_per_leg = gr.Textbox(label="Number of slices found per LEG")
38
 
39
+ csv_file_box.change(process_data, inputs=[csv_file_box, slice_size_slider, sample_rate, window_size_slider], outputs=[processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_real_slice])
40
+ leg_dropdown.change(plot_sensor_data_from_json, inputs=[json_file_box, leg_dropdown, slice_slider], outputs=[plot_box_leg, plot_slice_leg, get_real_slice])
41
+ repeat_process.click(process_data, inputs=[csv_file_box, slice_size_slider, sample_rate, window_size_slider], outputs=[processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_real_slice])
42
+ slice_slider.change(plot_sensor_data_from_json, inputs=[json_file_box, leg_dropdown, slice_slider], outputs=[plot_box_leg, plot_slice_leg, get_real_slice])
43
 
44
  cabasus_sensor.queue(concurrency_count=2).launch(debug=True)
funcs/dataloader.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob, json, os
2
+ import torch
3
+ import warnings
4
+
5
+ from torch.utils.data import Dataset
6
+
7
+ class BaseDataset2(Dataset):
8
+ """Template class for all datasets in the project."""
9
+
10
+ def __init__(self, x, y):
11
+ """Initialize dataset.
12
+
13
+ Args:
14
+ x(ndarray): Input features.
15
+ y(ndarray): Targets.
16
+ """
17
+ self.data = torch.from_numpy(x).float()
18
+ self.targets = torch.from_numpy(y).float()
19
+ self.latents = None
20
+
21
+ self.labels = None
22
+ self.is_radial = []
23
+ self.partition = True
24
+
25
+ def __getitem__(self, index):
26
+ return self.data[index], self.targets[index], index
27
+
28
+ def __len__(self):
29
+ return len(self.data)
30
+
31
+ def numpy(self, idx=None):
32
+ """Get dataset as ndarray.
33
+
34
+ Specify indices to return a subset of the dataset, otherwise return whole dataset.
35
+
36
+ Args:
37
+ idx(int, optional): Specify index or indices to return.
38
+
39
+ Returns:
40
+ ndarray: Return flattened dataset as a ndarray.
41
+
42
+ """
43
+ n = len(self)
44
+
45
+ data = self.data.numpy().reshape((n, -1))
46
+
47
+ if idx is None:
48
+ return data, self.targets.numpy()
49
+ else:
50
+ return data[idx], self.targets[idx].numpy()
51
+
52
+ def get_latents(self):
53
+ """Get latent variables.
54
+
55
+ Returns:
56
+ latents(ndarray): Latent variables for each sample.
57
+ """
58
+ return self.latents
59
+
60
+
61
+ def load_json(file_path):
62
+ with open(file_path, 'r') as f:
63
+ data = json.load(f)
64
+ return data
65
+
66
+ def read_json_files(file):
67
+ data_x = []
68
+ data_y = []
69
+
70
+ samples = load_json(file)
71
+ valid_samples = 0
72
+
73
+ for sample in samples:
74
+ data = []
75
+ skip_sample = False
76
+ for key in ['AX1', 'AX2', 'AX3', 'AX4', 'AY1', 'AY2', 'AY3', 'AY4', 'AZ1', 'AZ2', 'AZ3', 'AZ4', 'GX1', 'GX2', 'GX3', 'GX4', 'GY1', 'GY2', 'GY3', 'GY4', 'GZ1', 'GZ2', 'GZ3', 'GZ4', 'GZ1_precise_time_diff', 'GZ2_precise_time_diff', 'GZ3_precise_time_diff', 'GZ4_precise_time_diff', 'precise_time_diff']:
77
+ if key in sample:
78
+ if key.endswith('_precise_time_diff') or key == 'precise_time_diff':
79
+ if sample[key] is None:
80
+ skip_sample = True
81
+ break
82
+ data.append(round(sample[key])*20)
83
+ else:
84
+ data.extend(sample[key])
85
+ else:
86
+ warnings.warn(f"KeyError: {key} not found in JSON file: {file}")
87
+
88
+ if skip_sample:
89
+ #warnings.warn(f"Skipped sample with null values in JSON file: {json_file}")
90
+ continue
91
+
92
+ if len(data) != 768*2 + 5: # 24 keys * 64 values each + 5 additional values
93
+ warnings.warn(f"Incomplete sample in JSON file: {file}")
94
+ continue
95
+
96
+ valid_samples += 1
97
+ tensor = torch.tensor(data, dtype=torch.float32)
98
+ data_x.append(tensor)
99
+ data_y.append(1)
100
+
101
+ if valid_samples == 0:
102
+ warnings.warn(f"No valid samples found in JSON file: {file}")
103
+
104
+ if not data_x:
105
+ raise ValueError("No valid samples found in all the JSON files.")
106
+
107
+ return torch.stack(data_x), torch.tensor(data_y, dtype=torch.long)
funcs/plot_func.py CHANGED
@@ -1,10 +1,12 @@
1
  import json
2
  import matplotlib
3
 
 
4
  import pandas as pd
5
  import matplotlib.pyplot as plt
6
 
7
  matplotlib.use('Agg')
 
8
 
9
  def plot_sensor_data_from_json(json_file, sensor, slice_select=1):
10
  # Read the JSON file
@@ -18,35 +20,45 @@ def plot_sensor_data_from_json(json_file, sensor, slice_select=1):
18
  # Concatenate the slices and create a new timestamp series with 20ms intervals
19
  timestamps = []
20
  sensor_data = []
21
- for slice_dict in slices:
 
 
22
  start_timestamp = slice_dict["timestamp"]
23
  slice_length = len(slice_dict[sensor])
24
 
25
- slice_timestamps = [start_timestamp + 20 * i for i in range(slice_length)]
26
  timestamps.extend(slice_timestamps)
27
  sensor_data.extend(slice_dict[sensor])
28
 
 
 
 
29
  # Create a DataFrame with the sensor data
30
- data = pd.DataFrame({sensor: sensor_data}, index=timestamps)
31
 
32
  # Plot the sensor data
33
  fig, ax = plt.subplots(figsize=(12, 6))
34
- ax = plt.plot(data[sensor], label=sensor)
35
 
36
- # Mark the slice start and end points
37
- for slice_dict in slices:
38
- start_timestamp = slice_dict["timestamp"]
39
- end_timestamp = start_timestamp + 20 * (len(slice_dict[sensor]) - 1)
 
 
 
 
 
40
 
41
- plt.axvline(x=start_timestamp, color='black', linestyle=':', label='Start' if start_timestamp == slices[0]["timestamp"] else None)
42
- plt.axvline(x=end_timestamp, color='red', linestyle=':', label='End' if end_timestamp == slices[0]["timestamp"] + 20 * (len(slices[0][sensor]) - 1) else None)
43
 
44
  plt.xlabel("Timestamp")
45
  plt.ylabel(sensor)
46
  plt.legend()
47
  plt.tight_layout()
48
 
49
- return fig
50
 
51
  def plot_overlay_data_from_json(json_file, sensors, use_precise_timestamp=False):
52
  # Read the JSON file
 
1
  import json
2
  import matplotlib
3
 
4
+ import numpy as np
5
  import pandas as pd
6
  import matplotlib.pyplot as plt
7
 
8
  matplotlib.use('Agg')
9
+ plt.style.use('ggplot')
10
 
11
  def plot_sensor_data_from_json(json_file, sensor, slice_select=1):
12
  # Read the JSON file
 
20
  # Concatenate the slices and create a new timestamp series with 20ms intervals
21
  timestamps = []
22
  sensor_data = []
23
+ slice_item = []
24
+ temp_end = 0
25
+ for slice_count, slice_dict in enumerate(slices):
26
  start_timestamp = slice_dict["timestamp"]
27
  slice_length = len(slice_dict[sensor])
28
 
29
+ slice_timestamps = [start_timestamp + 20 * i for i in range(temp_end, slice_length + temp_end)]
30
  timestamps.extend(slice_timestamps)
31
  sensor_data.extend(slice_dict[sensor])
32
 
33
+ temp_end += slice_length
34
+ slice_item.extend([slice_count+1]*len(slice_timestamps))
35
+
36
  # Create a DataFrame with the sensor data
37
+ data = pd.DataFrame({sensor: sensor_data, 'slice selection': slice_item, 'time': timestamps})
38
 
39
  # Plot the sensor data
40
  fig, ax = plt.subplots(figsize=(12, 6))
41
+ ax = plt.plot(data['time'].to_list(), data[sensor].to_list())
42
 
43
+ df_temp = data[data['slice selection'] == int(slice_select)].reset_index()
44
+ y = [np.NaN]*((int(slice_select)-1)*len(df_temp[sensor].to_list())) + df_temp[sensor].to_list() + [np.NaN]*((len(slices) - int(slice_select))*len(df_temp[sensor].to_list()))
45
+ x = data['time'].to_list()
46
+ ax = plt.plot(x, y, '-')
47
+
48
+ plt.xlabel("Timestamp")
49
+ plt.ylabel(sensor)
50
+ plt.legend()
51
+ plt.tight_layout()
52
 
53
+ fig1, ax1 = plt.subplots(figsize=(12, 6))
54
+ ax1 = plt.plot(df_temp['time'].to_list(), df_temp[sensor].to_list())
55
 
56
  plt.xlabel("Timestamp")
57
  plt.ylabel(sensor)
58
  plt.legend()
59
  plt.tight_layout()
60
 
61
+ return fig, fig1
62
 
63
  def plot_overlay_data_from_json(json_file, sensors, use_precise_timestamp=False):
64
  # Read the JSON file
funcs/processor.py CHANGED
@@ -9,11 +9,11 @@ def process_data(input_file, slice_size=64, min_slice_size=16, sample_rate=20, w
9
  # Read the data from the file, including the CRC column
10
  try:
11
  if input_file.name is None:
12
- return None, None, None, None, None, None
13
  data = pd.read_csv(input_file.name, delimiter=";", index_col="NR", usecols=["NR", "TS", "LEG", "GX", "GY", "GZ", "AX", "AY", "AZ", "CRC"])
14
  except:
15
  if input_file is None:
16
- return None, None, None, None, None, None
17
  data = pd.read_csv(input_file, delimiter=";", index_col="NR", usecols=["NR", "TS", "LEG", "GX", "GY", "GZ", "AX", "AY", "AZ", "CRC"])
18
 
19
 
@@ -69,7 +69,7 @@ def process_data(input_file, slice_size=64, min_slice_size=16, sample_rate=20, w
69
  if not no_significant_change_index.empty:
70
  # Save the data up to the point where no significant change appears in all channels
71
  data = data.loc[:no_significant_change_index[0]]
72
- return None, None, f'Warning: gap of {gap_size} ms found at line {gap_start_index}', None, None, None
73
 
74
  # Save the resulting DataFrame to a new file
75
  data.to_csv('output.csv', sep=";", na_rep="NaN", float_format="%.0f")
@@ -77,10 +77,10 @@ def process_data(input_file, slice_size=64, min_slice_size=16, sample_rate=20, w
77
  file, len_ = slice_csv_to_json('output.csv', slice_size, min_slice_size, sample_rate, window_size=window_size)
78
 
79
  # get the plot automatically
80
- sensor_fig = plot_sensor_data_from_json(file, "GZ1")
81
  overlay_fig = plot_overlay_data_from_json(file, ["GZ1", "GZ2", "GZ3", "GZ4"], use_precise_timestamp=True)
82
 
83
  #
84
 
85
 
86
- return 'output.csv', file, f'{len_}', sensor_fig, overlay_fig, gr.Slider.update(interactive=True, maximum=len_, minimum=1, value=1)
 
9
  # Read the data from the file, including the CRC column
10
  try:
11
  if input_file.name is None:
12
+ return None, None, None, None, None, None, None, None
13
  data = pd.read_csv(input_file.name, delimiter=";", index_col="NR", usecols=["NR", "TS", "LEG", "GX", "GY", "GZ", "AX", "AY", "AZ", "CRC"])
14
  except:
15
  if input_file is None:
16
+ return None, None, None, None, None, None, None, None
17
  data = pd.read_csv(input_file, delimiter=";", index_col="NR", usecols=["NR", "TS", "LEG", "GX", "GY", "GZ", "AX", "AY", "AZ", "CRC"])
18
 
19
 
 
69
  if not no_significant_change_index.empty:
70
  # Save the data up to the point where no significant change appears in all channels
71
  data = data.loc[:no_significant_change_index[0]]
72
+ return None, None, f'Warning: gap of {gap_size} ms found at line {gap_start_index}', None, None, None, None, None
73
 
74
  # Save the resulting DataFrame to a new file
75
  data.to_csv('output.csv', sep=";", na_rep="NaN", float_format="%.0f")
 
77
  file, len_ = slice_csv_to_json('output.csv', slice_size, min_slice_size, sample_rate, window_size=window_size)
78
 
79
  # get the plot automatically
80
+ sensor_fig, slice_fig = plot_sensor_data_from_json(file, "GZ1")
81
  overlay_fig = plot_overlay_data_from_json(file, ["GZ1", "GZ2", "GZ3", "GZ4"], use_precise_timestamp=True)
82
 
83
  #
84
 
85
 
86
+ return 'output.csv', file, f'{len_}', sensor_fig, overlay_fig, gr.Slider.update(interactive=True, maximum=len_, minimum=1, value=1), slice_fig, None
funcs/som.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import hdbscan
3
+ from minisom import MiniSom
4
+ import pickle
5
+ from collections import Counter
6
+ import matplotlib.pyplot as plt
7
+ import phate
8
+ import imageio
9
+ from tqdm import tqdm
10
+ import io
11
+ import plotly.graph_objs as go
12
+ import plotly.subplots as sp
13
+ import umap
14
+ from sklearn.datasets import make_blobs
15
+ from sklearn.preprocessing import LabelEncoder
16
+ from sklearn.cluster import KMeans
17
+ from sklearn.semi_supervised import LabelSpreading
18
+ from moviepy.editor import *
19
+
20
+ class ClusterSOM:
21
+ def __init__(self):
22
+ self.hdbscan_model = None
23
+ self.som_models = {}
24
+ self.sigma_values = {}
25
+ self.mean_values = {}
26
+ self.cluster_mapping = {}
27
+ self.embedding = None
28
+ self.dim_red_op = None
29
+
30
+ def train(self, dataset, min_samples_per_cluster=100, n_clusters=None, som_size=(20, 20), sigma=1.0, learning_rate=0.5, num_iteration=200000, random_seed=42, n_neighbors=5, coverage=0.95):
31
+ """
32
+ Train HDBSCAN and SOM models on the given dataset.
33
+ """
34
+ # Train HDBSCAN model
35
+ print('Identifying clusters in the embedding ...')
36
+ self.hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=min_samples_per_cluster)
37
+ self.hdbscan_model.fit(dataset)
38
+
39
+ # Calculate n_clusters if not provided
40
+ if n_clusters is None:
41
+ cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common())
42
+ cluster_labels = list(cluster_labels)
43
+ total_points = sum(counts)
44
+ covered_points = 0
45
+ n_clusters = 0
46
+ for count in counts:
47
+ covered_points += count
48
+ n_clusters += 1
49
+ if covered_points / total_points >= coverage:
50
+ break
51
+
52
+ # Train SOM models for the n_clusters most common clusters in the HDBSCAN model
53
+ cluster_labels, counts = zip(*Counter(self.hdbscan_model.labels_).most_common(n_clusters + 1))
54
+ cluster_labels = list(cluster_labels)
55
+
56
+ if -1 in cluster_labels:
57
+ cluster_labels.remove(-1)
58
+ else:
59
+ cluster_labels.pop()
60
+
61
+ for i, label in tqdm(enumerate(cluster_labels), total=len(cluster_labels), desc="Fitting 2D maps"):
62
+ if label == -1:
63
+ continue # Ignore noise
64
+ cluster_data = dataset[self.hdbscan_model.labels_ == label]
65
+ som = MiniSom(som_size[0], som_size[1], dataset.shape[1], sigma=sigma, learning_rate=learning_rate, random_seed=random_seed)
66
+ som.train_random(cluster_data, num_iteration)
67
+ self.som_models[i+1] = som
68
+ self.cluster_mapping[i+1] = label
69
+
70
+ # Compute sigma values
71
+ mean_cluster, sigma_cluster = self.compute_sigma_values(cluster_data, som_size, som, n_neighbors=n_neighbors)
72
+ self.sigma_values[i+1] = sigma_cluster
73
+ self.mean_values[i+1] = mean_cluster
74
+
75
+ def compute_sigma_values(self, cluster_data, som_size, som, n_neighbors=5):
76
+ som_weights = som.get_weights()
77
+
78
+ # Assign each datapoint to its nearest node
79
+ partitions = {idx: [] for idx in np.ndindex(som_size[0], som_size[1])}
80
+ for sample in cluster_data:
81
+ x, y = som.winner(sample)
82
+ partitions[(x, y)].append(sample)
83
+
84
+ # Compute the mean distance and std deviation of these partitions
85
+ mean_cluster = np.zeros(som_size)
86
+ sigma_cluster = np.zeros(som_size)
87
+ for idx in partitions:
88
+ if len(partitions[idx]) > 0:
89
+ partition_data = np.array(partitions[idx])
90
+ mean_distance = np.mean(np.linalg.norm(partition_data - som_weights[idx], axis=-1))
91
+ std_distance = np.std(np.linalg.norm(partition_data - som_weights[idx], axis=-1))
92
+ else:
93
+ mean_distance = 0
94
+ std_distance = 0
95
+ mean_cluster[idx] = mean_distance
96
+ sigma_cluster[idx] = std_distance
97
+
98
+ return mean_cluster, sigma_cluster
99
+
100
+ def train_label(self, labeled_data, labels):
101
+ """
102
+ Train on labeled data to find centroids and compute distances to the labels.
103
+ """
104
+ le = LabelEncoder()
105
+ encoded_labels = le.fit_transform(labels)
106
+ unique_labels = np.unique(encoded_labels)
107
+
108
+ # Use label spreading to propagate the labels
109
+ label_prop_model = LabelSpreading(kernel='knn', n_neighbors=5)
110
+ label_prop_model.fit(labeled_data, encoded_labels)
111
+
112
+ # Find the centroids for each label using KMeans
113
+ kmeans = KMeans(n_clusters=len(unique_labels), random_state=42)
114
+ kmeans.fit(labeled_data)
115
+
116
+ # Store the label centroids and label encodings
117
+ self.label_centroids = kmeans.cluster_centers_
118
+ self.label_encodings = le
119
+
120
+ def predict(self, data, sigma_factor=1.5):
121
+ """
122
+ Predict the cluster and BMU SOM coordinate for each sample in the data if it's inside the sigma value.
123
+ Also, predict the label and distance to the center of the label if labels are trained.
124
+ """
125
+ results = []
126
+
127
+ for sample in data:
128
+ min_distance = float('inf')
129
+ nearest_cluster_idx = None
130
+ nearest_node = None
131
+
132
+ for i, som in self.som_models.items():
133
+ x, y = som.winner(sample)
134
+ node = som.get_weights()[x, y]
135
+ distance = np.linalg.norm(sample - node)
136
+
137
+ if distance < min_distance:
138
+ min_distance = distance
139
+ nearest_cluster_idx = i
140
+ nearest_node = (x, y)
141
+
142
+ # Check if the nearest node is within the sigma value
143
+ if min_distance <= self.mean_values[nearest_cluster_idx][nearest_node] * 1.5: # * self.sigma_values[nearest_cluster_idx][nearest_node] * sigma_factor:
144
+ if hasattr(self, 'label_centroids'):
145
+ # Predict the label and distance to the center of the label
146
+ label_idx = self.label_encodings.inverse_transform([nearest_cluster_idx - 1])[0]
147
+ label_distance = np.linalg.norm(sample - self.label_centroids[label_idx])
148
+ results.append((nearest_cluster_idx, nearest_node, label_idx, label_distance))
149
+ else:
150
+ results.append((nearest_cluster_idx, nearest_node))
151
+ else:
152
+ results.append((-1, None)) # Noise
153
+
154
+ return results
155
+
156
+ def plot_embedding(self, new_data=None, dim_reduction='umap', interactive=False):
157
+ """
158
+ Plot the dataset and SOM grids for each cluster.
159
+ If new_data is provided, it will be used for plotting instead of the entire dataset.
160
+ """
161
+
162
+ if self.hdbscan_model is None:
163
+ raise ValueError("HDBSCAN model not trained yet.")
164
+
165
+ if len(self.som_models) == 0:
166
+ raise ValueError("SOM models not trained yet.")
167
+
168
+ if dim_reduction not in ['phate', 'umap']:
169
+ raise ValueError("Invalid dimensionality reduction method. Use 'phate' or 'umap'.")
170
+
171
+ if self.dim_red_op is None or self.embedding is None:
172
+ n_components = 3
173
+ if dim_reduction == 'phate':
174
+ self.dim_red_op = phate.PHATE(n_components=n_components, random_state=42)
175
+ elif dim_reduction == 'umap':
176
+ self.dim_red_op = umap.UMAP(n_components=n_components, random_state=42)
177
+
178
+ self.embedding = self.dim_red_op.fit_transform(new_data)
179
+
180
+ if new_data is not None:
181
+ new_embedding = self.dim_red_op.transform(new_data)
182
+ else:
183
+ new_embedding = self.embedding
184
+
185
+ if interactive:
186
+ fig = sp.make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter3d'}]])
187
+ else:
188
+ fig = plt.figure(figsize=(30, 30))
189
+ ax = fig.add_subplot(111, projection='3d')
190
+
191
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(self.som_models) + 1))
192
+
193
+ for reindexed_label, som in self.som_models.items():
194
+ original_label = self.cluster_mapping[reindexed_label]
195
+ cluster_data = embedding[self.hdbscan_model.labels_ == original_label]
196
+ som_weights = som.get_weights()
197
+
198
+ som_embedding = dim_red_op.transform(som_weights.reshape(-1, dataset.shape[1])).reshape(som_weights.shape[0], som_weights.shape[1], n_components)
199
+
200
+ if interactive:
201
+ # Plot the original data points
202
+ fig.add_trace(
203
+ go.Scatter3d(
204
+ x=cluster_data[:, 0],
205
+ y=cluster_data[:, 1],
206
+ z=cluster_data[:, 2],
207
+ mode='markers',
208
+ marker=dict(color=colors[reindexed_label], size=1),
209
+ name=f"Cluster {reindexed_label}"
210
+ )
211
+ )
212
+ else:
213
+ # Plot the original data points
214
+ ax.scatter(cluster_data[:, 0], cluster_data[:, 1], cluster_data[:, 2], c=[colors[reindexed_label]], alpha=0.3, s=5, label=f"Cluster {reindexed_label}")
215
+
216
+ for x in range(som_embedding.shape[0]):
217
+ for y in range(som_embedding.shape[1]):
218
+ if interactive:
219
+ # Plot the SOM grid
220
+ fig.add_trace(
221
+ go.Scatter3d(
222
+ x=[som_embedding[x, y, 0]],
223
+ y=[som_embedding[x, y, 1]],
224
+ z=[som_embedding[x, y, 2]],
225
+ mode='markers+text',
226
+ marker=dict(color=colors[reindexed_label], size=3, symbol='circle'),
227
+ text=[f"{x},{y}"],
228
+ textposition="top center"
229
+ )
230
+ )
231
+ else:
232
+ # Plot the SOM grid
233
+ ax.plot([som_embedding[x, y, 0]], [som_embedding[x, y, 1]], [som_embedding[x, y, 2]], '+', markersize=8, mew=2, zorder=10, c=colors[reindexed_label])
234
+
235
+ for i in range(som_embedding.shape[0] - 1):
236
+ for j in range(som_embedding.shape[1] - 1):
237
+ if interactive:
238
+ # Plot the SOM connections
239
+ fig.add_trace(
240
+ go.Scatter3d(
241
+ x=np.append(som_embedding[i:i+2, j, 0], som_embedding[i, j:j+2, 0]),
242
+ y=np.append(som_embedding[i:i+2, j, 1], som_embedding[i, j:j+2, 1]),
243
+ z=np.append(som_embedding[i:i+2, j, 2], som_embedding[i, j:j+2, 2]),
244
+ mode='lines',
245
+ line=dict(color=colors[reindexed_label], width=2),
246
+ showlegend=False
247
+ )
248
+ )
249
+ else:
250
+ # Plot the SOM connections
251
+ ax.plot(som_embedding[i:i+2, j, 0], som_embedding[i:i+2, j, 1], som_embedding[i:i+2, j, 2], lw=1, c=colors[reindexed_label])
252
+ ax.plot(som_embedding[i, j:j+2, 0], som_embedding[i, j:j+2, 1], som_embedding[i, j:j+2, 2], lw=1, c=colors[reindexed_label])
253
+
254
+ if interactive:
255
+ # Plot noise
256
+ noise_data = embedding[self.hdbscan_model.labels_ == -1]
257
+ if len(noise_data) > 0:
258
+ fig.add_trace(
259
+ go.Scatter3d(
260
+ x=noise_data[:, 0],
261
+ y=noise_data[:, 1],
262
+ z=noise_data[:, 2],
263
+ mode='markers',
264
+ marker=dict(color="gray", size=1),
265
+ name="Noise"
266
+ )
267
+ )
268
+ fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
269
+ fig.show()
270
+ else:
271
+ # Plot noise
272
+ noise_data = embedding[self.hdbscan_model.labels_ == -1]
273
+ if len(noise_data) > 0:
274
+ ax.scatter(noise_data[:, 0], noise_data[:, 1], noise_data[:, 2], c="gray", label="Noise")
275
+ ax.legend()
276
+ plt.show()
277
+
278
+
279
+ def plot_label_heatmap(self):
280
+ """
281
+ Plot a heatmap for each main cluster showing the best label for each coordinate in a single subplot layout.
282
+ """
283
+ if not hasattr(self, 'label_centroids'):
284
+ raise ValueError("Labels not trained yet.")
285
+
286
+ n_labels = len(self.label_centroids)
287
+ label_colors = plt.cm.rainbow(np.linspace(0, 1, n_labels))
288
+ n_clusters = len(self.som_models)
289
+
290
+ # Create a subplot layout with a heatmap for each main cluster
291
+ n_rows = int(np.ceil(np.sqrt(n_clusters)))
292
+ n_cols = n_rows if n_rows * (n_rows - 1) < n_clusters else n_rows - 1
293
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10), squeeze=False)
294
+
295
+ for i, (reindexed_label, som) in enumerate(self.som_models.items()):
296
+ som_weights = som.get_weights()
297
+ label_map = np.zeros(som_weights.shape[:2], dtype=int)
298
+ label_distance_map = np.full(som_weights.shape[:2], np.inf)
299
+
300
+ for label_idx, label_centroid in enumerate(self.label_centroids):
301
+ for x in range(som_weights.shape[0]):
302
+ for y in range(som_weights.shape[1]):
303
+ node = som_weights[x, y]
304
+ distance = np.linalg.norm(label_centroid - node)
305
+
306
+ if distance < label_distance_map[x, y]:
307
+ label_distance_map[x, y] = distance
308
+ label_map[x, y] = label_idx
309
+
310
+ row, col = i // n_cols, i % n_cols
311
+ ax = axes[row, col]
312
+ cmap = plt.cm.rainbow
313
+ cmap.set_under(color='white')
314
+ im = ax.imshow(label_map, cmap=cmap, origin='lower', interpolation='none', vmin=0.5)
315
+ ax.set_xticks(range(label_map.shape[1]))
316
+ ax.set_yticks(range(label_map.shape[0]))
317
+ ax.grid(True, linestyle='-', linewidth=0.5)
318
+ ax.set_title(f"Label Heatmap for Cluster {reindexed_label}")
319
+
320
+ # Add a colorbar for label colors
321
+ cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
322
+ cbar = fig.colorbar(im, cax=cbar_ax, ticks=range(n_labels))
323
+ cbar.ax.set_yticklabels(self.label_encodings.classes_)
324
+
325
+ # Adjust the layout to fit everything nicely
326
+ fig.subplots_adjust(wspace=0.5, hspace=0.5, right=0.9)
327
+
328
+ plt.show()
329
+
330
+
331
+ def plot_activation(self, data, filename='prediction_output', start=None, end=None):
332
+ """
333
+ Generate a GIF visualization of the prediction output using the activation maps of individual SOMs.
334
+ """
335
+ if len(self.som_models) == 0:
336
+ raise ValueError("SOM models not trained yet.")
337
+
338
+ if start is None:
339
+ start = 0
340
+
341
+ if end is None:
342
+ end = len(data)
343
+
344
+ images = []
345
+ for sample in tqdm(data[start:end], desc="Visualizing prediction output"):
346
+ prediction = self.predict([sample])[0]
347
+ if prediction[0] == -1: # Noise
348
+ continue
349
+
350
+ fig, axes = plt.subplots(1, len(self.som_models), figsize=(20, 5), sharex=True, sharey=True)
351
+ fig.suptitle(f"Activation map for SOM {prediction[0]}, node {prediction[1]}", fontsize=16)
352
+
353
+ for idx, (som_key, som) in enumerate(self.som_models.items()):
354
+ ax = axes[idx]
355
+ activation_map = np.zeros(som._weights.shape[:2])
356
+ for x in range(som._weights.shape[0]):
357
+ for y in range(som._weights.shape[1]):
358
+ activation_map[x, y] = np.linalg.norm(sample - som._weights[x, y])
359
+
360
+ winner = som.winner(sample) # Find the BMU for this SOM
361
+ activation_map[winner] = 0 # Set the BMU's value to 0 so it will be red in the colormap
362
+
363
+ if som_key == prediction[0]: # Active SOM
364
+ im_active = ax.imshow(activation_map, cmap='viridis', origin='lower', interpolation='none')
365
+ ax.plot(winner[1], winner[0], 'r+') # Mark the BMU with a red plus sign
366
+ ax.set_title(f"SOM {som_key}", color='blue', fontweight='bold')
367
+ if hasattr(self, 'label_centroids'):
368
+ label_idx = self.label_encodings.inverse_transform([som_key - 1])[0]
369
+ ax.set_xlabel(f"Label: {label_idx}", fontsize=12)
370
+ else: # Inactive SOM
371
+ im_inactive = ax.imshow(activation_map, cmap='gray', origin='lower', interpolation='none')
372
+ ax.set_title(f"SOM {som_key}")
373
+
374
+ ax.set_xticks(range(activation_map.shape[1]))
375
+ ax.set_yticks(range(activation_map.shape[0]))
376
+ ax.grid(True, linestyle='-', linewidth=0.5)
377
+
378
+ # Create a colorbar for each frame
379
+ fig.subplots_adjust(right=0.8)
380
+ cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
381
+ fig.colorbar(im_active, cax=cbar_ax)
382
+
383
+ # Save the plot to a buffer
384
+ buf = io.BytesIO()
385
+ plt.savefig(buf, format='png')
386
+ buf.seek(0)
387
+ img = imageio.imread(buf)
388
+ images.append(img)
389
+ plt.close()
390
+
391
+ # Save the images as a GIF
392
+ imageio.mimsave(f"{filename}.gif", images, duration=500, loop=1)
393
+
394
+ # Load the gif
395
+ gif_file = f"{filename}.gif" # Replace with the path to your GIF file
396
+ clip = VideoFileClip(gif_file)
397
+
398
+ # Convert the gif to mp4
399
+ mp4_file = f"{filename}.mp4" # Replace with the desired output path
400
+ clip.write_videofile(mp4_file, codec='libx264')
401
+
402
+ # Close the clip to release resources
403
+ clip.close()
404
+
405
+ def save(self, file_path):
406
+ """
407
+ Save the ClusterSOM model to a file.
408
+ """
409
+ model_data = (self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping)
410
+ if hasattr(self, 'label_centroids'):
411
+ model_data += (self.label_centroids, self.label_encodings)
412
+
413
+ with open(file_path, "wb") as f:
414
+ pickle.dump(model_data, f)
415
+
416
+ def load(self, file_path):
417
+ """
418
+ Load a ClusterSOM model from a file.
419
+ """
420
+ with open(file_path, "rb") as f:
421
+ model_data = pickle.load(f)
422
+
423
+ self.hdbscan_model, self.som_models, self.mean_values, self.sigma_values, self.cluster_mapping = model_data[:5]
424
+ if len(model_data) > 5:
425
+ self.label_centroids, self.label_encodings = model_data[5:]
ml_inference.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from phate import PHATEAE
3
+ from funcs.som import ClusterSOM
4
+
5
+ from funcs.dataloader import BaseDataset2, read_json_files
6
+
7
+
8
+ DEVICE = torch.device("cpu")
9
+
10
+ reducer10d = PHATEAE(epochs=30, n_components=10, lr=.0001, batch_size=128, t='auto', knn=8, relax=True, metric='euclidean')
11
+ reducer10d.load('models/r10d_2.pth')
12
+
13
+ cluster_som = ClusterSOM()
14
+ cluster_som.load("models/cluster_som2.pkl")
15
+
16
+ train_x, train_y = read_json_files('output.json')
17
+ # Convert tensors to numpy arrays if necessary
18
+ if isinstance(train_x, torch.Tensor):
19
+ train_x = train_x.numpy()
20
+ if isinstance(train_y, torch.Tensor):
21
+ train_y = train_y.numpy()
22
+
23
+ # load the time series slices of the data 4*3*2*64 (feeds+axis*sensor*samples) + 5 for time diff
24
+ data = BaseDataset2(train_x.reshape(len(train_x), -1) / 32768, train_y)
25
+
26
+ #compute the 10 dimensional embeding vector
27
+ embedding10d = reducer10d.transform(data)
28
+
29
+ prediction = cluster_som.predict(embedding10d)
30
+ cluster_som.plot_activation(embedding10d)
models/cluster_som2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5282b68cae29910b6b38c03e0e7e9ab528fb67ef689812d6b02012950303c2d6
3
+ size 8367290
models/r10d_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c8272bde6d372c90002f6d6afe39584255a99371bdbf18c54f5f574725d9902
3
+ size 13100259
requirements.txt CHANGED
@@ -3,66 +3,121 @@ aiohttp==3.8.4
3
  aiosignal==1.3.1
4
  altair==4.2.2
5
  anyio==3.6.2
 
 
6
  async-timeout==4.0.2
7
  attrs==23.1.0
 
 
8
  certifi==2022.12.7
9
  charset-normalizer==3.1.0
10
  click==8.1.3
11
  contourpy==1.0.7
12
  cycler==0.11.0
 
 
 
13
  entrypoints==0.4
 
14
  fastapi==0.95.1
15
  ffmpy==0.3.0
16
  filelock==3.12.0
17
  fonttools==4.39.3
18
  frozenlist==1.3.3
19
  fsspec==2023.4.0
20
- gradio==3.28.2
 
21
  gradio_client==0.1.4
 
22
  h11==0.14.0
 
23
  httpcore==0.17.0
24
  httpx==0.24.0
25
  huggingface-hub==0.14.1
26
  idna==3.4
 
 
 
27
  importlib-resources==5.12.0
 
 
28
  Jinja2==3.1.2
 
29
  jsonschema==4.17.3
30
  kiwisolver==1.4.4
 
31
  linkify-it-py==2.0.2
 
32
  markdown-it-py==2.2.0
33
  MarkupSafe==2.1.2
34
  matplotlib==3.7.1
 
35
  mdit-py-plugins==0.3.3
36
  mdurl==0.1.2
 
 
 
37
  multidict==6.0.4
 
 
 
38
  numpy==1.24.3
39
  orjson==3.8.11
40
  packaging==23.1
41
  pandas==2.0.1
 
 
 
 
42
  Pillow==9.5.0
43
  pkgutil_resolve_name==1.3.10
 
 
 
 
 
44
  pydantic==1.10.7
 
45
  pydub==0.25.1
46
  Pygments==2.15.1
 
 
47
  pyparsing==3.0.9
48
  pyrsistent==0.19.3
49
  python-dateutil==2.8.2
50
  python-multipart==0.0.6
51
  pytz==2023.3
 
52
  PyYAML==6.0
53
- requests==2.29.0
54
- scipy==1.10.1
 
 
 
55
  semantic-version==2.10.0
56
  six==1.16.0
57
  sniffio==1.3.0
 
58
  starlette==0.26.1
 
 
 
 
 
59
  toolz==0.12.0
 
 
 
60
  tqdm==4.65.0
 
61
  typing_extensions==4.5.0
62
  tzdata==2023.3
63
  uc-micro-py==1.0.2
64
- urllib3==1.26.15
 
65
  uvicorn==0.22.0
 
66
  websockets==11.0.2
 
67
  yarl==1.9.2
68
  zipp==3.15.0
 
3
  aiosignal==1.3.1
4
  altair==4.2.2
5
  anyio==3.6.2
6
+ appnope==0.1.3
7
+ asttokens==2.2.1
8
  async-timeout==4.0.2
9
  attrs==23.1.0
10
+ babyplots==1.7.0
11
+ backcall==0.2.0
12
  certifi==2022.12.7
13
  charset-normalizer==3.1.0
14
  click==8.1.3
15
  contourpy==1.0.7
16
  cycler==0.11.0
17
+ Cython==0.29.34
18
+ decorator==4.4.2
19
+ Deprecated==1.2.13
20
  entrypoints==0.4
21
+ executing==1.2.0
22
  fastapi==0.95.1
23
  ffmpy==0.3.0
24
  filelock==3.12.0
25
  fonttools==4.39.3
26
  frozenlist==1.3.3
27
  fsspec==2023.4.0
28
+ future==0.18.3
29
+ gradio==3.28.3
30
  gradio_client==0.1.4
31
+ graphtools==1.5.3
32
  h11==0.14.0
33
+ hdbscan==0.8.29
34
  httpcore==0.17.0
35
  httpx==0.24.0
36
  huggingface-hub==0.14.1
37
  idna==3.4
38
+ imageio==2.28.1
39
+ imageio-ffmpeg==0.4.8
40
+ importlib-metadata==6.6.0
41
  importlib-resources==5.12.0
42
+ ipython==8.12.2
43
+ jedi==0.18.2
44
  Jinja2==3.1.2
45
+ joblib==1.2.0
46
  jsonschema==4.17.3
47
  kiwisolver==1.4.4
48
+ lazy_loader==0.2
49
  linkify-it-py==2.0.2
50
+ llvmlite==0.40.0
51
  markdown-it-py==2.2.0
52
  MarkupSafe==2.1.2
53
  matplotlib==3.7.1
54
+ matplotlib-inline==0.1.6
55
  mdit-py-plugins==0.3.3
56
  mdurl==0.1.2
57
+ MiniSom==2.3.1
58
+ moviepy==1.0.3
59
+ mpmath==1.3.0
60
  multidict==6.0.4
61
+ networkx==3.1
62
+ numba==0.57.0
63
+ numexpr==2.8.4
64
  numpy==1.24.3
65
  orjson==3.8.11
66
  packaging==23.1
67
  pandas==2.0.1
68
+ parso==0.8.3
69
+ pexpect==4.8.0
70
+ phate @ git+https://github.com/metric-space-ai/phate.git@5fcb5bc29f6634391b0ad3831544b09a23123122
71
+ pickleshare==0.7.5
72
  Pillow==9.5.0
73
  pkgutil_resolve_name==1.3.10
74
+ plotly==5.14.1
75
+ proglog==0.1.10
76
+ prompt-toolkit==3.0.38
77
+ ptyprocess==0.7.0
78
+ pure-eval==0.2.2
79
  pydantic==1.10.7
80
+ pydiffmap==0.2.0.1
81
  pydub==0.25.1
82
  Pygments==2.15.1
83
+ PyGSP==0.5.1
84
+ pynndescent==0.5.10
85
  pyparsing==3.0.9
86
  pyrsistent==0.19.3
87
  python-dateutil==2.8.2
88
  python-multipart==0.0.6
89
  pytz==2023.3
90
+ PyWavelets==1.4.1
91
  PyYAML==6.0
92
+ requests==2.30.0
93
+ scikit-image==0.20.0
94
+ scikit-learn==1.2.2
95
+ scipy==1.9.1
96
+ seaborn==0.12.2
97
  semantic-version==2.10.0
98
  six==1.16.0
99
  sniffio==1.3.0
100
+ stack-data==0.6.2
101
  starlette==0.26.1
102
+ sympy==1.12
103
+ tasklogger==1.2.0
104
+ tenacity==8.2.2
105
+ threadpoolctl==3.1.0
106
+ tifffile==2023.4.12
107
  toolz==0.12.0
108
+ torch==2.0.1
109
+ torchaudio==2.0.2
110
+ torchvision==0.15.2
111
  tqdm==4.65.0
112
+ traitlets==5.9.0
113
  typing_extensions==4.5.0
114
  tzdata==2023.3
115
  uc-micro-py==1.0.2
116
+ umap-learn==0.5.3
117
+ urllib3==2.0.2
118
  uvicorn==0.22.0
119
+ wcwidth==0.2.6
120
  websockets==11.0.2
121
+ wrapt==1.15.0
122
  yarl==1.9.2
123
  zipp==3.15.0
test_plot.py CHANGED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import json
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ plt.style.use('ggplot')
7
+
8
+ def plot_overlay_data_from_json(json_file, sensors, use_precise_timestamp=False, slice_select=1):
9
+ # Read the JSON file
10
+ with open(json_file, "r") as f:
11
+ slices = json.load(f)
12
+
13
+ # Set up the colormap
14
+ cmap = plt.get_cmap('viridis')
15
+
16
+ # Create subplots for each sensor
17
+ fig, axs = plt.subplots(len(sensors), 1, figsize=(12, 2 * len(sensors)), sharex=True)
18
+
19
+ for idx, sensor in enumerate(sensors):
20
+ # Plot the overlay of the slices
21
+ for slice_idx, slice_dict in enumerate(slices):
22
+ slice_length = len(slice_dict[sensor])
23
+
24
+ # Create timestamp array starting from 0 for each slice
25
+ slice_timestamps = [20 * i for i in range(slice_length)]
26
+ sensor_data = slice_dict[sensor]
27
+
28
+ data = pd.DataFrame({sensor: sensor_data}, index=slice_timestamps)
29
+ color = cmap(slice_idx / len(slices))
30
+
31
+ axs[idx].plot(data[sensor], color=color, label=f'Slice {slice_idx + 1}')
32
+
33
+ axs[idx].set_ylabel(sensor)
34
+
35
+ axs[-1].set_xlabel("Timestamp")
36
+ axs[0].legend()
37
+
38
+ return fig
39
+
40
+ plot_overlay_data_from_json('output.json', ["GZ1", "GZ2", "GZ3", "GZ4"], 4)