Spaces:
Runtime error
Runtime error
from functools import partial | |
from typing import Dict | |
import gradio as gr | |
import numpy as np | |
import plotly.graph_objects as go | |
from huggingface_hub import from_pretrained_keras | |
ROOT_DATA_URL = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA" | |
TRAIN_DATA_URL = f"{ROOT_DATA_URL}/FordA_TRAIN.tsv" | |
TEST_DATA_URL = f"{ROOT_DATA_URL}/FordA_TEST.tsv" | |
TIMESERIES_LEN = 500 | |
CLASSES = {"Symptom does NOT exist", "Symptom exists"} | |
model = from_pretrained_keras("keras-io/timeseries-classification-from-scratch") | |
# Read data | |
def read_data(file_url: str): | |
data = np.loadtxt(file_url, delimiter="\t") | |
y = data[:, 0] | |
x = data[:, 1:] | |
return x, y.astype(int) | |
x_train, y_train = read_data(file_url=TRAIN_DATA_URL) | |
x_test, y_test = read_data(file_url=TEST_DATA_URL) | |
# Helper functions | |
def get_prediction(row_index: int, data: np.ndarray) -> Dict[str, float]: | |
x = data[row_index].reshape((1, TIMESERIES_LEN, 1)) | |
predictions = model.predict(x).flatten() | |
return {k: float(v) for k, v in zip(CLASSES, predictions)} | |
def create_plot(row_index: int, dataset_name: str) -> go.Figure: | |
x = x_train | |
row = x[row_index] | |
scatter = go.Scatter( | |
x=list(range(TIMESERIES_LEN)), | |
y=row.flatten(), | |
mode="lines+markers", | |
) | |
fig = go.Figure(data=scatter) | |
fig.update_layout(title=f"Timeseries in row {row_index} of {dataset_name} set ") | |
return fig | |
def show_tab_section(data: np.ndarray, dataset_name: str): | |
num_indexes = data.shape[0] | |
index = gr.Slider( | |
maximum=num_indexes - 1, | |
label="Select the index of the row you want to classify:", | |
) | |
button = gr.Button("Predict") | |
plot = gr.Plot() | |
create_plot_data = partial(create_plot, dataset_name=dataset_name) | |
button.click(create_plot_data, inputs=[index], outputs=[plot]) | |
get_prediction_data = partial(get_prediction, data=data) | |
label = gr.Label() | |
button.click(get_prediction_data, inputs=[index], outputs=[label]) | |
# Gradio Demo | |
title = "# Timeseries classification from scratch" | |
description = """ | |
Select a time series in the Training or Test dataset and ask the model to classify it! | |
<br /> | |
<br /> | |
The model was trained on the <a href="http://www.j-wichard.de/publications/FordPaper.pdf" target="_blank">FordA dataset</a>. Each row is a diagnostic session run on an automotive subsystem. In each session 500 samples were collected. Given a time series, the model was trained to identify if a specific symptom exists or it does not exist. | |
<br /> | |
<br /> | |
<p> | |
<b>Model:</b> <a href="https://huggingface.co/keras-io/timeseries-classification-from-scratch" target="_blank">https://huggingface.co/keras-io/timeseries-classification-from-scratch</a> | |
<br /> | |
<b>Keras Example:</b> <a href="https://keras.io/examples/timeseries/timeseries_classification_from_scratch/" target="_blank">https://keras.io/examples/timeseries/timeseries_classification_from_scratch/</a> | |
</p> | |
<br /> | |
""" | |
article = """ | |
<div style="text-align: center;"> | |
Space by <a href="https://github.com/EdAbati" target="_blank">Edoardo Abati</a> | |
<br /> | |
Keras example by <a href="https://github.com/hfawaz/" target="_blank">hfawaz</a> | |
</div> | |
""" | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Tabs(): | |
with gr.TabItem("Training set"): | |
show_tab_section(data=x_train, dataset_name="Training") | |
with gr.TabItem("Test set"): | |
show_tab_section(data=x_test, dataset_name="Test") | |
gr.Markdown(article) | |
demo.launch(enable_queue=True) | |