Reaumur commited on
Commit
a02b141
·
verified ·
1 Parent(s): e030839

Upload 13 files

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor="#254336"
3
+ backgroundColor="#Ffffff"
4
+ secondaryBackgroundColor="#B7B597"
5
+ textColor="#000000"
6
+ font="sans serif"
Nio.png ADDED
Nvidia.png ADDED
SARIMAX_model_NIO.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:049bdbb96f201b05d6ae0952f40a6cbec5a3454337325852faca28e0d97ce8c1
3
+ size 15751906
SARIMAX_model_NVDA.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:076546363600ff1215746c24dc8c956b0c085f2f5679214e2c75b7523bc469b3
3
+ size 15751906
SARIMAX_model_TSLA.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08932ee8f840d9c4fc209bbe11101dc5da900b103ca5b422d40148e409022084
3
+ size 15751906
StocKnock.png ADDED
StocKnock1.png ADDED
StocKnock2.png ADDED
Tesla.png ADDED
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import yfinance as yf
4
+ import joblib
5
+ from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
6
+ import numpy as np
7
+ import plotly.graph_objects as go
8
+ from plotly.subplots import make_subplots
9
+ from PIL import Image
10
+
11
+ # Load the banner image
12
+ banner = Image.open("StocKnock.png")
13
+ banner1 = Image.open("StocKnock2.png")
14
+ Tesla = Image.open('Tesla.png')
15
+ NVDA = Image.open('Nvidia.png')
16
+ Nio = Image.open('Nio.png')
17
+
18
+ # Load the model pipeline
19
+ model_pipeline = joblib.load('model_LinReg.pkl')
20
+
21
+ # Load SARIMA models for each company
22
+ sarima_models = {
23
+ 'TSLA': joblib.load('SARIMAX_model_TSLA.pkl'),
24
+ 'NVDA': joblib.load('SARIMAX_model_NVDA.pkl'),
25
+ 'NIO': joblib.load('SARIMAX_model_NIO.pkl')
26
+ }
27
+
28
+ # Initialize VADER sentiment analyzer
29
+ sia = SentimentIntensityAnalyzer()
30
+
31
+ def analyze_sentiment(text):
32
+ return sia.polarity_scores(text)
33
+
34
+ def categorize_sentiment(compound_score):
35
+ if compound_score >= 0.05:
36
+ return 'Positive'
37
+ elif compound_score <= -0.05:
38
+ return 'Negative'
39
+ else:
40
+ return 'Neutral'
41
+
42
+ def get_stock_data(ticker):
43
+ stock_data = yf.download(ticker, period='1y', interval='1d') # Get 1 year of data for better SARIMA forecasting
44
+ if stock_data.empty:
45
+ return None
46
+ return stock_data
47
+
48
+ def create_input_df(company, headlines):
49
+ company_ticker = {'Tesla': 'TSLA', 'Nvidia': 'NVDA', 'NIO': 'NIO'}
50
+ ticker = company_ticker.get(company)
51
+ if not ticker:
52
+ return None
53
+
54
+ stock_data = get_stock_data(ticker)
55
+ if stock_data is None:
56
+ return None
57
+
58
+ # Filter stock data to include only entries from 2024
59
+ stock_data_2024 = stock_data[stock_data.index.year == 2024]
60
+ if stock_data_2024.empty:
61
+ return None
62
+
63
+ latest_stock = stock_data_2024.iloc[-1]
64
+
65
+ data = {
66
+ 'Company_ID': [ticker],
67
+ 'Open': [latest_stock['Open']],
68
+ 'High': [latest_stock['High']],
69
+ 'Low': [latest_stock['Low']],
70
+ 'Close': [latest_stock['Close']],
71
+ 'Volume': [latest_stock['Volume']],
72
+ 'news_count': [len(headlines)]
73
+ }
74
+
75
+ # Initialize sentiment scores
76
+ pos_score = neg_score = neu_score = compound_score = 0
77
+
78
+ # Calculate sentiment scores for each headline
79
+ for headline in headlines:
80
+ sentiment = analyze_sentiment(headline)
81
+ pos_score += sentiment['pos']
82
+ neg_score += sentiment['neg']
83
+ neu_score += sentiment['neu']
84
+ compound_score += sentiment['compound']
85
+
86
+ # Calculate average sentiment scores
87
+ num_headlines = len(headlines)
88
+ avg_pos_score = pos_score / num_headlines
89
+ avg_neg_score = neg_score / num_headlines
90
+ avg_neu_score = neu_score / num_headlines
91
+ avg_compound_score = compound_score / num_headlines
92
+
93
+ # Categorize sentiment based on the average compound score
94
+ sentiment_category = categorize_sentiment(avg_compound_score)
95
+
96
+ # Add sentiment scores and category to the data dictionary
97
+ data.update({
98
+ 'positive': [avg_pos_score],
99
+ 'negative': [avg_neg_score],
100
+ 'neutral': [avg_neu_score],
101
+ 'compound': [avg_compound_score],
102
+ 'sentiment_category': [sentiment_category]
103
+ })
104
+
105
+ return pd.DataFrame(data), stock_data_2024
106
+
107
+ def predict_stock_price(company, headlines):
108
+ if len(headlines) > 10:
109
+ return "Please provide up to 10 headlines."
110
+
111
+ input_df, stock_data_2024 = create_input_df(company, headlines)
112
+ if input_df is None:
113
+ return "Invalid company selected or no data available for 2024."
114
+
115
+ st.write("Input DataFrame:")
116
+ st.write(input_df) # Display the input DataFrame for debugging
117
+
118
+ # Predict the next closing price
119
+ predicted_next_close = model_pipeline.predict(input_df)[0]
120
+
121
+ # Perform SARIMA forecast
122
+ ticker = input_df['Company_ID'][0]
123
+ sarima_model = sarima_models.get(ticker)
124
+ if sarima_model is None:
125
+ return "SARIMA model not available for the selected company."
126
+
127
+ # Prepare data for SARIMA forecast with predicted value
128
+ history_with_predicted = stock_data_2024['Adj Close']
129
+ future_with_predicted = np.append(history_with_predicted, predicted_next_close)
130
+
131
+ # Prepare data for SARIMA forecast without predicted value
132
+ history_without_predicted = stock_data_2024['Adj Close']
133
+
134
+ # Forecast future prices with predicted value
135
+ forecast_steps = 30
136
+ forecast_with_predicted = sarima_model.forecast(steps=forecast_steps, exog=[predicted_next_close])
137
+
138
+ # Plot the results
139
+ fig = make_subplots(rows=1, cols=1)
140
+
141
+ # Historical data
142
+ fig.add_trace(go.Scatter(x=history_without_predicted.index, y=history_without_predicted, mode='lines', name='Historical Data'))
143
+
144
+ # Predicted next close price
145
+ predicted_date = history_with_predicted.index[-1] + pd.Timedelta(days=1)
146
+ fig.add_trace(go.Scatter(x=[predicted_date], y=[predicted_next_close], mode='markers', name='Predicted Next Close'))
147
+
148
+ # Forecast data with predicted value
149
+ forecast_index_with_predicted = [predicted_date + pd.Timedelta(days=i) for i in range(1, forecast_steps + 1)]
150
+ forecast_with_predicted_line = go.Scatter(x=forecast_index_with_predicted, y=forecast_with_predicted, mode='lines', name='Forecast')
151
+ fig.add_trace(forecast_with_predicted_line)
152
+
153
+ fig.update_layout(title=f"SARIMA Forecast for {company}", xaxis_title="Date", yaxis_title="Price")
154
+
155
+ st.plotly_chart(fig)
156
+
157
+ return f"Predicted Next Close Price: {predicted_next_close}"
158
+
159
+ def main():
160
+ st.sidebar.image(banner1, use_column_width=True)
161
+ st.sidebar.title("**StocKnock**")
162
+ st.sidebar.write("Welcome to **StocKnock**, where we use sentiment analysis on social media to predict stock prices. Join us for smarter investing!")
163
+ st.sidebar.title("What model do we use?")
164
+ st.sidebar.write("We utilize **Linear Regression** to predict the stock for the next day and **Sarimax** to forecast future stock prices, including the predicted results.")
165
+ st.sidebar.title("Stocks you can predict")
166
+ st.sidebar.write("For the time being, these are the stock that you can predict!")
167
+ st.sidebar.image(Tesla, use_column_width=True)
168
+ st.sidebar.image(NVDA, use_column_width=True)
169
+ st.sidebar.image(Nio, use_column_width=True)
170
+ st.image(banner, use_column_width=True)
171
+ st.title("Stock Price Prediction App")
172
+ st.write("Select a company and provide up to 10 headlines to predict the next stock price based on tweets.")
173
+
174
+ company_options = ['Tesla', 'Nvidia', 'NIO']
175
+ company = st.selectbox("Select Company", company_options, key="company_select")
176
+
177
+ headlines = st.text_area("Enter Headlines (up to 10 headlines)", key="headlines_input")
178
+
179
+ if st.button("Predict", key="predict_button"):
180
+ if headlines:
181
+ headlines = headlines.split("\n")
182
+ else:
183
+ st.error("Please enter headlines.")
184
+
185
+ prediction = predict_stock_price(company, headlines)
186
+ st.success(prediction)
187
+
188
+ if __name__ == "__main__":
189
+ main()
model_LinReg.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16a59828a4054330b584f8462c952c79b70d517ffd261bd84bbfa4b7b814fc22
3
+ size 4174
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ yfinance
4
+ joblib
5
+ vaderSentiment
6
+ numpy
7
+ plotly
8
+ Pillow
9
+ scikit-learn
10
+ statsmodels
11
+ transformers