cabasus / app.py
arcan3's picture
added slice part
df766c6
raw
history blame
5.35 kB
import torch
import gradio as gr
import json
import os
from phate import PHATEAE
from funcs.som import ClusterSOM
from funcs.tools import numpy_to_native
from funcs.processor import process_data
from funcs.plot_func import plot_sensor_data_from_json
from funcs.dataloader import BaseDataset2, read_json_files
DEVICE = torch.device("cpu")
reducer10d = PHATEAE(epochs=30, n_components=10, lr=.0001, batch_size=128, t='auto', knn=8, relax=True, metric='euclidean')
reducer10d.load('models/r10d.pth')
cluster_som = ClusterSOM()
cluster_som.load("models/cluster_som2.pkl")
# ml inference
def get_som_mp4(file, slice_select, reducer=reducer10d, cluster=cluster_som):
try:
train_x, train_y = read_json_files(file)
except:
train_x, train_y = read_json_files(file.name)
# Convert tensors to numpy arrays if necessary
if isinstance(train_x, torch.Tensor):
train_x = train_x.numpy()
if isinstance(train_y, torch.Tensor):
train_y = train_y.numpy()
# load the time series slices of the data 4*3*2*64 (feeds+axis*sensor*samples) + 5 for time diff
data = BaseDataset2(train_x.reshape(len(train_x), -1) / 32768, train_y)
#compute the 10 dimensional embeding vector
embedding10d = reducer.transform(data)
# prediction = cluster_som.predict(embedding10d)
fig = cluster.plot_activation_v2(embedding10d, slice_select)
return fig
def attach_label_to_json(json_file, label_text):
# Read the JSON file
try:
with open(json_file, "r") as f:
slices = json.load(f)
except:
with open(json_file.name, "r") as f:
slices = json.load(f)
slices['label'] = label_text
with open(f'manual_labelled_{os.path.basename(json_file.name)}', "w") as f:
json.dump(numpy_to_native(slices), f, indent=2)
return f'manual_labelled_{os.path.basename(json_file.name)}'
with gr.Blocks(title='Cabasus') as cabasus_sensor:
title = gr.Markdown("<h2><center>Data gathering and processing</center></h2>")
with gr.Tab("Convert"):
with gr.Row():
csv_file_box = gr.File(label='Upload CSV File')
with gr.Column():
processed_file_box = gr.File(label='Processed CSV File')
json_file_box = gr.File(label='Generated Json file')
plot_box_leg = gr.Plot(label="Filtered Signal Plot")
slice_slider = gr.Slider(minimum=1, maximum=300, label='Slice select', step=1)
som_create = gr.Button('generate som')
som_figures = gr.Plot(label="som activations")
with gr.Row():
slice_size_slider = gr.Slider(minimum=16, maximum=512, step=1, value=64, label="Slice Size", visible=False)
sample_rate = gr.Slider(minimum=1, maximum=199, step=1, value=20, label="Sample rate", visible=False)
with gr.Row():
window_size_slider = gr.Slider(minimum=0, maximum=100, step=2, value=10, label="Window Size", visible=False)
repeat_process = gr.Button('Restart process', visible=False)
with gr.Row():
leg_dropdown = gr.Dropdown(choices=['GZ1', 'GZ2', 'GZ3', 'GZ4'], label='select leg', value='GZ1')
with gr.Row():
get_all_slice = gr.Plot(label="Real Signal Plot")
plot_box_overlay = gr.Plot(label="Overlay Signal Plot")
with gr.Row():
plot_slice_leg = gr.Plot(label="Sliced Signal Plot", visible=False)
with gr.Row():
slice_json_box = gr.File(label='Slice json file')
with gr.Column():
label_name = gr.Textbox(label="enter the label name")
button_label_Add = gr.Button('attach label')
slice_json_label_box = gr.File(label='Slice json labelled file')
with gr.Row():
animation = gr.Video(label='animation')
real_video = gr.Video(label='real')
slices_per_leg = gr.Textbox(label="Debug information")
csv_file_box.change(process_data, inputs=[csv_file_box, slice_size_slider, sample_rate, window_size_slider],
outputs=[processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box])
leg_dropdown.change(plot_sensor_data_from_json, inputs=[json_file_box, leg_dropdown, slice_slider],
outputs=[plot_box_leg, plot_slice_leg, get_all_slice, slice_json_box, plot_box_overlay])
repeat_process.click(process_data, inputs=[csv_file_box, slice_size_slider, sample_rate, window_size_slider],
outputs=[processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box])
slice_slider.change(plot_sensor_data_from_json, inputs=[json_file_box, leg_dropdown, slice_slider],
outputs=[plot_box_leg, plot_slice_leg, get_all_slice, slice_json_box, plot_box_overlay])
som_create.click(get_som_mp4, inputs=[json_file_box, slice_slider], outputs=[som_figures])
button_label_Add.click(attach_label_to_json, inputs=[slice_json_box, label_name], outputs=[slice_json_label_box])
cabasus_sensor.queue(concurrency_count=2).launch(debug=True)