EdoAbati commited on
Commit
8b9ef60
·
1 Parent(s): 2e5fce4
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Dict
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import plotly.graph_objects as go
7
+ from huggingface_hub import from_pretrained_keras
8
+
9
+ ROOT_DATA_URL = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA"
10
+ TRAIN_DATA_URL = f"{ROOT_DATA_URL}/FordA_TRAIN.tsv"
11
+ TEST_DATA_URL = f"{ROOT_DATA_URL}/FordA_TEST.tsv"
12
+ TIMESERIES_LEN = 500
13
+ CLASSES = {"Symptom does NOT exist", "Symptom exists"}
14
+
15
+ model = from_pretrained_keras("EdoAbati/timeseries-classification-from-scratch")
16
+
17
+ # Read data
18
+ def read_data(file_url: str):
19
+ data = np.loadtxt(file_url, delimiter="\t")
20
+ y = data[:, 0]
21
+ x = data[:, 1:]
22
+ return x, y.astype(int)
23
+
24
+
25
+ x_train, y_train = read_data(file_url=TRAIN_DATA_URL)
26
+ x_test, y_test = read_data(file_url=TEST_DATA_URL)
27
+
28
+ # Helper functions
29
+ def get_prediction(row_index: int, data: np.ndarray) -> Dict[str, float]:
30
+ x = data[row_index].reshape((1, TIMESERIES_LEN, 1))
31
+ predictions = model.predict(x).flatten()
32
+ return {k: float(v) for k, v in zip(CLASSES, predictions)}
33
+
34
+
35
+ def create_plot(row_index: int, dataset_name: str) -> go.Figure:
36
+ x = x_train
37
+ row = x[row_index]
38
+ scatter = go.Scatter(
39
+ x=list(range(TIMESERIES_LEN)),
40
+ y=row.flatten(),
41
+ mode="lines+markers",
42
+ )
43
+ fig = go.Figure(data=scatter)
44
+ fig.update_layout(title=f"Timeseries in row {row_index} of {dataset_name} set ")
45
+ return fig
46
+
47
+
48
+ def show_tab_section(data: np.ndarray, dataset_name: str):
49
+ num_indexes = data.shape[0]
50
+ index = gr.Slider(maximum=num_indexes - 1)
51
+ button = gr.Button("Predict")
52
+ plot = gr.Plot()
53
+ create_plot_data = partial(create_plot, dataset_name=dataset_name)
54
+ button.click(create_plot_data, inputs=[index], outputs=[plot])
55
+ get_prediction_data = partial(get_prediction, data=data)
56
+ label = gr.Label()
57
+ button.click(get_prediction_data, inputs=[index], outputs=[label])
58
+
59
+
60
+ # Gradio Demo
61
+ title = "# Timeseries classification from scratch"
62
+ description = """
63
+ Select a time series in the Training or Test dataset and ask the model to classify it!
64
+ <br />
65
+ <br />
66
+ The model was trained on the <a href="http://www.j-wichard.de/publications/FordPaper.pdf" target="_blank">FordA dataset</a>. Each row is a diagnostic session run on an automotive subsystem. In each session 500 samples were collected. Given a time series, the model was trained to identify if a specific symptom exists or it does not exist.
67
+ <br />
68
+ <br />
69
+ <p>
70
+ <b>Model:</b> <a href="https://huggingface.co/keras-io/timeseries-classification-from-scratch" target="_blank">https://huggingface.co/keras-io/timeseries-classification-from-scratch</a>
71
+ <br />
72
+ <b>Keras Example:</b> <a href="https://keras.io/examples/timeseries/timeseries_classification_from_scratch/" target="_blank">https://keras.io/examples/timeseries/timeseries_classification_from_scratch/</a>
73
+ </p>
74
+ <br />
75
+ """
76
+ article = """
77
+ <div style="text-align: center;">
78
+ Space by <a href="https://github.com/EdAbati" target="_blank">Edoardo Abati</a>
79
+ <br />
80
+ Keras example by <a href="https://github.com/hfawaz/" target="_blank">hfawaz</a>
81
+ </div>
82
+ """
83
+
84
+ demo = gr.Blocks()
85
+
86
+ with demo:
87
+ gr.Markdown(title)
88
+ gr.HTML(description)
89
+ with gr.Tabs():
90
+ with gr.TabItem("Training set"):
91
+ show_tab_section(data=x_train, dataset_name="Training")
92
+ with gr.TabItem("Test set"):
93
+ show_tab_section(data=x_test, dataset_name="Test")
94
+ gr.HTML(article)
95
+
96
+ demo.launch(enable_queue=True)