|
import pandas as pd |
|
import numpy as np |
|
from statsmodels.tsa.arima.model import ARIMA |
|
from statsmodels.tsa.seasonal import seasonal_decompose |
|
from prophet import Prophet |
|
|
|
class TimeSeriesForecaster: |
|
def __init__(self): |
|
self.model = None |
|
|
|
def forecast(self, data, date_column, value_column, periods=30, method='auto'): |
|
|
|
data = data.sort_values(date_column) |
|
data = data.set_index(date_column) |
|
|
|
if method == 'auto': |
|
|
|
if self._has_seasonality(data[value_column]): |
|
method = 'prophet' |
|
else: |
|
method = 'arima' |
|
|
|
if method == 'arima': |
|
return self._forecast_arima(data[value_column], periods) |
|
elif method == 'prophet': |
|
return self._forecast_prophet(data.reset_index(), date_column, value_column, periods) |
|
else: |
|
raise ValueError("Invalid method. Choose 'arima', 'prophet', or 'auto'.") |
|
|
|
def _has_seasonality(self, series, threshold=0.1): |
|
result = seasonal_decompose(series, model='additive', extrapolate_trend='freq') |
|
return np.abs(result.seasonal).mean() > threshold * np.abs(result.trend).mean() |
|
|
|
def _forecast_arima(self, series, periods): |
|
model = ARIMA(series, order=(1, 1, 1)) |
|
self.model = model.fit() |
|
forecast = self.model.forecast(steps=periods) |
|
return pd.DataFrame({'date': forecast.index, 'forecast': forecast.values}) |
|
|
|
def _forecast_prophet(self, df, date_column, value_column, periods): |
|
df = df.rename(columns={date_column: 'ds', value_column: 'y'}) |
|
model = Prophet() |
|
self.model = model.fit(df) |
|
future = model.make_future_dataframe(periods=periods) |
|
forecast = model.predict(future) |
|
return forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']] |
|
|
|
def plot_forecast(self, original_data, forecast_data): |
|
import matplotlib.pyplot as plt |
|
|
|
plt.figure(figsize=(12, 6)) |
|
plt.plot(original_data.index, original_data, label='Original Data') |
|
plt.plot(forecast_data['date'], forecast_data['forecast'], label='Forecast', color='red') |
|
plt.fill_between(forecast_data['date'], |
|
forecast_data['forecast'] - forecast_data['forecast'].std(), |
|
forecast_data['forecast'] + forecast_data['forecast'].std(), |
|
color='red', alpha=0.2) |
|
plt.legend() |
|
plt.title('Time Series Forecast') |
|
plt.xlabel('Date') |
|
plt.ylabel('Value') |
|
return plt |