kkastr
condensed the scraper into the main app file. change api keys config to use toml. clean up of aws files as will no longer be deploying there
cbcecfb
import os | |
import re | |
import sys | |
import toml | |
import praw | |
import gradio as gr | |
import pandas as pd | |
import praw.exceptions | |
from transformers import pipeline | |
def chunk(a): | |
n = round(0.3 * len(a)) | |
k, m = divmod(len(a), n) | |
return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)) | |
def preprocessData(df): | |
df["text"] = df["text"].apply(lambda x: re.sub(r"http\S+", "", x, flags=re.M)) | |
df["text"] = df["text"].apply(lambda x: re.sub(r"^>.+", "", x, flags=re.M)) | |
smax = df.score.max() | |
threshold = round(0.05 * smax) | |
df = df[df.score >= threshold] | |
# empirically, having more than 200 comments doesn't change much but slows down the summarizer. | |
if len(df.text) >= 200: | |
df = df[:200] | |
# chunking to handle giving the model too large of an input which crashes | |
chunked = list(chunk(df.text)) | |
return chunked | |
def getComments(url, debug=False): | |
api_keys = toml.load('./api_params.toml') | |
reddit = praw.Reddit( | |
client_id=api_keys['client_id'] , | |
client_secret=api_keys['client_secret'] , | |
user_agent=api_keys['user_agent'] | |
) | |
try: | |
submission = reddit.submission(url=url) | |
if debug and os.path.isfile(f'./{submission.id}_comments.csv'): | |
df = pd.read_csv(f"./{submission.id}_comments.csv") | |
return df | |
else: | |
pass | |
except praw.exceptions.InvalidURL: | |
print("The URL is invalid. Make sure that you have included the submission id") | |
submission.comments.replace_more(limit=0) | |
cols = [ | |
"text", | |
"score", | |
"id", | |
"parent_id", | |
"submission_title", | |
"submission_score", | |
"submission_id" | |
] | |
rows = [] | |
for comment in submission.comments.list(): | |
if comment.stickied: | |
continue | |
data = [ | |
comment.body, | |
comment.score, | |
comment.id, | |
comment.parent_id, | |
submission.title, | |
submission.score, | |
submission.id, | |
] | |
rows.append(data) | |
df = pd.DataFrame(data=rows, columns=cols) | |
if debug: | |
# save for debugging to avoid sending tons of requests to reddit | |
df.to_csv(f'{submission.id}_comments.csv', index=False) | |
return df | |
def summarizer(url: str, summary_length: str = "Short") -> str: | |
# pushshift.io submission comments api doesn't work so have to use praw | |
df = getComments(url=url) | |
chunked_df = preprocessData(df) | |
submission_title = df.submission_title.unique()[0] | |
nlp = pipeline('summarization', model="model/") | |
lst_summaries = [] | |
for grp in chunked_df: | |
# treating a group of comments as one block of text | |
result = nlp(grp.str.cat(), max_length=500)[0]["summary_text"] | |
lst_summaries.append(result) | |
stext = ' '.join(lst_summaries).replace(" .", ".") | |
if summary_length == "Short": | |
thread_summary = nlp(stext, max_length=500)[0]["summary_text"].replace(" .", ".") | |
return submission_title + '\n' + '\n' + thread_summary | |
else: | |
return submission_title + '\n' + '\n' + stext | |
if __name__ == "__main__": | |
if not os.path.isfile('./api_params.toml'): | |
print(""" | |
Could not find api params config file in directory. | |
Please create api_params.toml by following the instructions in the README. | |
""") | |
sys.exit(1) | |
with gr.Blocks(css=".gradio-container {max-width: 900px; margin: auto;}") as demo: | |
submission_url = gr.Textbox(label='Post URL') | |
length_choice = gr.Radio(label='Summary Length', value="Short", choices=["Short", "Long"]) | |
sub_btn = gr.Button("Summarize") | |
summary = gr.Textbox(label='Comment Summary') | |
sub_btn.click(fn=summarizer, inputs=[submission_url, length_choice], outputs=summary) | |
try: | |
demo.launch() | |
except KeyboardInterrupt: | |
gr.close_all() | |