File size: 2,658 Bytes
ed275c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import pandas as pd
import numpy as np
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.seasonal import seasonal_decompose
from prophet import Prophet
class TimeSeriesForecaster:
def __init__(self):
self.model = None
def forecast(self, data, date_column, value_column, periods=30, method='auto'):
# Ensure data is sorted by date
data = data.sort_values(date_column)
data = data.set_index(date_column)
if method == 'auto':
# Automatically choose between ARIMA and Prophet based on data characteristics
if self._has_seasonality(data[value_column]):
method = 'prophet'
else:
method = 'arima'
if method == 'arima':
return self._forecast_arima(data[value_column], periods)
elif method == 'prophet':
return self._forecast_prophet(data.reset_index(), date_column, value_column, periods)
else:
raise ValueError("Invalid method. Choose 'arima', 'prophet', or 'auto'.")
def _has_seasonality(self, series, threshold=0.1):
result = seasonal_decompose(series, model='additive', extrapolate_trend='freq')
return np.abs(result.seasonal).mean() > threshold * np.abs(result.trend).mean()
def _forecast_arima(self, series, periods):
model = ARIMA(series, order=(1, 1, 1))
self.model = model.fit()
forecast = self.model.forecast(steps=periods)
return pd.DataFrame({'date': forecast.index, 'forecast': forecast.values})
def _forecast_prophet(self, df, date_column, value_column, periods):
df = df.rename(columns={date_column: 'ds', value_column: 'y'})
model = Prophet()
self.model = model.fit(df)
future = model.make_future_dataframe(periods=periods)
forecast = model.predict(future)
return forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
def plot_forecast(self, original_data, forecast_data):
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
plt.plot(original_data.index, original_data, label='Original Data')
plt.plot(forecast_data['date'], forecast_data['forecast'], label='Forecast', color='red')
plt.fill_between(forecast_data['date'],
forecast_data['forecast'] - forecast_data['forecast'].std(),
forecast_data['forecast'] + forecast_data['forecast'].std(),
color='red', alpha=0.2)
plt.legend()
plt.title('Time Series Forecast')
plt.xlabel('Date')
plt.ylabel('Value')
return plt |