File size: 4,118 Bytes
3464b48
 
 
 
 
 
 
 
 
 
 
 
 
 
4df22ff
3464b48
 
 
 
 
 
 
be4b326
 
 
 
 
 
 
 
3464b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from datetime import timedelta
from statsmodels.tsa.arima.model import ARIMA
from config import FORECAST_PERIOD, ticker_dict, CONFIDENCE_INTERVAL
from data_fetcher import get_stock_data, get_company_info

def is_business_day(a_date):
    return a_date.weekday() < 5

def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
    predictions = []
    confidence_intervals = []
    
    if series.shape[1] > 1:
        series = series['Close'].values.tolist()
    
    if model == "ARIMA":
        model = ARIMA(series, order=(5, 1, 0))
        model_fit = model.fit()
        forecast = model_fit.forecast(steps=forecast_period, alpha=(1 - CONFIDENCE_INTERVAL))
        
        # Check if forecast is a numpy array (newer statsmodels) or a ForecastResults object (older statsmodels)
        if isinstance(forecast, np.ndarray):
            predictions = forecast
            confidence_intervals = model_fit.get_forecast(steps=forecast_period).conf_int()
        else:
            predictions = forecast.predicted_mean
            confidence_intervals = forecast.conf_int()
    elif model == "Prophet":
        # Implement Prophet forecasting method
        pass
    elif model == "LSTM":
        # Implement LSTM forecasting method
        pass

    return predictions, confidence_intervals

def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
    stock_name, ticker_name = stock.split(":")
    
    if ticker_dict[idx] == 'FTSE 100':
        ticker_name += '.L' if ticker_name[-1] != '.' else 'L'
    elif ticker_dict[idx] == 'CAC 40':
        ticker_name += '.PA'

    series = get_stock_data(ticker_name, interval, start_date, end_date)
    predictions, confidence_intervals = forecast_series(series, model=forecast_method)

    last_date = pd.to_datetime(series['Date'].values[-1])
    forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=FORECAST_PERIOD)
    forecast_dates = [date for date in forecast_dates if is_business_day(date)]

    forecast = pd.DataFrame({
        "Date": forecast_dates,
        "Forecast": predictions,
        "Lower_CI": confidence_intervals.iloc[:, 0],
        "Upper_CI": confidence_intervals.iloc[:, 1]
    })

    if graph_type == 'Line Graph':
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=series['Date'], y=series['Close'], mode='lines', name='Historical'))
        fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
        fig.add_trace(go.Scatter(
            x=forecast['Date'].tolist() + forecast['Date'].tolist()[::-1],
            y=forecast['Upper_CI'].tolist() + forecast['Lower_CI'].tolist()[::-1],
            fill='toself',
            fillcolor='rgba(0,100,80,0.2)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))
    else:  # Candlestick Graph
        fig = go.Figure(data=[go.Candlestick(x=series['Date'],
                                             open=series['Open'],
                                             high=series['High'],
                                             low=series['Low'],
                                             close=series['Close'],
                                             name='Historical')])
        fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
        fig.add_trace(go.Scatter(
            x=forecast['Date'].tolist() + forecast['Date'].tolist()[::-1],
            y=forecast['Upper_CI'].tolist() + forecast['Lower_CI'].tolist()[::-1],
            fill='toself',
            fillcolor='rgba(0,100,80,0.2)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))

    fig.update_layout(title=f"Stock Price of {stock_name}",
                      xaxis_title="Date",
                      yaxis_title="Price")

    fundamentals = get_company_info(ticker_name)

    return fig, fundamentals