import pandas as pd import numpy as np from statsmodels.tsa.arima.model import ARIMA from statsmodels.tsa.seasonal import seasonal_decompose from prophet import Prophet class TimeSeriesForecaster: def __init__(self): self.model = None def forecast(self, data, date_column, value_column, periods=30, method='auto'): # Ensure data is sorted by date data = data.sort_values(date_column) data = data.set_index(date_column) if method == 'auto': # Automatically choose between ARIMA and Prophet based on data characteristics if self._has_seasonality(data[value_column]): method = 'prophet' else: method = 'arima' if method == 'arima': return self._forecast_arima(data[value_column], periods) elif method == 'prophet': return self._forecast_prophet(data.reset_index(), date_column, value_column, periods) else: raise ValueError("Invalid method. Choose 'arima', 'prophet', or 'auto'.") def _has_seasonality(self, series, threshold=0.1): result = seasonal_decompose(series, model='additive', extrapolate_trend='freq') return np.abs(result.seasonal).mean() > threshold * np.abs(result.trend).mean() def _forecast_arima(self, series, periods): model = ARIMA(series, order=(1, 1, 1)) self.model = model.fit() forecast = self.model.forecast(steps=periods) return pd.DataFrame({'date': forecast.index, 'forecast': forecast.values}) def _forecast_prophet(self, df, date_column, value_column, periods): df = df.rename(columns={date_column: 'ds', value_column: 'y'}) model = Prophet() self.model = model.fit(df) future = model.make_future_dataframe(periods=periods) forecast = model.predict(future) return forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']] def plot_forecast(self, original_data, forecast_data): import matplotlib.pyplot as plt plt.figure(figsize=(12, 6)) plt.plot(original_data.index, original_data, label='Original Data') plt.plot(forecast_data['date'], forecast_data['forecast'], label='Forecast', color='red') plt.fill_between(forecast_data['date'], forecast_data['forecast'] - forecast_data['forecast'].std(), forecast_data['forecast'] + forecast_data['forecast'].std(), color='red', alpha=0.2) plt.legend() plt.title('Time Series Forecast') plt.xlabel('Date') plt.ylabel('Value') return plt