cabasus / funcs /ml_inference.py
arcan3
added
5124a31
raw
history blame contribute delete
878 Bytes
import torch
from funcs.dataloader import BaseDataset2, read_json_files
def get_som_mp4(file, reducer10d, cluster_som, slice_select):
try:
train_x, train_y = read_json_files(file)
except:
train_x, train_y = read_json_files(file.name)
# Convert tensors to numpy arrays if necessary
if isinstance(train_x, torch.Tensor):
train_x = train_x.numpy()
if isinstance(train_y, torch.Tensor):
train_y = train_y.numpy()
# load the time series slices of the data 4*3*2*64 (feeds+axis*sensor*samples) + 5 for time diff
data = BaseDataset2(train_x.reshape(len(train_x), -1) / 32768, train_y)
#compute the 10 dimensional embeding vector
embedding10d = reducer10d.transform(data)
# prediction = cluster_som.predict(embedding10d)
fig = cluster_som.plot_activation_v2(embedding10d, slice_select)
return fig