Sarat Chandra commited on
Commit
991028b
·
1 Parent(s): 8d4f57d

"Added model files"

Browse files
Files changed (7) hide show
  1. Dockerfile +22 -0
  2. app.py +85 -0
  3. model_AAPL.h5 +3 -0
  4. model_AMZN.h5 +3 -0
  5. model_TSLA.h5 +3 -0
  6. requirements.txt +6 -0
  7. templates/index.html +67 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a Python base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory inside the container
5
+ WORKDIR /code
6
+
7
+ # Copy the required files to the working directory
8
+ COPY app.py .
9
+ COPY ./templates/index.html /code/templates/index.html
10
+ COPY ./requirements.txt /code/requirements.txt
11
+ COPY model_AAPL.h5 .
12
+ COPY model_AMZN.h5 .
13
+ COPY model_TSLA.h5 .
14
+
15
+ # Install the required packages
16
+ RUN pip install --no-cache-dir -r /code/requirements.txt
17
+
18
+ # Expose the port that the Flask app will run on
19
+ EXPOSE 5000
20
+
21
+ # Start the Flask app
22
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ import pandas as pd
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import yfinance as yf
6
+ import pickle
7
+
8
+ app = Flask(__name__)
9
+
10
+ # Define tickers and load models
11
+ tickers = ['AMZN', 'TSLA', 'AAPL']
12
+ models = {}
13
+ for ticker in tickers:
14
+ model = tf.keras.models.load_model(f'model_{ticker}.h5')
15
+ models[ticker] = model
16
+
17
+ with open('min_max.pickle', 'rb') as handle:
18
+ min_max_scaling = pickle.load(handle)
19
+
20
+ # Function to prepare the data for model input
21
+ def prepare_data(data):
22
+ # Assuming the data is a 1D array of closing prices
23
+ # Reshape the data to have the shape (batch_size, timesteps, features)
24
+ data = np.array(data)
25
+ data = data.reshape(1, data.shape[0], 1)
26
+ return data
27
+
28
+ # Function to get the last 60 days data for a ticker
29
+ def get_last_60_days_data(ticker):
30
+ # Define the end date as yesterday
31
+ end_date = pd.Timestamp.today() - pd.Timedelta(days=1)
32
+ # Define the start date as 120 days before the end date
33
+ start_date = end_date - pd.Timedelta(days=120)
34
+
35
+ # Fetch the stock data using yfinance
36
+ stock_data = yf.download(ticker, start=start_date, end=end_date, progress=False)
37
+
38
+ # Ensure we have enough data (at least 60 days)
39
+ if len(stock_data) < 60:
40
+ return None
41
+
42
+ # Extract the last 60 days 'Close' prices from the stock data
43
+ last_60_days_data = stock_data['Close'].tolist()[-60:]
44
+
45
+ last_60_days_data = (last_60_days_data - min_max_scaling[ticker][0])/(min_max_scaling[ticker][1] - min_max_scaling[ticker][0])
46
+
47
+ return last_60_days_data.tolist()
48
+
49
+ # Function to predict the next day closing value using the model
50
+ def predict_next_day(ticker, data):
51
+ model = models[ticker]
52
+ data = prepare_data(data)
53
+ prediction = model.predict(data)
54
+ return prediction[0]
55
+
56
+ def scale_back_data(data,ticker):
57
+ data = np.array(data)
58
+ data = data * (min_max_scaling[ticker][1] - min_max_scaling[ticker][0]) + min_max_scaling[ticker][0]
59
+ return data.tolist()
60
+
61
+ # @app.route('/')
62
+ # def hello_world():
63
+ # return "Hello World"
64
+
65
+ # Flask route to handle the main page
66
+ @app.route('/', methods=['GET', 'POST'])
67
+ def index():
68
+ if request.method == 'POST':
69
+ selected_ticker = request.form['ticker']
70
+ last_60_days_data = get_last_60_days_data(selected_ticker)
71
+ last_60_days_data_original = get_last_60_days_data(selected_ticker)
72
+ predictions = predict_next_day(selected_ticker, last_60_days_data)
73
+ # for _ in range(10):
74
+ # next_day_prediction = predict_next_day(selected_ticker, last_60_days_data)
75
+ # predictions.append(next_day_prediction)
76
+ # last_60_days_data.append(next_day_prediction)
77
+ # last_60_days_data.pop(0)
78
+ predictions = scale_back_data(predictions,selected_ticker)
79
+ last_60_days_data_original = scale_back_data(last_60_days_data_original,selected_ticker)
80
+ return render_template('index.html', tickers=tickers, selected_ticker=selected_ticker, predictions=predictions, last_60_days_data=last_60_days_data_original)
81
+ else:
82
+ return render_template('index.html', tickers=tickers, selected_ticker=None, predictions=None, last_60_days_data=None)
83
+
84
+ if __name__ == '__main__':
85
+ app.run()
model_AAPL.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a68a119cd48925080649e2f2aba403abeacdc232680b988ca635d21d551e43ed
3
+ size 746232
model_AMZN.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d1ea08fbc4894f499cfab2ced61c5f7c2975ba719a21cf153bd536a180d2c98
3
+ size 746208
model_TSLA.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a22b259dfbf9d6b60967758b1bcda39ede97538aace25b75b744a489f76cb7c4
3
+ size 746208
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ flask
2
+ plotly
3
+ yfinance
4
+ pandas
5
+ tensorflow
6
+ numpy
templates/index.html ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Stock Price Prediction</title>
5
+ <!-- Add Plotly.js library -->
6
+ <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
7
+ </head>
8
+ <body>
9
+ <h1>Stock Price Prediction</h1>
10
+ <h3>A demonstration of docker and flask to predict a particular stock by selecting appropriate model.</h3>
11
+ <h4>Please select a ticker and click predict.</h4>
12
+ <form method="POST">
13
+ <select name="ticker">
14
+ {% for ticker in tickers %}
15
+ <option value="{{ ticker }}" {% if selected_ticker == ticker %}selected{% endif %}>{{ ticker }}</option>
16
+ {% endfor %}
17
+ </select>
18
+ <input type="submit" value="Predict">
19
+ </form>
20
+ {% if predictions and last_60_days_data %}
21
+ <div style="width: 80%; margin: auto;">
22
+ <div id="stockChart"></div>
23
+ </div>
24
+ <script>
25
+ // Flask variables
26
+ var predictions = {{predictions}};
27
+ var last_60_days_data = {{last_60_days_data}};
28
+
29
+
30
+ // Create the Plotly chart
31
+ var trace1 = {
32
+ x: Array.from({ length: 60 }, (_, i) => i + 1),
33
+ y: last_60_days_data,
34
+ type: 'scatter',
35
+ mode: 'lines',
36
+ name: 'Historical Data',
37
+ line: { color: 'rgba(54, 162, 235, 1)' }
38
+ };
39
+ var trace2 = {
40
+ x: Array.from({ length: 10 }, (_, i) => i + 61),
41
+ y: predictions,
42
+ type: 'scatter',
43
+ mode: 'lines',
44
+ name: 'Predictions',
45
+ line: { color: 'rgba(255, 99, 132, 1)', dash: 'dash' }
46
+ };
47
+ var data = [trace1, trace2];
48
+ var layout = {
49
+ xaxis: {
50
+ title: 'Days'
51
+ },
52
+ yaxis: {
53
+ title: 'Stock Price',
54
+ zeroline: true
55
+ },
56
+ legend: {
57
+ x: 0,
58
+ y: 1.0,
59
+ bgcolor: 'rgba(255, 255, 255, 0)',
60
+ bordercolor: 'rgba(255, 255, 255, 0)'
61
+ }
62
+ };
63
+ Plotly.newPlot('stockChart', data, layout);
64
+ </script>
65
+ {% endif %}
66
+ </body>
67
+ </html>