Upload 13 files
Browse files- .streamlit/config.toml +6 -0
- Nio.png +0 -0
- Nvidia.png +0 -0
- SARIMAX_model_NIO.pkl +3 -0
- SARIMAX_model_NVDA.pkl +3 -0
- SARIMAX_model_TSLA.pkl +3 -0
- StocKnock.png +0 -0
- StocKnock1.png +0 -0
- StocKnock2.png +0 -0
- Tesla.png +0 -0
- app.py +189 -0
- model_LinReg.pkl +3 -0
- requirements.txt +11 -0
.streamlit/config.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
primaryColor="#254336"
|
3 |
+
backgroundColor="#Ffffff"
|
4 |
+
secondaryBackgroundColor="#B7B597"
|
5 |
+
textColor="#000000"
|
6 |
+
font="sans serif"
|
Nio.png
ADDED
![]() |
Nvidia.png
ADDED
![]() |
SARIMAX_model_NIO.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:049bdbb96f201b05d6ae0952f40a6cbec5a3454337325852faca28e0d97ce8c1
|
3 |
+
size 15751906
|
SARIMAX_model_NVDA.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:076546363600ff1215746c24dc8c956b0c085f2f5679214e2c75b7523bc469b3
|
3 |
+
size 15751906
|
SARIMAX_model_TSLA.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:08932ee8f840d9c4fc209bbe11101dc5da900b103ca5b422d40148e409022084
|
3 |
+
size 15751906
|
StocKnock.png
ADDED
![]() |
StocKnock1.png
ADDED
![]() |
StocKnock2.png
ADDED
![]() |
Tesla.png
ADDED
![]() |
app.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import yfinance as yf
|
4 |
+
import joblib
|
5 |
+
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
|
6 |
+
import numpy as np
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
from plotly.subplots import make_subplots
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
# Load the banner image
|
12 |
+
banner = Image.open("StocKnock.png")
|
13 |
+
banner1 = Image.open("StocKnock2.png")
|
14 |
+
Tesla = Image.open('Tesla.png')
|
15 |
+
NVDA = Image.open('Nvidia.png')
|
16 |
+
Nio = Image.open('Nio.png')
|
17 |
+
|
18 |
+
# Load the model pipeline
|
19 |
+
model_pipeline = joblib.load('model_LinReg.pkl')
|
20 |
+
|
21 |
+
# Load SARIMA models for each company
|
22 |
+
sarima_models = {
|
23 |
+
'TSLA': joblib.load('SARIMAX_model_TSLA.pkl'),
|
24 |
+
'NVDA': joblib.load('SARIMAX_model_NVDA.pkl'),
|
25 |
+
'NIO': joblib.load('SARIMAX_model_NIO.pkl')
|
26 |
+
}
|
27 |
+
|
28 |
+
# Initialize VADER sentiment analyzer
|
29 |
+
sia = SentimentIntensityAnalyzer()
|
30 |
+
|
31 |
+
def analyze_sentiment(text):
|
32 |
+
return sia.polarity_scores(text)
|
33 |
+
|
34 |
+
def categorize_sentiment(compound_score):
|
35 |
+
if compound_score >= 0.05:
|
36 |
+
return 'Positive'
|
37 |
+
elif compound_score <= -0.05:
|
38 |
+
return 'Negative'
|
39 |
+
else:
|
40 |
+
return 'Neutral'
|
41 |
+
|
42 |
+
def get_stock_data(ticker):
|
43 |
+
stock_data = yf.download(ticker, period='1y', interval='1d') # Get 1 year of data for better SARIMA forecasting
|
44 |
+
if stock_data.empty:
|
45 |
+
return None
|
46 |
+
return stock_data
|
47 |
+
|
48 |
+
def create_input_df(company, headlines):
|
49 |
+
company_ticker = {'Tesla': 'TSLA', 'Nvidia': 'NVDA', 'NIO': 'NIO'}
|
50 |
+
ticker = company_ticker.get(company)
|
51 |
+
if not ticker:
|
52 |
+
return None
|
53 |
+
|
54 |
+
stock_data = get_stock_data(ticker)
|
55 |
+
if stock_data is None:
|
56 |
+
return None
|
57 |
+
|
58 |
+
# Filter stock data to include only entries from 2024
|
59 |
+
stock_data_2024 = stock_data[stock_data.index.year == 2024]
|
60 |
+
if stock_data_2024.empty:
|
61 |
+
return None
|
62 |
+
|
63 |
+
latest_stock = stock_data_2024.iloc[-1]
|
64 |
+
|
65 |
+
data = {
|
66 |
+
'Company_ID': [ticker],
|
67 |
+
'Open': [latest_stock['Open']],
|
68 |
+
'High': [latest_stock['High']],
|
69 |
+
'Low': [latest_stock['Low']],
|
70 |
+
'Close': [latest_stock['Close']],
|
71 |
+
'Volume': [latest_stock['Volume']],
|
72 |
+
'news_count': [len(headlines)]
|
73 |
+
}
|
74 |
+
|
75 |
+
# Initialize sentiment scores
|
76 |
+
pos_score = neg_score = neu_score = compound_score = 0
|
77 |
+
|
78 |
+
# Calculate sentiment scores for each headline
|
79 |
+
for headline in headlines:
|
80 |
+
sentiment = analyze_sentiment(headline)
|
81 |
+
pos_score += sentiment['pos']
|
82 |
+
neg_score += sentiment['neg']
|
83 |
+
neu_score += sentiment['neu']
|
84 |
+
compound_score += sentiment['compound']
|
85 |
+
|
86 |
+
# Calculate average sentiment scores
|
87 |
+
num_headlines = len(headlines)
|
88 |
+
avg_pos_score = pos_score / num_headlines
|
89 |
+
avg_neg_score = neg_score / num_headlines
|
90 |
+
avg_neu_score = neu_score / num_headlines
|
91 |
+
avg_compound_score = compound_score / num_headlines
|
92 |
+
|
93 |
+
# Categorize sentiment based on the average compound score
|
94 |
+
sentiment_category = categorize_sentiment(avg_compound_score)
|
95 |
+
|
96 |
+
# Add sentiment scores and category to the data dictionary
|
97 |
+
data.update({
|
98 |
+
'positive': [avg_pos_score],
|
99 |
+
'negative': [avg_neg_score],
|
100 |
+
'neutral': [avg_neu_score],
|
101 |
+
'compound': [avg_compound_score],
|
102 |
+
'sentiment_category': [sentiment_category]
|
103 |
+
})
|
104 |
+
|
105 |
+
return pd.DataFrame(data), stock_data_2024
|
106 |
+
|
107 |
+
def predict_stock_price(company, headlines):
|
108 |
+
if len(headlines) > 10:
|
109 |
+
return "Please provide up to 10 headlines."
|
110 |
+
|
111 |
+
input_df, stock_data_2024 = create_input_df(company, headlines)
|
112 |
+
if input_df is None:
|
113 |
+
return "Invalid company selected or no data available for 2024."
|
114 |
+
|
115 |
+
st.write("Input DataFrame:")
|
116 |
+
st.write(input_df) # Display the input DataFrame for debugging
|
117 |
+
|
118 |
+
# Predict the next closing price
|
119 |
+
predicted_next_close = model_pipeline.predict(input_df)[0]
|
120 |
+
|
121 |
+
# Perform SARIMA forecast
|
122 |
+
ticker = input_df['Company_ID'][0]
|
123 |
+
sarima_model = sarima_models.get(ticker)
|
124 |
+
if sarima_model is None:
|
125 |
+
return "SARIMA model not available for the selected company."
|
126 |
+
|
127 |
+
# Prepare data for SARIMA forecast with predicted value
|
128 |
+
history_with_predicted = stock_data_2024['Adj Close']
|
129 |
+
future_with_predicted = np.append(history_with_predicted, predicted_next_close)
|
130 |
+
|
131 |
+
# Prepare data for SARIMA forecast without predicted value
|
132 |
+
history_without_predicted = stock_data_2024['Adj Close']
|
133 |
+
|
134 |
+
# Forecast future prices with predicted value
|
135 |
+
forecast_steps = 30
|
136 |
+
forecast_with_predicted = sarima_model.forecast(steps=forecast_steps, exog=[predicted_next_close])
|
137 |
+
|
138 |
+
# Plot the results
|
139 |
+
fig = make_subplots(rows=1, cols=1)
|
140 |
+
|
141 |
+
# Historical data
|
142 |
+
fig.add_trace(go.Scatter(x=history_without_predicted.index, y=history_without_predicted, mode='lines', name='Historical Data'))
|
143 |
+
|
144 |
+
# Predicted next close price
|
145 |
+
predicted_date = history_with_predicted.index[-1] + pd.Timedelta(days=1)
|
146 |
+
fig.add_trace(go.Scatter(x=[predicted_date], y=[predicted_next_close], mode='markers', name='Predicted Next Close'))
|
147 |
+
|
148 |
+
# Forecast data with predicted value
|
149 |
+
forecast_index_with_predicted = [predicted_date + pd.Timedelta(days=i) for i in range(1, forecast_steps + 1)]
|
150 |
+
forecast_with_predicted_line = go.Scatter(x=forecast_index_with_predicted, y=forecast_with_predicted, mode='lines', name='Forecast')
|
151 |
+
fig.add_trace(forecast_with_predicted_line)
|
152 |
+
|
153 |
+
fig.update_layout(title=f"SARIMA Forecast for {company}", xaxis_title="Date", yaxis_title="Price")
|
154 |
+
|
155 |
+
st.plotly_chart(fig)
|
156 |
+
|
157 |
+
return f"Predicted Next Close Price: {predicted_next_close}"
|
158 |
+
|
159 |
+
def main():
|
160 |
+
st.sidebar.image(banner1, use_column_width=True)
|
161 |
+
st.sidebar.title("**StocKnock**")
|
162 |
+
st.sidebar.write("Welcome to **StocKnock**, where we use sentiment analysis on social media to predict stock prices. Join us for smarter investing!")
|
163 |
+
st.sidebar.title("What model do we use?")
|
164 |
+
st.sidebar.write("We utilize **Linear Regression** to predict the stock for the next day and **Sarimax** to forecast future stock prices, including the predicted results.")
|
165 |
+
st.sidebar.title("Stocks you can predict")
|
166 |
+
st.sidebar.write("For the time being, these are the stock that you can predict!")
|
167 |
+
st.sidebar.image(Tesla, use_column_width=True)
|
168 |
+
st.sidebar.image(NVDA, use_column_width=True)
|
169 |
+
st.sidebar.image(Nio, use_column_width=True)
|
170 |
+
st.image(banner, use_column_width=True)
|
171 |
+
st.title("Stock Price Prediction App")
|
172 |
+
st.write("Select a company and provide up to 10 headlines to predict the next stock price based on tweets.")
|
173 |
+
|
174 |
+
company_options = ['Tesla', 'Nvidia', 'NIO']
|
175 |
+
company = st.selectbox("Select Company", company_options, key="company_select")
|
176 |
+
|
177 |
+
headlines = st.text_area("Enter Headlines (up to 10 headlines)", key="headlines_input")
|
178 |
+
|
179 |
+
if st.button("Predict", key="predict_button"):
|
180 |
+
if headlines:
|
181 |
+
headlines = headlines.split("\n")
|
182 |
+
else:
|
183 |
+
st.error("Please enter headlines.")
|
184 |
+
|
185 |
+
prediction = predict_stock_price(company, headlines)
|
186 |
+
st.success(prediction)
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
main()
|
model_LinReg.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:16a59828a4054330b584f8462c952c79b70d517ffd261bd84bbfa4b7b814fc22
|
3 |
+
size 4174
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
pandas
|
3 |
+
yfinance
|
4 |
+
joblib
|
5 |
+
vaderSentiment
|
6 |
+
numpy
|
7 |
+
plotly
|
8 |
+
Pillow
|
9 |
+
scikit-learn
|
10 |
+
statsmodels
|
11 |
+
transformers
|