File size: 4,956 Bytes
cbcecfb 3b3dbc9 cbcecfb 62f43b4 cbcecfb 77d9839 7a8513b cbcecfb 62984a8 00320ff 77d9839 00320ff 62f43b4 9f1606d 320952b 00320ff 62f43b4 cbcecfb 3b3dbc9 00320ff 7a8513b 00320ff 9f1606d 62f43b4 00320ff cbcecfb 62f43b4 f6c60e6 cbcecfb f6c60e6 cbcecfb 62f43b4 cbcecfb 3899fac cbcecfb 7a8513b 6dccd1c 7a8513b cbcecfb 62984a8 7a8513b 62984a8 7a8513b 62984a8 7a8513b e3c8b51 320952b 00320ff 7a8513b 9f1606d 7a8513b 320952b 00320ff 7a8513b 00320ff 7a8513b 3899fac 7a8513b 3899fac 7a8513b 00320ff 320952b 9f1606d 7a8513b 3899fac 3cc406f 7a8513b 3899fac a0489ac 7a8513b 62984a8 7a8513b 62984a8 3cc406f 3899fac 7a8513b 9f1606d cbcecfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import re
import sys
import nltk
import praw
import matplotlib
from tqdm import tqdm
import gradio as gr
import pandas as pd
import praw.exceptions
import matplotlib.pyplot as plt
from wordcloud import WordCloud
from transformers import pipeline
matplotlib.use('Agg')
def index_chunk(a):
n = round(0.3 * len(a))
k, m = divmod(len(a), n)
return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
def sentence_chunk(a):
sentences = []
buffer = ""
# the 512 token threshold is empirical
for item in a:
token_length_estimation = len(nltk.word_tokenize(buffer + item))
if token_length_estimation > 512:
sentences.append(buffer)
buffer = ""
buffer += item
sentences.append(buffer)
return sentences
def preprocessData(df):
df["text"] = df["text"].apply(lambda x: re.sub(r"http\S+", "", x, flags=re.M))
df["text"] = df["text"].apply(lambda x: re.sub(r"^>.+", "", x, flags=re.M))
# The df is sorted by comment score
# Empirically, having more than ~100 comments doesn't change much but slows down the summarizer.
# Slowdown is not present with load api but still seems good to limit low score comments.
if len(df.text) >= 128:
df = df[:128]
# chunking to handle giving the model too large of an input which crashes
chunked = sentence_chunk(df.text)
return chunked
def getComments(url, debug=False):
if debug and os.path.isfile('./debug_comments.csv'):
df = pd.read_csv("./debug_comments.csv")
return df
client_id = os.environ['REDDIT_CLIENT_ID']
client_secret = os.environ['REDDIT_CLIENT_SECRET']
user_agent = os.environ['REDDIT_USER_AGENT']
reddit = praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)
try:
submission = reddit.submission(url=url)
except praw.exceptions.InvalidURL:
print("The URL is invalid. Make sure that you have included the submission id")
submission.comments.replace_more(limit=0)
cols = [
"text",
"score",
"id",
"parent_id",
"submission_title",
"submission_score",
"submission_id"
]
rows = []
for comment in submission.comments.list():
if comment.stickied:
continue
data = [
comment.body,
comment.score,
comment.id,
comment.parent_id,
submission.title,
submission.score,
submission.id,
]
rows.append(data)
df = pd.DataFrame(data=rows, columns=cols)
if debug:
# save for debugging to avoid sending tons of requests to reddit
df.to_csv('debug_comments.csv', index=False)
return df
def summarizer(url: str) -> str:
# pushshift.io submission comments api doesn't work so have to use praw
df = getComments(url=url)
submission_title = '## ' + df.submission_title.unique()[0]
chunked_df = preprocessData(df)
text = ' '.join(chunked_df)
# transparent bg: background_color=None, mode='RGBA')
wc_opts = dict(collocations=False, width=1920, height=1080)
wcloud = WordCloud(**wc_opts).generate(text)
plt.imshow(wcloud, aspect='auto')
plt.axis("off")
plt.gca().set_position([0, 0, 1, 1])
plt.autoscale(tight=True)
fig = plt.gcf()
fig.patch.set_alpha(0.0)
fig.set_size_inches((12, 7))
lst_summaries = []
for grp in tqdm(chunked_df):
# treating a group of comments as one block of text
result = sum_api(grp)
lst_summaries.append(result)
long_output = ' '.join(lst_summaries).replace(" .", ".")
short_output = sum_api(long_output).replace(" .", ".")
sentiment = clf_api(short_output)
return submission_title, short_output, long_output, sentiment, fig
if __name__ == "__main__":
sum_model = "models/sshleifer/distilbart-cnn-12-6"
clf_model = "models/finiteautomata/bertweet-base-sentiment-analysis"
hf_token = os.environ["HF_TOKEN"]
sum_api = gr.Interface.load(sum_model, api_key=hf_token)
clf_api = gr.Interface.load(clf_model, api_key=hf_token)
with gr.Blocks(css=".gradio-container {max-width: 900px !important; width: 100%}") as demo:
submission_url = gr.Textbox(label='Post URL')
sub_btn = gr.Button("Summarize")
title = gr.Markdown("")
with gr.Row():
short_summary = gr.Textbox(label='Short Summary')
summary_sentiment = gr.Label(label='Sentiment')
thread_cloud = gr.Plot(label='Word Cloud')
long_summary = gr.Textbox(label='Long Summary')
sub_btn.click(fn=summarizer,
inputs=[submission_url],
outputs=[title, short_summary, long_summary, summary_sentiment, thread_cloud])
try:
demo.launch()
except KeyboardInterrupt:
gr.close_all()
|