sfmajors commited on
Commit
d563a73
·
1 Parent(s): 9bbed42

Fixing referencing

Browse files
TSLASentimentAnalyzer/.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ import matplotlib.pyplot as plt
5
+ from scraper import RedditScraper
6
+ import pandas as pd
7
+ from classifier import predict
8
+ from config import settings
9
+ from transformers import pipeline
10
+ from loguru import logger
11
+
12
+ reddit = RedditScraper()
13
+ st.title("$TSLA Market Sentiment Analyzer using r/TSLA Subreddit")
14
+
15
+
16
+ def load_data(number, scraping_option):
17
+ st.write("loading new data")
18
+ # st.write(scraping_option)
19
+ comments = []
20
+ for submission in scraping_option(number):
21
+ comments.extend(reddit.get_comment_forest(submission.comments))
22
+ logger.debug(
23
+ submission.title,
24
+ submission.num_comments,
25
+ len(reddit.get_comment_forest(submission.comments)),
26
+ )
27
+ df = pd.DataFrame(comments)
28
+
29
+ return df
30
+
31
+
32
+ def select_scrap_type(option):
33
+ if option == "Hot":
34
+ st.write("Selected Hot submissions")
35
+ return reddit.get_hot
36
+ if option == "Rising":
37
+ st.write("Selected rising submissions")
38
+ return reddit.get_rising
39
+ if option == "New":
40
+ st.write("Selected new submissions")
41
+ return reddit.get_new
42
+
43
+
44
+ st.info(
45
+ "Option has been deactivated as the same submissions were scraped because the subreddit is not too active"
46
+ )
47
+ select = st.selectbox("choose option", ["Hot", "Rising", "New"], disabled=True)
48
+
49
+
50
+ number = st.number_input("Insert a number", step=1, max_value=30, min_value=3)
51
+
52
+
53
+ sentiment_pipeline = pipeline("sentiment-analysis", settings.model_path)
54
+
55
+ data = load_data(number, select_scrap_type("Hot"))
56
+
57
+
58
+ if st.button("Analyze"):
59
+ results = sentiment_pipeline(list(data["comment"]))
60
+ data["label"] = [res["label"] for res in results]
61
+ data["sentiment_score"] = [res["score"] for res in results]
62
+ st.write(data.groupby("label").count())
63
+ sizes = list(data.groupby("label").count()["comment"])
64
+ labels = "Negative", "Positive"
65
+ fig1, ax1 = plt.subplots()
66
+ ax1.pie(sizes, labels=labels, autopct="%1.1f%%", shadow=True, startangle=90)
67
+ ax1.axis("equal")
68
+ st.pyplot(fig1)
69
+
70
+
71
+ st.write(data)
TSLASentimentAnalyzer/.ipynb_checkpoints/classifier-checkpoint.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ # sentiment_pipeline = pipeline("sentiment-analysis")
3
+ # data = ["I love you", "I hate you"]
4
+ def predict(data, custom_model: str ="finiteautomata/bertweet-base-sentiment-analysis"):
5
+ sentiment_pipeline = pipeline("sentiment-analysis")
6
+ return sentiment_pipeline(data)
7
+
TSLASentimentAnalyzer/.ipynb_checkpoints/config-checkpoint.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set
2
+
3
+ from pydantic import (
4
+ BaseModel,
5
+ BaseSettings,
6
+ RedisDsn,
7
+ PostgresDsn,
8
+ AmqpDsn,
9
+ Field,
10
+ )
11
+
12
+
13
+ class Settings(BaseSettings):
14
+
15
+ reddit_api_client_id: str
16
+ reddit_api_client_secret: str
17
+ stock_data_api_key: str
18
+ reddit_api_user_agent: str = "USERAGENT"
19
+ model_path: str = "fourthbrain-demo/model_trained_by_me2"
20
+ class Config:
21
+ env_file = ".env" # defaults to no prefix, i.e. ""
22
+
23
+
24
+ settings = Settings()
TSLASentimentAnalyzer/.ipynb_checkpoints/scraper-checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import praw
2
+ from config import settings
3
+ from praw.models import MoreComments
4
+ from loguru import logger
5
+
6
+
7
+ class RedditScraper:
8
+ def __init__(self, subreddit: str = "TSLA"):
9
+ reddit = praw.Reddit(
10
+ client_id=settings.reddit_api_client_id,
11
+ client_secret=settings.reddit_api_client_secret,
12
+ user_agent=settings.reddit_api_user_agent,
13
+ )
14
+ self.subreddit = reddit.subreddit(subreddit)
15
+
16
+ def get_hot(self, posts: int = 10):
17
+ return self.subreddit.hot(limit=posts)
18
+
19
+ def get_new(self, posts: int = 10):
20
+ return self.subreddit.new(limit=posts)
21
+
22
+ def get_rising(self, posts: int = 10):
23
+ return self.subreddit.rising(limit=posts)
24
+
25
+ def get_top(self, posts: int = 10):
26
+ return self.subreddit.top(limit=posts)
27
+
28
+ def get_top_comments(self, submission, threshold: int = 5):
29
+ return [
30
+ comment.body
31
+ for comment in submission.comments
32
+ if comment.score >= threshold
33
+ ]
34
+
35
+ def get_comment_forest(self, comment_forest, all_comments=[]):
36
+ all_comments = []
37
+ if isinstance(comment_forest, MoreComments):
38
+ comments_list = comment_forest.comments()
39
+ else:
40
+ comments_list = comment_forest.list()
41
+ logger.debug(str(comment_forest), len(comments_list))
42
+ for comment in comments_list:
43
+ if isinstance(comment, MoreComments):
44
+ logger.info("more comments")
45
+ logger.debug(self.get_comment_forest(comment))
46
+ continue
47
+ item = {}
48
+ item["comment"] = comment.body
49
+ item["title"] = comment.submission.title
50
+ item["id"] = comment.id
51
+ item["created_at"] = int(comment.created_utc)
52
+ item["score"] = comment.score
53
+ all_comments.append(item)
54
+ return all_comments
55
+ if comment_forest.list():
56
+ for reply in comment_forest:
57
+ all_comments.append(reply)
58
+ return self.get_comment_forest(reply.replies, all_comments)
59
+ return all_comments
TSLASentimentAnalyzer/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Demo
3
+ emoji: 📉
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.10.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+
16
+ # quick start
17
+
18
+ 1. install requirements with `pip install -r requirements`
19
+ 2. create a .env file in the root directory and add the following variables:
20
+ `REDDIT_API_CLIENT_ID` : the client ID of your reddit app
21
+ `REDDIT_API_CLIENT_SECRET`: the client secret of your reddit app
22
+ follow this tutorial to generate them. <https://www.jcchouinard.com/get-reddit-api-credentials-with-praw/>
23
+
24
+ 3. run the streamlit app using : `streamlit run app.py`
25
+
26
+ # scraping
27
+
28
+ the app use praw library to scrape submissions from reddit. A class named `scraper.RedditScraper` implements and abstracts that feature.
29
+
30
+ # sentiment analysis model
31
+ The model used in the application is a fine-tuned BERT-based model trained with labeled data scraped from TSLA subbreddit using a script that uses the scraping module `scraper.RedditScraper`.
32
+
33
+ The data is available in <https://huggingface.co/datasets/fourthbrain-demo/reddit-comments-demo> , availalbe in 2 versions (used later with DVC), and splitted into train/test datasets.
34
+
35
+
TSLASentimentAnalyzer/__pycache__/app.cpython-310.pyc ADDED
Binary file (2.37 kB). View file
 
TSLASentimentAnalyzer/__pycache__/classifier.cpython-310.pyc ADDED
Binary file (535 Bytes). View file
 
TSLASentimentAnalyzer/__pycache__/config.cpython-310.pyc ADDED
Binary file (999 Bytes). View file
 
TSLASentimentAnalyzer/__pycache__/scraper.cpython-310.pyc ADDED
Binary file (2.47 kB). View file
 
TSLASentimentAnalyzer/app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ import matplotlib.pyplot as plt
5
+ from scraper import RedditScraper
6
+ import pandas as pd
7
+ from classifier import predict
8
+ from config import settings
9
+ from transformers import pipeline
10
+ from loguru import logger
11
+
12
+ reddit = RedditScraper()
13
+ st.title("$TSLA Market Sentiment Analyzer using r/TSLA Subreddit")
14
+
15
+
16
+ def load_data(number, scraping_option):
17
+ st.write("loading new data")
18
+ # st.write(scraping_option)
19
+ comments = []
20
+ for submission in scraping_option(number):
21
+ comments.extend(reddit.get_comment_forest(submission.comments))
22
+ logger.debug(
23
+ submission.title,
24
+ submission.num_comments,
25
+ len(reddit.get_comment_forest(submission.comments)),
26
+ )
27
+ df = pd.DataFrame(comments)
28
+
29
+ return df
30
+
31
+
32
+ def select_scrap_type(option):
33
+ if option == "Hot":
34
+ st.write("Selected Hot submissions")
35
+ return reddit.get_hot
36
+ if option == "Rising":
37
+ st.write("Selected rising submissions")
38
+ return reddit.get_rising
39
+ if option == "New":
40
+ st.write("Selected new submissions")
41
+ return reddit.get_new
42
+
43
+
44
+ st.info(
45
+ "Option has been deactivated as the same submissions were scraped because the subreddit is not too active"
46
+ )
47
+ select = st.selectbox("choose option", ["Hot", "Rising", "New"], disabled=True)
48
+
49
+
50
+ number = st.number_input("Insert a number", step=1, max_value=30, min_value=3)
51
+
52
+
53
+ sentiment_pipeline = pipeline("sentiment-analysis", settings.model_path)
54
+
55
+ data = load_data(number, select_scrap_type("Hot"))
56
+
57
+
58
+ if st.button("Analyze"):
59
+ results = sentiment_pipeline(list(data["comment"]))
60
+ data["label"] = [res["label"] for res in results]
61
+ data["sentiment_score"] = [res["score"] for res in results]
62
+ st.write(data.groupby("label").count())
63
+ sizes = list(data.groupby("label").count()["comment"])
64
+ labels = "Negative", "Positive"
65
+ fig1, ax1 = plt.subplots()
66
+ ax1.pie(sizes, labels=labels, autopct="%1.1f%%", shadow=True, startangle=90)
67
+ ax1.axis("equal")
68
+ st.pyplot(fig1)
69
+
70
+
71
+ st.write(data)
TSLASentimentAnalyzer/classifier.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ # sentiment_pipeline = pipeline("sentiment-analysis")
3
+ # data = ["I love you", "I hate you"]
4
+ def predict(data, custom_model: str ="finiteautomata/bertweet-base-sentiment-analysis"):
5
+ sentiment_pipeline = pipeline("sentiment-analysis")
6
+ return sentiment_pipeline(data)
7
+
TSLASentimentAnalyzer/config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Set
2
+
3
+ from pydantic import (
4
+ BaseModel,
5
+ BaseSettings,
6
+ RedisDsn,
7
+ PostgresDsn,
8
+ AmqpDsn,
9
+ Field,
10
+ )
11
+
12
+
13
+ class Settings(BaseSettings):
14
+
15
+ reddit_api_client_id: str
16
+ reddit_api_client_secret: str
17
+ stock_data_api_key: str
18
+ reddit_api_user_agent: str = "USERAGENT"
19
+ model_path: str = "fourthbrain-demo/model_trained_by_me2"
20
+ class Config:
21
+ env_file = ".env" # defaults to no prefix, i.e. ""
22
+
23
+
24
+ settings = Settings()
TSLASentimentAnalyzer/scraper.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import praw
2
+ from config import settings
3
+ from praw.models import MoreComments
4
+ from loguru import logger
5
+
6
+
7
+ class RedditScraper:
8
+ def __init__(self, subreddit: str = "TSLA"):
9
+ reddit = praw.Reddit(
10
+ client_id=settings.reddit_api_client_id,
11
+ client_secret=settings.reddit_api_client_secret,
12
+ user_agent=settings.reddit_api_user_agent,
13
+ )
14
+ self.subreddit = reddit.subreddit(subreddit)
15
+
16
+ def get_hot(self, posts: int = 10):
17
+ return self.subreddit.hot(limit=posts)
18
+
19
+ def get_new(self, posts: int = 10):
20
+ return self.subreddit.new(limit=posts)
21
+
22
+ def get_rising(self, posts: int = 10):
23
+ return self.subreddit.rising(limit=posts)
24
+
25
+ def get_top(self, posts: int = 10):
26
+ return self.subreddit.top(limit=posts)
27
+
28
+ def get_top_comments(self, submission, threshold: int = 5):
29
+ return [
30
+ comment.body
31
+ for comment in submission.comments
32
+ if comment.score >= threshold
33
+ ]
34
+
35
+ def get_comment_forest(self, comment_forest, all_comments=[]):
36
+ all_comments = []
37
+ if isinstance(comment_forest, MoreComments):
38
+ comments_list = comment_forest.comments()
39
+ else:
40
+ comments_list = comment_forest.list()
41
+ logger.debug(str(comment_forest), len(comments_list))
42
+ for comment in comments_list:
43
+ if isinstance(comment, MoreComments):
44
+ logger.info("more comments")
45
+ logger.debug(self.get_comment_forest(comment))
46
+ continue
47
+ item = {}
48
+ item["comment"] = comment.body
49
+ item["title"] = comment.submission.title
50
+ item["id"] = comment.id
51
+ item["created_at"] = int(comment.created_utc)
52
+ item["score"] = comment.score
53
+ all_comments.append(item)
54
+ return all_comments
55
+ if comment_forest.list():
56
+ for reply in comment_forest:
57
+ all_comments.append(reply)
58
+ return self.get_comment_forest(reply.replies, all_comments)
59
+ return all_comments
TSLASentimentAnalyzer/sentiment_data.csv ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ,timestamp,counter,close,volume,sentiment_score,close_lag1,perc_change_close,sentiment_score_lag1,perc_change_sentiment,sentiment_SMA3mo
2
+ 1,2022-07-09,625,0.0,,0.9639244723320007,752.28998,-1.0,0.980573832988739,-0.016979201460018443,0.0
3
+ 2,2022-07-10,324,0.0,,0.9888325300481584,0.0,0.0,0.9639244723320007,0.0258402586832329,0.9777769451229661
4
+ 3,2022-07-11,121,703.03003,33080400,0.9755014126951044,0.0,inf,0.9888325300481584,-0.01348167353718104,0.9760861383584212
5
+ 4,2022-07-12,9,699.21002,29310300,0.9687565366427103,703.03003,-0.005433637024011655,0.9755014126951044,-0.006914265796662843,0.9776968264619911
6
+ 5,2022-07-13,196,711.12,32651500,0.991240360907146,699.21002,0.017033480155218626,0.9687565366427103,0.023208952315671386,0.9784994367483201
7
+ 6,2022-07-14,100,714.94,26185800,0.9773943841457366,711.12,0.005371807852401916,0.991240360907146,-0.01396833432885843,0.9791304272318643
8
+ 7,2022-07-15,49,0.0,,0.9558297651154655,714.94,-1.0,0.9773943841457366,-0.022063375214826014,0.9748215033894493
9
+ 8,2022-07-16,64,0.0,,0.9682549461722374,0.0,0.0,0.9558297651154655,0.012999366111256171,0.9671596984778131
10
+ 9,2022-07-17,121,0.0,,0.9894618229432539,0.0,0.0,0.9682549461722374,0.02190216208536067,0.9711821780769855
app.py CHANGED
@@ -13,7 +13,7 @@ import time
13
  from plotly.subplots import make_subplots
14
 
15
  # Read CSV file into pandas and extract timestamp data
16
- dfSentiment = pd.read_csv("../TSLASentimentAnalyzer/sentiment_data.csv")
17
  dfSentiment['timestamp'] = [datetime.strptime(dt, '%Y-%m-%d') for dt in dfSentiment['timestamp'].tolist()]
18
 
19
  # Multi-select columns to build chart
 
13
  from plotly.subplots import make_subplots
14
 
15
  # Read CSV file into pandas and extract timestamp data
16
+ dfSentiment = pd.read_csv("./TSLASentimentAnalyzer/sentiment_data.csv")
17
  dfSentiment['timestamp'] = [datetime.strptime(dt, '%Y-%m-%d') for dt in dfSentiment['timestamp'].tolist()]
18
 
19
  # Multi-select columns to build chart