File size: 3,955 Bytes
cbcecfb
3b3dbc9
cbcecfb
 
 
 
 
 
00320ff
 
 
320952b
9f1606d
320952b
 
00320ff
 
cbcecfb
3b3dbc9
 
00320ff
320952b
00320ff
3cc406f
00320ff
320952b
00320ff
1d197a9
320952b
 
00320ff
9f1606d
320952b
00320ff
cbcecfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00320ff
320952b
00320ff
cbcecfb
9f1606d
320952b
 
00320ff
3b3dbc9
00320ff
3b3dbc9
 
cbcecfb
3b3dbc9
cbcecfb
00320ff
 
320952b
cbcecfb
 
 
 
 
 
9f1606d
3cc406f
 
 
3b3dbc9
 
3cc406f
 
 
 
1d197a9
9f1606d
cbcecfb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import re
import sys
import toml
import praw
import gradio as gr
import pandas as pd
import praw.exceptions
from transformers import pipeline


def chunk(a):
    n = round(0.3 * len(a))
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))


def preprocessData(df):
    df["text"] = df["text"].apply(lambda x: re.sub(r"http\S+", "", x, flags=re.M))
    df["text"] = df["text"].apply(lambda x: re.sub(r"^>.+", "", x, flags=re.M))

    smax = df.score.max()

    threshold = round(0.05 * smax)

    df = df[df.score >= threshold]

    # empirically, having more than 200 comments doesn't change much but slows down the summarizer.
    if len(df.text) >= 200:
        df = df[:200]

    # chunking to handle giving the model too large of an input which crashes
    chunked = list(chunk(df.text))

    return chunked


def getComments(url, debug=False):

    api_keys = toml.load('./api_params.toml')

    reddit = praw.Reddit(
        client_id=api_keys['client_id'] ,
        client_secret=api_keys['client_secret'] ,
        user_agent=api_keys['user_agent']
    )

    try:
        submission = reddit.submission(url=url)
        if debug and os.path.isfile(f'./{submission.id}_comments.csv'):
            df = pd.read_csv(f"./{submission.id}_comments.csv")
            return df
        else:
            pass
    except praw.exceptions.InvalidURL:
        print("The URL is invalid. Make sure that you have included the submission id")

    submission.comments.replace_more(limit=0)

    cols = [
        "text",
        "score",
        "id",
        "parent_id",
        "submission_title",
        "submission_score",
        "submission_id"
    ]
    rows = []

    for comment in submission.comments.list():

        if comment.stickied:
            continue

        data = [
            comment.body,
            comment.score,
            comment.id,
            comment.parent_id,
            submission.title,
            submission.score,
            submission.id,
        ]

        rows.append(data)

    df = pd.DataFrame(data=rows, columns=cols)

    if debug:
        # save for debugging to avoid sending tons of requests to reddit

        df.to_csv(f'{submission.id}_comments.csv', index=False)

    return df


def summarizer(url: str, summary_length: str = "Short") -> str:

    # pushshift.io submission comments api doesn't work so have to use praw
    df = getComments(url=url)
    chunked_df = preprocessData(df)

    submission_title = df.submission_title.unique()[0]

    nlp = pipeline('summarization', model="model/")

    lst_summaries = []

    for grp in chunked_df:
        # treating a group of comments as one block of text
        result = nlp(grp.str.cat(), max_length=500)[0]["summary_text"]
        lst_summaries.append(result)

    stext = ' '.join(lst_summaries).replace(" .", ".")

    if summary_length == "Short":
        thread_summary = nlp(stext, max_length=500)[0]["summary_text"].replace(" .", ".")
        return submission_title + '\n' + '\n' + thread_summary
    else:
        return submission_title + '\n' + '\n' + stext


if __name__ == "__main__":
    if not os.path.isfile('./api_params.toml'):
        print("""
                Could not find api params config file in directory.
                Please create api_params.toml by following the instructions in the README.
              """)
        sys.exit(1)

    with gr.Blocks(css=".gradio-container {max-width: 900px; margin: auto;}") as demo:
        submission_url = gr.Textbox(label='Post URL')

        length_choice = gr.Radio(label='Summary Length', value="Short", choices=["Short", "Long"])

        sub_btn = gr.Button("Summarize")

        summary = gr.Textbox(label='Comment Summary')

        sub_btn.click(fn=summarizer, inputs=[submission_url, length_choice], outputs=summary)

    try:
        demo.launch()
    except KeyboardInterrupt:
        gr.close_all()