File size: 17,381 Bytes
7c64532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import streamlit as st
import pandas as pd
import numpy as np
import torch
from chronos import ChronosPipeline
import plotly.graph_objects as go
import plotly.express as px
import base64

@st.cache_resource
def load_pipeline():
    return ChronosPipeline.from_pretrained(
        "amazon/chronos-t5-small",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )

@st.cache_data
def preprocess_data(data, date_column, metric_column, date_format):
    if date_format == "day-month-year":
        data[date_column] = pd.to_datetime(data[date_column], dayfirst=True)
    elif date_format == "month-day-year":
        data[date_column] = pd.to_datetime(data[date_column], dayfirst=False)
    
    time_series_data = data.set_index(date_column)[metric_column].astype(float)
    return time_series_data

def make_forecast(time_series_data, prediction_length, interval):
    pipeline = load_pipeline()
    context = torch.tensor(time_series_data.values)
    forecast = pipeline.predict(context, prediction_length)
    
    low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
    last_date = time_series_data.index[-1]
    forecast_index = pd.date_range(start=last_date + pd.Timedelta(days=interval), periods=prediction_length, freq=f'{interval}D')
    forecast_df = pd.DataFrame({
        "Date": forecast_index,
        "Low": low,
        "Median": median,
        "High": high
    })
    
    # Ensure 'Date' is a column, not the index
    forecast_df.reset_index(drop=True, inplace=True)
    
    return forecast_df

def get_csv_download_link(df, filename):
    csv = df.to_csv(index=True)
    b64 = base64.b64encode(csv.encode()).decode()
    href = f'<a href="data:file/csv;base64,{b64}" download="{filename}">Download {filename}</a>'
    return href

def visualize_initial_forecast(forecast_df, time_series_data):
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=time_series_data.index, y=time_series_data, 
                             mode='lines', name='Historical Data', 
                             line=dict(color='blue')))
    fig.add_trace(go.Scatter(x=forecast_df['Date'], y=forecast_df['Low'], 
                             mode='lines+markers', name='Low Forecast', 
                             line=dict(color='red')))
    fig.add_trace(go.Scatter(x=forecast_df['Date'], y=forecast_df['Median'], 
                             mode='lines+markers', name='Median Forecast', 
                             line=dict(color='green')))
    fig.add_trace(go.Scatter(x=forecast_df['Date'], y=forecast_df['High'], 
                             mode='lines+markers', name='High Forecast', 
                             line=dict(color='orange')))

    fig.update_layout(
        title="Chronos Forecast - click and drag the crosshairs to select an area to zoom in on",
        xaxis_title="Date",
        yaxis_title="Metric Value",
        legend_title="Legend",
        font=dict(size=12),
        xaxis=dict(rangeslider=dict(visible=False), type="date"),
    )
    
    # Remove this line: st.plotly_chart(fig)

    # Add download links
    st.markdown(get_csv_download_link(time_series_data.reset_index(), "historical_data.csv"), unsafe_allow_html=True)
    st.markdown(get_csv_download_link(forecast_df, "forecast_data.csv"), unsafe_allow_html=True)

    return fig

def visualize_forecast(forecast_df, time_series_data, ground_truth_df=None, ground_truth_date_format=None, ground_truth_date_col=None, ground_truth_metric_col=None):
    try:
        # Ensure forecast_df has a datetime index
        if 'Date' in forecast_df.columns:
            forecast_df['Date'] = pd.to_datetime(forecast_df['Date'])
            forecast_df.set_index('Date', inplace=True)
        elif not isinstance(forecast_df.index, pd.DatetimeIndex):
            forecast_df.index = pd.to_datetime(forecast_df.index)

        # Prepare ground truth data if available
        if ground_truth_df is not None:
            if ground_truth_date_format == "day-month-year":
                ground_truth_df[ground_truth_date_col] = pd.to_datetime(ground_truth_df[ground_truth_date_col], dayfirst=True)
            elif ground_truth_date_format == "month-day-year":
                ground_truth_df[ground_truth_date_col] = pd.to_datetime(ground_truth_df[ground_truth_date_col], dayfirst=False)
            ground_truth_df.set_index(ground_truth_date_col, inplace=True)
            ground_truth_df = ground_truth_df.rename(columns={ground_truth_metric_col: 'Actual'})

            # Merge forecast with ground truth
            merged_df = forecast_df.join(ground_truth_df['Actual'], how='outer')
        else:
            merged_df = forecast_df

        # Sort the index to ensure correct plotting
        merged_df.sort_index(inplace=True)

        # Calculate variances and percentage variances
        if 'Actual' in merged_df.columns:
            merged_df['Low Variance'] = merged_df['Low'] - merged_df['Actual']
            merged_df['Median Variance'] = merged_df['Median'] - merged_df['Actual']
            merged_df['High Variance'] = merged_df['High'] - merged_df['Actual']
            
            merged_df['Low % Variance'] = (merged_df['Low'] - merged_df['Actual']) / merged_df['Actual'] * 100
            merged_df['Median % Variance'] = (merged_df['Median'] - merged_df['Actual']) / merged_df['Actual'] * 100
            merged_df['High % Variance'] = (merged_df['High'] - merged_df['Actual']) / merged_df['Actual'] * 100

        # Determine the maximum value for the y-axis scale
        max_value = merged_df[['Low', 'Median', 'High'] + (['Actual'] if 'Actual' in merged_df.columns else [])].max().max()

        # Plot the trendlines using Plotly
        fig = go.Figure()

        if 'Actual' in merged_df.columns:
            fig.add_trace(go.Scatter(x=merged_df.index, y=merged_df['Actual'], 
                                     mode='lines+markers', name='Actual', 
                                     line=dict(color='black', dash='dot')))

        fig.add_trace(go.Scatter(x=merged_df.index, y=merged_df['Low'], 
                                 mode='lines+markers', name='Low Forecast', 
                                 line=dict(color='red')))
        fig.add_trace(go.Scatter(x=merged_df.index, y=merged_df['Median'], 
                                 mode='lines+markers', name='Median Forecast', 
                                 line=dict(color='green')))
        fig.add_trace(go.Scatter(x=merged_df.index, y=merged_df['High'], 
                                 mode='lines+markers', name='High Forecast', 
                                 line=dict(color='blue')))

        # Update layout
        fig.update_layout(
            title="Actual vs Forecast - click and drag the crosshairs to select an area to zoom in on",
            xaxis_title="Date",
            yaxis_title="Metric Value",
            legend_title="Legend",
            font=dict(size=12),
            xaxis=dict(
                rangeslider=dict(visible=False),
                type="date"
            ),
            yaxis=dict(range=[0, max_value * 1.1])  # Set y-axis range dynamically with some padding
        )

        st.plotly_chart(fig)
        
        # Prepare CSV for download
        csv_df = merged_df.copy()
        csv_df = csv_df.round(2)  # Round all float columns to 2 decimal places
        csv_df = csv_df.replace([np.inf, -np.inf], np.nan).fillna('')  # Replace inf with empty string
        
        # Add download link for the comparison chart data
        st.markdown(get_csv_download_link(csv_df, "forecast_vs_actual.csv"), unsafe_allow_html=True)

        # Calculate and display variances if ground truth is available
        if 'Actual' in merged_df.columns:
            # Filter for only the forecasted period
            forecast_period = merged_df.dropna(subset=['Low', 'Median', 'High', 'Actual'])
            
            # Calculate total variances for the forecasted period only
            totals = forecast_period[["Low", "Median", "High", "Actual"]].sum()
            total_low_variance = (totals["Low"] - totals["Actual"]) / totals["Actual"] if totals["Actual"] != 0 else 0
            total_median_variance = (totals["Median"] - totals["Actual"]) / totals["Actual"] if totals["Actual"] != 0 else 0
            total_high_variance = (totals["High"] - totals["Actual"]) / totals["Actual"] if totals["Actual"] != 0 else 0

            # Create a bar chart for percentage variances
            bar_df = pd.DataFrame({
                'Metric': ['Low Variance', 'Median Variance', 'High Variance'],
                'Value': [total_low_variance * 100, total_median_variance * 100, total_high_variance * 100]
            })

            bar_fig = px.bar(bar_df, x='Metric', y='Value', title='Percentage Variances', labels={'Value': 'Percentage (%)'})
            st.plotly_chart(bar_fig)
            
            # Add download link for the variance data
            st.markdown(get_csv_download_link(bar_df, "variance_data.csv"), unsafe_allow_html=True)

            st.write(f"Total Low Variance: {total_low_variance:.2f}%")
            st.write(f"Total Median Variance: {total_median_variance:.2f}%")
            st.write(f"Total High Variance: {total_high_variance:.2f}%")

    except Exception as e:
        st.error(f"An error occurred during visualization: {str(e)}")
        st.write("Debug: Exception details")
        st.write(e)

def main():
    st.title("Amazon Chronos Forecasting App")

    tab1, tab2, tab3 = st.tabs(["Run a Forecast", "Compare to Actual", "User Guide"])

    with tab1:
        uploaded_file = st.file_uploader("Upload CSV file with historical data", type=["csv"])
        if uploaded_file is not None:
            data = pd.read_csv(uploaded_file)
            st.write("File uploaded successfully")
            st.subheader("Uploaded Data")
            st.write(data)

            date_column = st.selectbox("Select the Date column", data.columns)
            metric_column = st.selectbox("Select the Metric column", data.columns)
            date_format = st.radio("Select the date format of the Date column", ("day-month-year", "month-day-year"))
            
            prediction_length = st.number_input("Enter the prediction length", min_value=1, value=12)
            interval = st.number_input("Enter the interval in days", min_value=1, value=7)

            if st.button("Make Forecast"):
                time_series_data = preprocess_data(data, date_column, metric_column, date_format)
                forecast_df = make_forecast(time_series_data, prediction_length, interval)
                
                st.session_state.forecast_df = forecast_df
                st.session_state.time_series_data = time_series_data
                
                st.subheader("Forecast Visualization")
                st.write("Forecasted Values:")
                st.write(forecast_df)
                
                initial_forecast_fig = visualize_initial_forecast(forecast_df, time_series_data)
                st.session_state.initial_forecast_fig = initial_forecast_fig
                st.plotly_chart(initial_forecast_fig)

    with tab2:
        st.subheader("Compare Forecast to Actual Data")
        
        if 'forecast_df' not in st.session_state or 'time_series_data' not in st.session_state:
            st.warning("Please make a forecast in the 'Run Forecast' tab first.")
        else:
            ground_truth_file = st.file_uploader("Upload CSV file with your actual 'ground truth' data to see how accurate the forecast is", type=["csv"], key="ground_truth_file")
            if ground_truth_file is not None:
                ground_truth_df = pd.read_csv(ground_truth_file)
                st.write("Actual data file uploaded successfully")
                st.subheader("Actual Data")
                st.write(ground_truth_df)

                ground_truth_date_col = st.selectbox("Select the Date column for actual data", ground_truth_df.columns, key="gt_date_col")
                ground_truth_metric_col = st.selectbox("Select the Metric column for actual data", ground_truth_df.columns, key="gt_metric_col")
                ground_truth_date_format = st.radio("Select the date format for actual data", ("day-month-year", "month-day-year"), key="gt_date_format")

                if st.button("Compare Forecast to Actual Data"):
                    st.subheader("Comparison with Actual Data")
                    if 'initial_forecast_fig' in st.session_state:
                        st.subheader("Chronos Forecast")
                        st.plotly_chart(st.session_state.initial_forecast_fig)
                    
                    st.subheader("Forecast vs Actual Data")
                    visualize_forecast(st.session_state.forecast_df, st.session_state.time_series_data, 
                                       ground_truth_df, ground_truth_date_format, ground_truth_date_col, ground_truth_metric_col)

    with tab3:
        st.subheader("User Guide")
        st.write("""
        This is a demo HuggingFace app which gives you everything you need to test Amazon Chronos T5 Small using a demo ecommerce sales dataset.

        As per the Hugging Face description:

        'Chronos is a family of pretrained time series forecasting models based on language model architectures. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes.
        For more info see:
        - [Hugging Face Chronos T5 Small](https://huggingface.co/amazon/chronos-t5-small)
        - [GitHub: Chronos Forecasting](https://github.com/amazon-science/chronos-forecasting)

        Please Share, Cite and Connect with Me:

        If you liked or found this notebook at all helpful please share it, and simply cite me as the original source... feel free to connect with me on LinkedIn here:
        - [LinkedIn: James Bentley](https://www.linkedin.com/in/james-bentley-1b329214/)

        Youtube Video Walkthrough of a Google Colab Notebook I built previously - which I based this app on:
        - [Watch here](https://www.youtube.com/watch?v=jyrOmIiI2Bc&t=103s)

        Disclaimer: This is purely for educational purposes.

        **Upload Your CSV File From Your Computer:**
        It should contain two columns, the first column should contain your dates, and the second should contain the metric you would like to predict, as pictured below.

        You can download a copies of the csv files I use for this test here (be sure to save them as csv):
        - [Sales.csv](https://docs.google.com/spreadsheets/d/1_tyquxKwYRWFyp0r8tMvpWoAIqJmS8fEG0wsxFT58B0/edit?usp=sharing)
        - [Actual.csv](https://docs.google.com/spreadsheets/d/1yjebWmbmY-rAyB_TDXAye8i-yoiqKA2dW_SHmtL2ihM/edit?usp=sharing)

        **Confirm Your Column Names:**
        Now we just need to confirm which column contains your dates and which contains your metric that you want to forecast, this is just so we can properly handle it based on whatever you have named them.

        **Generate Forecast and CSV File:**
        To run your forecast you will need to confirm two settings,

        - The forecast length, so this is the number of timepoints you want to run the forecast for, so for example if you wanted to run a 31 day forecast for a month, then you would select 31, if you wanted to run only 7 days next week, then you would select seven, or if you wanted to run 12 months, with one forecast for each month, you would select 12. The current default is set at 12 (to work with the demo). If you plan to assess forecast accuracy against some test data, then you should make sure that this number matches the number of date ranges you want to test against where you have data.

        - The Interval Period, so this means how many days should be between each forecasted period, so if you wanted to run the forecast for consecutive days then you would select 1, if you wanted to run the forecast for each week, then you would select 7.

        **Check the Accuracy of Your Forecast Against Actual Data:**
        
        If you want to check the accuracy of the forecaster against some real data, which you didn't include in the original csv, then you can do that by uploading an actual.csv file (or whatever you choose to name it). 
        
        This file should contain the actual data for the dates you ran the forecast for.

        This should be a two column file with a date range in the first column, and the metric in the second column, and by comparing this to the forecast you'll be able to see what kind of accuracy it outputs.

        Below is the file I use in my demo

        - [Actual.csv](https://docs.google.com/spreadsheets/d/1yjebWmbmY-rAyB_TDXAye8i-yoiqKA2dW_SHmtL2ihM/edit?usp=sharing)


        **Select the Actual.csv File and Confirm The Column Names:**
        Now you just need to confirm the column names that need to be used.

        **Generate Actual vs Forecast Trendline Chart and CSV:**
        Now that you have setup your actual file you can generate a trendline chart to show how the forecasts tracked vs your actual data for the forecasted date range.

        A csv file is also available to download which shows the combined original data, forecasts and actuals with % variances.
        """)

if __name__ == "__main__":
    main()