File size: 3,955 Bytes
cbcecfb 3b3dbc9 cbcecfb 00320ff 320952b 9f1606d 320952b 00320ff cbcecfb 3b3dbc9 00320ff 320952b 00320ff 3cc406f 00320ff 320952b 00320ff 1d197a9 320952b 00320ff 9f1606d 320952b 00320ff cbcecfb 00320ff 320952b 00320ff cbcecfb 9f1606d 320952b 00320ff 3b3dbc9 00320ff 3b3dbc9 cbcecfb 3b3dbc9 cbcecfb 00320ff 320952b cbcecfb 9f1606d 3cc406f 3b3dbc9 3cc406f 1d197a9 9f1606d cbcecfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import re
import sys
import toml
import praw
import gradio as gr
import pandas as pd
import praw.exceptions
from transformers import pipeline
def chunk(a):
n = round(0.3 * len(a))
k, m = divmod(len(a), n)
return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
def preprocessData(df):
df["text"] = df["text"].apply(lambda x: re.sub(r"http\S+", "", x, flags=re.M))
df["text"] = df["text"].apply(lambda x: re.sub(r"^>.+", "", x, flags=re.M))
smax = df.score.max()
threshold = round(0.05 * smax)
df = df[df.score >= threshold]
# empirically, having more than 200 comments doesn't change much but slows down the summarizer.
if len(df.text) >= 200:
df = df[:200]
# chunking to handle giving the model too large of an input which crashes
chunked = list(chunk(df.text))
return chunked
def getComments(url, debug=False):
api_keys = toml.load('./api_params.toml')
reddit = praw.Reddit(
client_id=api_keys['client_id'] ,
client_secret=api_keys['client_secret'] ,
user_agent=api_keys['user_agent']
)
try:
submission = reddit.submission(url=url)
if debug and os.path.isfile(f'./{submission.id}_comments.csv'):
df = pd.read_csv(f"./{submission.id}_comments.csv")
return df
else:
pass
except praw.exceptions.InvalidURL:
print("The URL is invalid. Make sure that you have included the submission id")
submission.comments.replace_more(limit=0)
cols = [
"text",
"score",
"id",
"parent_id",
"submission_title",
"submission_score",
"submission_id"
]
rows = []
for comment in submission.comments.list():
if comment.stickied:
continue
data = [
comment.body,
comment.score,
comment.id,
comment.parent_id,
submission.title,
submission.score,
submission.id,
]
rows.append(data)
df = pd.DataFrame(data=rows, columns=cols)
if debug:
# save for debugging to avoid sending tons of requests to reddit
df.to_csv(f'{submission.id}_comments.csv', index=False)
return df
def summarizer(url: str, summary_length: str = "Short") -> str:
# pushshift.io submission comments api doesn't work so have to use praw
df = getComments(url=url)
chunked_df = preprocessData(df)
submission_title = df.submission_title.unique()[0]
nlp = pipeline('summarization', model="model/")
lst_summaries = []
for grp in chunked_df:
# treating a group of comments as one block of text
result = nlp(grp.str.cat(), max_length=500)[0]["summary_text"]
lst_summaries.append(result)
stext = ' '.join(lst_summaries).replace(" .", ".")
if summary_length == "Short":
thread_summary = nlp(stext, max_length=500)[0]["summary_text"].replace(" .", ".")
return submission_title + '\n' + '\n' + thread_summary
else:
return submission_title + '\n' + '\n' + stext
if __name__ == "__main__":
if not os.path.isfile('./api_params.toml'):
print("""
Could not find api params config file in directory.
Please create api_params.toml by following the instructions in the README.
""")
sys.exit(1)
with gr.Blocks(css=".gradio-container {max-width: 900px; margin: auto;}") as demo:
submission_url = gr.Textbox(label='Post URL')
length_choice = gr.Radio(label='Summary Length', value="Short", choices=["Short", "Long"])
sub_btn = gr.Button("Summarize")
summary = gr.Textbox(label='Comment Summary')
sub_btn.click(fn=summarizer, inputs=[submission_url, length_choice], outputs=summary)
try:
demo.launch()
except KeyboardInterrupt:
gr.close_all()
|