Kr08 commited on
Commit
4df22ff
·
verified ·
1 Parent(s): bc87714

Update stock_analysis.py

Browse files
Files changed (1) hide show
  1. stock_analysis.py +40 -24
stock_analysis.py CHANGED
@@ -1,26 +1,27 @@
1
  import pandas as pd
 
2
  import plotly.graph_objects as go
3
  from datetime import timedelta
4
  from statsmodels.tsa.arima.model import ARIMA
5
- from config import FORECAST_PERIOD, ticker_dict
6
  from data_fetcher import get_stock_data, get_company_info
7
 
8
  def is_business_day(a_date):
9
  return a_date.weekday() < 5
10
 
11
  def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
12
- predictions = list()
 
 
13
  if series.shape[1] > 1:
14
  series = series['Close'].values.tolist()
15
 
16
  if model == "ARIMA":
17
- for _ in range(forecast_period):
18
- model = ARIMA(series, order=(5, 1, 0))
19
- model_fit = model.fit()
20
- output = model_fit.forecast()
21
- yhat = output[0]
22
- predictions.append(yhat)
23
- series.append(yhat)
24
  elif model == "Prophet":
25
  # Implement Prophet forecasting method
26
  pass
@@ -28,9 +29,9 @@ def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
28
  # Implement LSTM forecasting method
29
  pass
30
 
31
- return predictions
32
 
33
- def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method):
34
  stock_name, ticker_name = stock.split(":")
35
 
36
  if ticker_dict[idx] == 'FTSE 100':
@@ -38,27 +39,33 @@ def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method):
38
  elif ticker_dict[idx] == 'CAC 40':
39
  ticker_name += '.PA'
40
 
41
- series = get_stock_data(ticker_name, interval)
42
- predictions = forecast_series(series, model=forecast_method)
43
 
44
  last_date = pd.to_datetime(series['Date'].values[-1])
45
- forecast_week = []
46
- i = 1
47
- while len(forecast_week) < FORECAST_PERIOD:
48
- next_date = last_date + timedelta(days=i)
49
- if is_business_day(next_date):
50
- forecast_week.append(next_date)
51
- i += 1
52
-
53
- predictions = predictions[:len(forecast_week)]
54
- forecast_week = forecast_week[:len(predictions)]
55
 
56
- forecast = pd.DataFrame({"Date": forecast_week, "Forecast": predictions})
 
 
 
 
 
57
 
58
  if graph_type == 'Line Graph':
59
  fig = go.Figure()
60
  fig.add_trace(go.Scatter(x=series['Date'], y=series['Close'], mode='lines', name='Historical'))
61
  fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
 
 
 
 
 
 
 
 
 
62
  else: # Candlestick Graph
63
  fig = go.Figure(data=[go.Candlestick(x=series['Date'],
64
  open=series['Open'],
@@ -67,6 +74,15 @@ def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method):
67
  close=series['Close'],
68
  name='Historical')])
69
  fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
 
 
 
 
 
 
 
 
 
70
 
71
  fig.update_layout(title=f"Stock Price of {stock_name}",
72
  xaxis_title="Date",
 
1
  import pandas as pd
2
+ import numpy as np
3
  import plotly.graph_objects as go
4
  from datetime import timedelta
5
  from statsmodels.tsa.arima.model import ARIMA
6
+ from config import FORECAST_PERIOD, ticker_dict, CONFIDENCE_INTERVAL
7
  from data_fetcher import get_stock_data, get_company_info
8
 
9
  def is_business_day(a_date):
10
  return a_date.weekday() < 5
11
 
12
  def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
13
+ predictions = []
14
+ confidence_intervals = []
15
+
16
  if series.shape[1] > 1:
17
  series = series['Close'].values.tolist()
18
 
19
  if model == "ARIMA":
20
+ model = ARIMA(series, order=(5, 1, 0))
21
+ model_fit = model.fit()
22
+ forecast = model_fit.forecast(steps=forecast_period, alpha=(1 - CONFIDENCE_INTERVAL))
23
+ predictions = forecast.predicted_mean
24
+ confidence_intervals = forecast.conf_int()
 
 
25
  elif model == "Prophet":
26
  # Implement Prophet forecasting method
27
  pass
 
29
  # Implement LSTM forecasting method
30
  pass
31
 
32
+ return predictions, confidence_intervals
33
 
34
+ def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
35
  stock_name, ticker_name = stock.split(":")
36
 
37
  if ticker_dict[idx] == 'FTSE 100':
 
39
  elif ticker_dict[idx] == 'CAC 40':
40
  ticker_name += '.PA'
41
 
42
+ series = get_stock_data(ticker_name, interval, start_date, end_date)
43
+ predictions, confidence_intervals = forecast_series(series, model=forecast_method)
44
 
45
  last_date = pd.to_datetime(series['Date'].values[-1])
46
+ forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=FORECAST_PERIOD)
47
+ forecast_dates = [date for date in forecast_dates if is_business_day(date)]
 
 
 
 
 
 
 
 
48
 
49
+ forecast = pd.DataFrame({
50
+ "Date": forecast_dates,
51
+ "Forecast": predictions,
52
+ "Lower_CI": confidence_intervals.iloc[:, 0],
53
+ "Upper_CI": confidence_intervals.iloc[:, 1]
54
+ })
55
 
56
  if graph_type == 'Line Graph':
57
  fig = go.Figure()
58
  fig.add_trace(go.Scatter(x=series['Date'], y=series['Close'], mode='lines', name='Historical'))
59
  fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
60
+ fig.add_trace(go.Scatter(
61
+ x=forecast['Date'].tolist() + forecast['Date'].tolist()[::-1],
62
+ y=forecast['Upper_CI'].tolist() + forecast['Lower_CI'].tolist()[::-1],
63
+ fill='toself',
64
+ fillcolor='rgba(0,100,80,0.2)',
65
+ line=dict(color='rgba(255,255,255,0)'),
66
+ hoverinfo="skip",
67
+ showlegend=False
68
+ ))
69
  else: # Candlestick Graph
70
  fig = go.Figure(data=[go.Candlestick(x=series['Date'],
71
  open=series['Open'],
 
74
  close=series['Close'],
75
  name='Historical')])
76
  fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
77
+ fig.add_trace(go.Scatter(
78
+ x=forecast['Date'].tolist() + forecast['Date'].tolist()[::-1],
79
+ y=forecast['Upper_CI'].tolist() + forecast['Lower_CI'].tolist()[::-1],
80
+ fill='toself',
81
+ fillcolor='rgba(0,100,80,0.2)',
82
+ line=dict(color='rgba(255,255,255,0)'),
83
+ hoverinfo="skip",
84
+ showlegend=False
85
+ ))
86
 
87
  fig.update_layout(title=f"Stock Price of {stock_name}",
88
  xaxis_title="Date",