Blaise-g commited on
Commit
102cc89
Β·
1 Parent(s): a7a06e0

Create new file

Browse files
Files changed (1) hide show
  1. app.py +288 -0
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import nltk
7
+ from cleantext import clean
8
+
9
+ from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
10
+ from utils import load_example_filenames, truncate_word_count
11
+
12
+ _here = Path(__file__).parent
13
+
14
+ nltk.download("stopwords") # TODO=find where this requirement originates from
15
+
16
+ logging.basicConfig(
17
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
18
+ )
19
+
20
+
21
+ def proc_submission(
22
+ input_text: str,
23
+ model_size: str,
24
+ num_beams,
25
+ token_batch_length,
26
+ length_penalty,
27
+ repetition_penalty,
28
+ no_repeat_ngram_size,
29
+ max_input_length: int = 768,
30
+ ):
31
+ """
32
+ proc_submission - a helper function for the gradio module to process submissions
33
+
34
+ Args:
35
+ input_text (str): the input text to summarize
36
+ model_size (str): the size of the model to use
37
+ num_beams (int): the number of beams to use
38
+ token_batch_length (int): the length of the token batches to use
39
+ length_penalty (float): the length penalty to use
40
+ repetition_penalty (float): the repetition penalty to use
41
+ no_repeat_ngram_size (int): the no repeat ngram size to use
42
+ max_input_length (int, optional): the maximum input length to use. Defaults to 768.
43
+
44
+ Returns:
45
+ str in HTML format, string of the summary, str of score
46
+ """
47
+
48
+ settings = {
49
+ "length_penalty": float(length_penalty),
50
+ "repetition_penalty": float(repetition_penalty),
51
+ "no_repeat_ngram_size": int(no_repeat_ngram_size),
52
+ "encoder_no_repeat_ngram_size": 4,
53
+ "num_beams": int(num_beams),
54
+ "min_length": 4,
55
+ "max_length": int(token_batch_length // 4),
56
+ "early_stopping": True,
57
+ "do_sample": False,
58
+ }
59
+ st = time.perf_counter()
60
+ history = {}
61
+ clean_text = clean(input_text, lower=False)
62
+ max_input_length = 2048 if model_size == "base" else max_input_length
63
+ processed = truncate_word_count(clean_text, max_input_length)
64
+
65
+ if processed["was_truncated"]:
66
+ tr_in = processed["truncated_text"]
67
+ msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
68
+ logging.warning(msg)
69
+ history["WARNING"] = msg
70
+ else:
71
+ tr_in = input_text
72
+ msg = None
73
+
74
+ _summaries = summarize_via_tokenbatches(
75
+ tr_in,
76
+ model_sm if model_size == "base" else model,
77
+ tokenizer_sm if model_size == "base" else tokenizer,
78
+ batch_length=token_batch_length,
79
+ **settings,
80
+ )
81
+ sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
82
+ sum_scores = [
83
+ f" - Section {i}: {round(s['summary_score'],4)}"
84
+ for i, s in enumerate(_summaries)
85
+ ]
86
+
87
+ sum_text_out = "\n".join(sum_text)
88
+ history["Summary Scores"] = "<br><br>"
89
+ scores_out = "\n".join(sum_scores)
90
+ rt = round((time.perf_counter() - st) / 60, 2)
91
+ print(f"Runtime: {rt} minutes")
92
+ html = ""
93
+ html += f"<p>Runtime: {rt} minutes on CPU</p>"
94
+ if msg is not None:
95
+ html += f"<h2>WARNING:</h2><hr><b>{msg}</b><br><br>"
96
+
97
+ html += ""
98
+
99
+ return html, sum_text_out, scores_out
100
+
101
+
102
+ def load_single_example_text(
103
+ example_path: str or Path,
104
+ ):
105
+ """
106
+ load_single_example - a helper function for the gradio module to load examples
107
+ Returns:
108
+ list of str, the examples
109
+ """
110
+ global name_to_path
111
+ full_ex_path = name_to_path[example_path]
112
+ full_ex_path = Path(full_ex_path)
113
+ # load the examples into a list
114
+ with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f:
115
+ raw_text = f.read()
116
+ text = clean(raw_text, lower=False)
117
+ return text
118
+
119
+
120
+ def load_uploaded_file(file_obj):
121
+ """
122
+ load_uploaded_file - process an uploaded file
123
+ Args:
124
+ file_obj (POTENTIALLY list): Gradio file object inside a list
125
+ Returns:
126
+ str, the uploaded file contents
127
+ """
128
+
129
+ # file_path = Path(file_obj[0].name)
130
+
131
+ # check if mysterious file object is a list
132
+ if isinstance(file_obj, list):
133
+ file_obj = file_obj[0]
134
+ file_path = Path(file_obj.name)
135
+ try:
136
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
137
+ raw_text = f.read()
138
+ text = clean(raw_text, lower=False)
139
+ return text
140
+ except Exception as e:
141
+ logging.info(f"Trying to load file with path {file_path}, error: {e}")
142
+ return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8."
143
+
144
+
145
+ if __name__ == "__main__":
146
+
147
+ model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
148
+ model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary")
149
+
150
+ name_to_path = load_example_filenames(_here / "examples")
151
+ logging.info(f"Loaded {len(name_to_path)} examples")
152
+ demo = gr.Blocks()
153
+
154
+ with demo:
155
+
156
+ gr.Markdown("# Long-Form Summarization: LED & BookSum")
157
+ gr.Markdown(
158
+ "A simple demo using a fine-tuned LED model to summarize long-form text. See [model card](https://huggingface.co/pszemraj/led-large-book-summary) for a notebook with GPU inference (much faster) on Colab."
159
+ )
160
+ with gr.Column():
161
+
162
+ gr.Markdown("## Load Inputs & Select Parameters")
163
+ gr.Markdown(
164
+ "Enter text below in the text area. The text will be summarized [using the selected parameters](https://huggingface.co/blog/how-to-generate). Optionally load an example below or upload a file."
165
+ )
166
+ with gr.Row():
167
+ model_size = gr.Radio(
168
+ choices=["base", "large"], label="Model Variant", value="large"
169
+ )
170
+ num_beams = gr.Radio(
171
+ choices=[2, 3, 4],
172
+ label="Beam Search: # of Beams",
173
+ value=2,
174
+ )
175
+ gr.Markdown(
176
+ "_The base model is less performant than the large model, but is faster and will accept up to 2048 words per input (Large model accepts up to 768)._"
177
+ )
178
+ with gr.Row():
179
+ length_penalty = gr.inputs.Slider(
180
+ minimum=0.5,
181
+ maximum=1.0,
182
+ label="length penalty",
183
+ default=0.7,
184
+ step=0.05,
185
+ )
186
+ token_batch_length = gr.Radio(
187
+ choices=[512, 768, 1024],
188
+ label="token batch length",
189
+ value=512,
190
+ )
191
+
192
+ with gr.Row():
193
+ repetition_penalty = gr.inputs.Slider(
194
+ minimum=1.0,
195
+ maximum=5.0,
196
+ label="repetition penalty",
197
+ default=3.5,
198
+ step=0.1,
199
+ )
200
+ no_repeat_ngram_size = gr.Radio(
201
+ choices=[2, 3, 4],
202
+ label="no repeat ngram size",
203
+ value=3,
204
+ )
205
+ with gr.Row():
206
+ example_name = gr.Dropdown(
207
+ list(name_to_path.keys()),
208
+ label="Choose an Example",
209
+ )
210
+ load_examples_button = gr.Button(
211
+ "Load Example",
212
+ )
213
+ input_text = gr.Textbox(
214
+ lines=6,
215
+ label="Input Text (for summarization)",
216
+ placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
217
+ )
218
+ gr.Markdown("Upload your own file:")
219
+ with gr.Row():
220
+ uploaded_file = gr.File(
221
+ label="Upload a text file",
222
+ file_count="single",
223
+ type="file",
224
+ )
225
+ load_file_button = gr.Button("Load Uploaded File")
226
+
227
+ gr.Markdown("---")
228
+
229
+ with gr.Column():
230
+ gr.Markdown("## Generate Summary")
231
+ gr.Markdown(
232
+ "Summary generation should take approximately 1-2 minutes for most settings."
233
+ )
234
+ summarize_button = gr.Button(
235
+ "Summarize!",
236
+ variant="primary",
237
+ )
238
+
239
+ output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
240
+ gr.Markdown("### Summary Output")
241
+ summary_text = gr.Textbox(
242
+ label="Summary", placeholder="The generated summary will appear here"
243
+ )
244
+ gr.Markdown(
245
+ "The summary scores can be thought of as representing the quality of the summary. less-negative numbers (closer to 0) are better:"
246
+ )
247
+ summary_scores = gr.Textbox(
248
+ label="Summary Scores", placeholder="Summary scores will appear here"
249
+ )
250
+
251
+ gr.Markdown("---")
252
+
253
+ with gr.Column():
254
+ gr.Markdown("## About the Model")
255
+ gr.Markdown(
256
+ "- [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
257
+ )
258
+ gr.Markdown(
259
+ "- The two most important parameters-empirically-are the `num_beams` and `token_batch_length`. However, increasing these will also increase the amount of time it takes to generate a summary. The `length_penalty` and `repetition_penalty` parameters are also important for the model to generate good summaries."
260
+ )
261
+ gr.Markdown(
262
+ "- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial."
263
+ )
264
+ gr.Markdown("---")
265
+
266
+ load_examples_button.click(
267
+ fn=load_single_example_text, inputs=[example_name], outputs=[input_text]
268
+ )
269
+
270
+ load_file_button.click(
271
+ fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text]
272
+ )
273
+
274
+ summarize_button.click(
275
+ fn=proc_submission,
276
+ inputs=[
277
+ input_text,
278
+ model_size,
279
+ num_beams,
280
+ token_batch_length,
281
+ length_penalty,
282
+ repetition_penalty,
283
+ no_repeat_ngram_size,
284
+ ],
285
+ outputs=[output_text, summary_text, summary_scores],
286
+ )
287
+
288
+ demo.launch(enable_queue=True, share=True)