Spaces:
Runtime error
Runtime error
added app
Browse files
app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
from huggingface_hub import from_pretrained_keras
|
8 |
+
|
9 |
+
ROOT_DATA_URL = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA"
|
10 |
+
TRAIN_DATA_URL = f"{ROOT_DATA_URL}/FordA_TRAIN.tsv"
|
11 |
+
TEST_DATA_URL = f"{ROOT_DATA_URL}/FordA_TEST.tsv"
|
12 |
+
TIMESERIES_LEN = 500
|
13 |
+
CLASSES = {"Symptom does NOT exist", "Symptom exists"}
|
14 |
+
|
15 |
+
model = from_pretrained_keras("EdoAbati/timeseries-classification-from-scratch")
|
16 |
+
|
17 |
+
# Read data
|
18 |
+
def read_data(file_url: str):
|
19 |
+
data = np.loadtxt(file_url, delimiter="\t")
|
20 |
+
y = data[:, 0]
|
21 |
+
x = data[:, 1:]
|
22 |
+
return x, y.astype(int)
|
23 |
+
|
24 |
+
|
25 |
+
x_train, y_train = read_data(file_url=TRAIN_DATA_URL)
|
26 |
+
x_test, y_test = read_data(file_url=TEST_DATA_URL)
|
27 |
+
|
28 |
+
# Helper functions
|
29 |
+
def get_prediction(row_index: int, data: np.ndarray) -> Dict[str, float]:
|
30 |
+
x = data[row_index].reshape((1, TIMESERIES_LEN, 1))
|
31 |
+
predictions = model.predict(x).flatten()
|
32 |
+
return {k: float(v) for k, v in zip(CLASSES, predictions)}
|
33 |
+
|
34 |
+
|
35 |
+
def create_plot(row_index: int, dataset_name: str) -> go.Figure:
|
36 |
+
x = x_train
|
37 |
+
row = x[row_index]
|
38 |
+
scatter = go.Scatter(
|
39 |
+
x=list(range(TIMESERIES_LEN)),
|
40 |
+
y=row.flatten(),
|
41 |
+
mode="lines+markers",
|
42 |
+
)
|
43 |
+
fig = go.Figure(data=scatter)
|
44 |
+
fig.update_layout(title=f"Timeseries in row {row_index} of {dataset_name} set ")
|
45 |
+
return fig
|
46 |
+
|
47 |
+
|
48 |
+
def show_tab_section(data: np.ndarray, dataset_name: str):
|
49 |
+
num_indexes = data.shape[0]
|
50 |
+
index = gr.Slider(maximum=num_indexes - 1)
|
51 |
+
button = gr.Button("Predict")
|
52 |
+
plot = gr.Plot()
|
53 |
+
create_plot_data = partial(create_plot, dataset_name=dataset_name)
|
54 |
+
button.click(create_plot_data, inputs=[index], outputs=[plot])
|
55 |
+
get_prediction_data = partial(get_prediction, data=data)
|
56 |
+
label = gr.Label()
|
57 |
+
button.click(get_prediction_data, inputs=[index], outputs=[label])
|
58 |
+
|
59 |
+
|
60 |
+
# Gradio Demo
|
61 |
+
title = "# Timeseries classification from scratch"
|
62 |
+
description = """
|
63 |
+
Select a time series in the Training or Test dataset and ask the model to classify it!
|
64 |
+
<br />
|
65 |
+
<br />
|
66 |
+
The model was trained on the <a href="http://www.j-wichard.de/publications/FordPaper.pdf" target="_blank">FordA dataset</a>. Each row is a diagnostic session run on an automotive subsystem. In each session 500 samples were collected. Given a time series, the model was trained to identify if a specific symptom exists or it does not exist.
|
67 |
+
<br />
|
68 |
+
<br />
|
69 |
+
<p>
|
70 |
+
<b>Model:</b> <a href="https://huggingface.co/keras-io/timeseries-classification-from-scratch" target="_blank">https://huggingface.co/keras-io/timeseries-classification-from-scratch</a>
|
71 |
+
<br />
|
72 |
+
<b>Keras Example:</b> <a href="https://keras.io/examples/timeseries/timeseries_classification_from_scratch/" target="_blank">https://keras.io/examples/timeseries/timeseries_classification_from_scratch/</a>
|
73 |
+
</p>
|
74 |
+
<br />
|
75 |
+
"""
|
76 |
+
article = """
|
77 |
+
<div style="text-align: center;">
|
78 |
+
Space by <a href="https://github.com/EdAbati" target="_blank">Edoardo Abati</a>
|
79 |
+
<br />
|
80 |
+
Keras example by <a href="https://github.com/hfawaz/" target="_blank">hfawaz</a>
|
81 |
+
</div>
|
82 |
+
"""
|
83 |
+
|
84 |
+
demo = gr.Blocks()
|
85 |
+
|
86 |
+
with demo:
|
87 |
+
gr.Markdown(title)
|
88 |
+
gr.HTML(description)
|
89 |
+
with gr.Tabs():
|
90 |
+
with gr.TabItem("Training set"):
|
91 |
+
show_tab_section(data=x_train, dataset_name="Training")
|
92 |
+
with gr.TabItem("Test set"):
|
93 |
+
show_tab_section(data=x_test, dataset_name="Test")
|
94 |
+
gr.HTML(article)
|
95 |
+
|
96 |
+
demo.launch(enable_queue=True)
|