cabasus / test.py
arcan3's picture
added finalise
41ed540
raw
history blame
1.56 kB
import torch
import gradio as gr
import json
import os
import matplotlib.pyplot as plt
from phate import PHATEAE
from funcs.som import ClusterSOM
from funcs.tools import numpy_to_native
from funcs.processor import process_data
from funcs.plot_func import plot_sensor_data_from_json
from funcs.dataloader import BaseDataset2, read_json_files
DEVICE = torch.device("cpu")
reducer10d = PHATEAE(epochs=30, n_components=10, lr=.0001, batch_size=128, t='auto', knn=8, relax=True, metric='euclidean')
reducer10d.load('models/r10d_2.pth')
cluster_som = ClusterSOM()
cluster_som.load("models/cluster_som2.pkl")
# ml inference
def get_som_mp4(file, slice_select, reducer=reducer10d, cluster=cluster_som):
try:
train_x, train_y = read_json_files(file)
except:
train_x, train_y = read_json_files(file.name)
# Convert tensors to numpy arrays if necessary
if isinstance(train_x, torch.Tensor):
train_x = train_x.numpy()
if isinstance(train_y, torch.Tensor):
train_y = train_y.numpy()
# load the time series slices of the data 4*3*2*64 (feeds+axis*sensor*samples) + 5 for time diff
data = BaseDataset2(train_x.reshape(len(train_x), -1) / 32768, train_y)
#compute the 10 dimensional embeding vector
embedding10d = reducer.transform(data)
# prediction = cluster_som.predict(embedding10d)
fig = cluster.plot_activation_v2(embedding10d, slice_select)
plt.savefig('test.png')
return fig
get_som_mp4('Data-JSON/Dressage/Tempi/Trab/Arbeitstrab/20210906-093200-Don-Arbeitstrab.json', 1)