Spaces:
Build error
Build error
import torch | |
import gradio as gr | |
import json | |
import os | |
from phate import PHATEAE | |
from funcs.som import ClusterSOM | |
from funcs.tools import numpy_to_native | |
from funcs.processor import process_data | |
from funcs.plot_func import plot_sensor_data_from_json | |
from funcs.dataloader import BaseDataset2, read_json_files | |
DEVICE = torch.device("cpu") | |
reducer10d = PHATEAE(epochs=30, n_components=10, lr=.0001, batch_size=128, t='auto', knn=8, relax=True, metric='euclidean') | |
reducer10d.load('models/r10d.pth') | |
cluster_som = ClusterSOM() | |
cluster_som.load("models/cluster_som2.pkl") | |
# ml inference | |
def get_som_mp4(file, slice_select, reducer=reducer10d, cluster=cluster_som): | |
try: | |
train_x, train_y = read_json_files(file) | |
except: | |
train_x, train_y = read_json_files(file.name) | |
# Convert tensors to numpy arrays if necessary | |
if isinstance(train_x, torch.Tensor): | |
train_x = train_x.numpy() | |
if isinstance(train_y, torch.Tensor): | |
train_y = train_y.numpy() | |
# load the time series slices of the data 4*3*2*64 (feeds+axis*sensor*samples) + 5 for time diff | |
data = BaseDataset2(train_x.reshape(len(train_x), -1) / 32768, train_y) | |
#compute the 10 dimensional embeding vector | |
embedding10d = reducer.transform(data) | |
# prediction = cluster_som.predict(embedding10d) | |
fig = cluster.plot_activation_v2(embedding10d, slice_select) | |
return fig | |
def attach_label_to_json(json_file, label_text): | |
# Read the JSON file | |
try: | |
with open(json_file, "r") as f: | |
slices = json.load(f) | |
except: | |
with open(json_file.name, "r") as f: | |
slices = json.load(f) | |
slices['label'] = label_text | |
with open(f'manual_labelled_{os.path.basename(json_file.name)}', "w") as f: | |
json.dump(numpy_to_native(slices), f, indent=2) | |
return f'manual_labelled_{os.path.basename(json_file.name)}' | |
with gr.Blocks(title='Cabasus') as cabasus_sensor: | |
title = gr.Markdown("<h2><center>Data gathering and processing</center></h2>") | |
with gr.Tab("Convert"): | |
with gr.Row(): | |
csv_file_box = gr.File(label='Upload CSV File') | |
with gr.Column(): | |
processed_file_box = gr.File(label='Processed CSV File') | |
json_file_box = gr.File(label='Generated Json file') | |
plot_box_leg = gr.Plot(label="Filtered Signal Plot") | |
slice_slider = gr.Slider(minimum=1, maximum=300, label='Slice select', step=1) | |
som_create = gr.Button('generate som') | |
som_figures = gr.Plot(label="som activations") | |
with gr.Row(): | |
slice_size_slider = gr.Slider(minimum=16, maximum=512, step=1, value=64, label="Slice Size", visible=False) | |
sample_rate = gr.Slider(minimum=1, maximum=199, step=1, value=20, label="Sample rate", visible=False) | |
with gr.Row(): | |
window_size_slider = gr.Slider(minimum=0, maximum=100, step=2, value=10, label="Window Size", visible=False) | |
repeat_process = gr.Button('Restart process', visible=False) | |
with gr.Row(): | |
leg_dropdown = gr.Dropdown(choices=['GZ1', 'GZ2', 'GZ3', 'GZ4'], label='select leg', value='GZ1') | |
with gr.Row(): | |
get_all_slice = gr.Plot(label="Real Signal Plot") | |
plot_box_overlay = gr.Plot(label="Overlay Signal Plot") | |
with gr.Row(): | |
plot_slice_leg = gr.Plot(label="Sliced Signal Plot", visible=False) | |
with gr.Row(): | |
slice_json_box = gr.File(label='Slice json file') | |
with gr.Column(): | |
label_name = gr.Textbox(label="enter the label name") | |
button_label_Add = gr.Button('attach label') | |
slice_json_label_box = gr.File(label='Slice json labelled file') | |
with gr.Row(): | |
animation = gr.Video(label='animation') | |
real_video = gr.Video(label='real') | |
slices_per_leg = gr.Textbox(label="Debug information") | |
csv_file_box.change(process_data, inputs=[csv_file_box, slice_size_slider, sample_rate, window_size_slider], | |
outputs=[processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box]) | |
leg_dropdown.change(plot_sensor_data_from_json, inputs=[json_file_box, leg_dropdown, slice_slider], | |
outputs=[plot_box_leg, plot_slice_leg, get_all_slice, slice_json_box, plot_box_overlay]) | |
repeat_process.click(process_data, inputs=[csv_file_box, slice_size_slider, sample_rate, window_size_slider], | |
outputs=[processed_file_box, json_file_box, slices_per_leg, plot_box_leg, plot_box_overlay, slice_slider, plot_slice_leg, get_all_slice, slice_json_box]) | |
slice_slider.change(plot_sensor_data_from_json, inputs=[json_file_box, leg_dropdown, slice_slider], | |
outputs=[plot_box_leg, plot_slice_leg, get_all_slice, slice_json_box, plot_box_overlay]) | |
som_create.click(get_som_mp4, inputs=[json_file_box, slice_slider], outputs=[som_figures]) | |
button_label_Add.click(attach_label_to_json, inputs=[slice_json_box, label_name], outputs=[slice_json_label_box]) | |
cabasus_sensor.queue(concurrency_count=2).launch(debug=True) | |