File size: 1,237 Bytes
46dbe9e
46a14b8
 
 
 
7cd99ed
 
 
46a14b8
46dbe9e
 
46a14b8
 
 
 
46dbe9e
46a14b8
 
 
c029c40
46a14b8
 
 
46dbe9e
46a14b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46dbe9e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.torch.model.deepar import DeepAREstimator
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt


def fn(upload_data):
    df = pd.read_csv(upload_data.name, index_col=0, parse_dates=True)
    dataset = PandasDataset(df, target=df.columns[0])
    training_data, test_gen = split(dataset, offset=-36)

    model = DeepAREstimator(
        prediction_length=12,
        freq=dataset.freq,
        trainer_kwargs=dict(max_epochs=10),
    ).train(
        training_data=training_data,
    )

    test_data = test_gen.generate_instances(prediction_length=12, windows=3)
    forecasts = list(model.predict(test_data.input))

    fig = plt.figure()
    df["#Passengers"].plot(color="black")
    for forecast, color in zip(forecasts, ["green", "blue", "purple"]):
        forecast.plot(color=f"tab:{color}")
    plt.legend(["True values"], loc="upper left", fontsize="xx-large")
    return fig


with gr.Blocks() as demo:
    plot = gr.Plot()
    upload_btn = gr.UploadButton()

    upload_btn.upload(fn, inputs=upload_btn, outputs=plot)

if __name__ == "__main__":
    demo.launch()