awacke1 commited on
Commit
cce4162
·
1 Parent(s): 2d3e15c

Create new file

Browse files
Files changed (1) hide show
  1. app.py +286 -0
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import nltk
7
+ from cleantext import clean
8
+
9
+ from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
10
+ from utils import load_example_filenames, truncate_word_count
11
+
12
+ _here = Path(__file__).parent
13
+
14
+ nltk.download("stopwords") # TODO=find where this requirement originates from
15
+
16
+ logging.basicConfig(
17
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
18
+ )
19
+
20
+
21
+ def proc_submission(
22
+ input_text: str,
23
+ model_size: str,
24
+ num_beams,
25
+ token_batch_length,
26
+ length_penalty,
27
+ repetition_penalty,
28
+ no_repeat_ngram_size,
29
+ max_input_length: int = 768,
30
+ ):
31
+ """
32
+ proc_submission - a helper function for the gradio module to process submissions
33
+ Args:
34
+ input_text (str): the input text to summarize
35
+ model_size (str): the size of the model to use
36
+ num_beams (int): the number of beams to use
37
+ token_batch_length (int): the length of the token batches to use
38
+ length_penalty (float): the length penalty to use
39
+ repetition_penalty (float): the repetition penalty to use
40
+ no_repeat_ngram_size (int): the no repeat ngram size to use
41
+ max_input_length (int, optional): the maximum input length to use. Defaults to 768.
42
+ Returns:
43
+ str in HTML format, string of the summary, str of score
44
+ """
45
+
46
+ settings = {
47
+ "length_penalty": float(length_penalty),
48
+ "repetition_penalty": float(repetition_penalty),
49
+ "no_repeat_ngram_size": int(no_repeat_ngram_size),
50
+ "encoder_no_repeat_ngram_size": 4,
51
+ "num_beams": int(num_beams),
52
+ "min_length": 4,
53
+ "max_length": int(token_batch_length // 4),
54
+ "early_stopping": True,
55
+ "do_sample": False,
56
+ }
57
+ st = time.perf_counter()
58
+ history = {}
59
+ clean_text = clean(input_text, lower=False)
60
+ max_input_length = 2048 if model_size == "base" else max_input_length
61
+ processed = truncate_word_count(clean_text, max_input_length)
62
+
63
+ if processed["was_truncated"]:
64
+ tr_in = processed["truncated_text"]
65
+ msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
66
+ logging.warning(msg)
67
+ history["WARNING"] = msg
68
+ else:
69
+ tr_in = input_text
70
+ msg = None
71
+
72
+ _summaries = summarize_via_tokenbatches(
73
+ tr_in,
74
+ model_sm if model_size == "base" else model,
75
+ tokenizer_sm if model_size == "base" else tokenizer,
76
+ batch_length=token_batch_length,
77
+ **settings,
78
+ )
79
+ sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
80
+ sum_scores = [
81
+ f" - Section {i}: {round(s['summary_score'],4)}"
82
+ for i, s in enumerate(_summaries)
83
+ ]
84
+
85
+ sum_text_out = "\n".join(sum_text)
86
+ history["Summary Scores"] = "<br><br>"
87
+ scores_out = "\n".join(sum_scores)
88
+ rt = round((time.perf_counter() - st) / 60, 2)
89
+ print(f"Runtime: {rt} minutes")
90
+ html = ""
91
+ html += f"<p>Runtime: {rt} minutes on CPU</p>"
92
+ if msg is not None:
93
+ html += f"<h2>WARNING:</h2><hr><b>{msg}</b><br><br>"
94
+
95
+ html += ""
96
+
97
+ return html, sum_text_out, scores_out
98
+
99
+
100
+ def load_single_example_text(
101
+ example_path: str or Path,
102
+ ):
103
+ """
104
+ load_single_example - a helper function for the gradio module to load examples
105
+ Returns:
106
+ list of str, the examples
107
+ """
108
+ global name_to_path
109
+ full_ex_path = name_to_path[example_path]
110
+ full_ex_path = Path(full_ex_path)
111
+ # load the examples into a list
112
+ with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f:
113
+ raw_text = f.read()
114
+ text = clean(raw_text, lower=False)
115
+ return text
116
+
117
+
118
+ def load_uploaded_file(file_obj):
119
+ """
120
+ load_uploaded_file - process an uploaded file
121
+ Args:
122
+ file_obj (POTENTIALLY list): Gradio file object inside a list
123
+ Returns:
124
+ str, the uploaded file contents
125
+ """
126
+
127
+ # file_path = Path(file_obj[0].name)
128
+
129
+ # check if mysterious file object is a list
130
+ if isinstance(file_obj, list):
131
+ file_obj = file_obj[0]
132
+ file_path = Path(file_obj.name)
133
+ try:
134
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
135
+ raw_text = f.read()
136
+ text = clean(raw_text, lower=False)
137
+ return text
138
+ except Exception as e:
139
+ logging.info(f"Trying to load file with path {file_path}, error: {e}")
140
+ return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8."
141
+
142
+
143
+ if __name__ == "__main__":
144
+
145
+ model, tokenizer = load_model_and_tokenizer("pszemraj/led-large-book-summary")
146
+ model_sm, tokenizer_sm = load_model_and_tokenizer("pszemraj/led-base-book-summary")
147
+
148
+ name_to_path = load_example_filenames(_here / "examples")
149
+ logging.info(f"Loaded {len(name_to_path)} examples")
150
+ demo = gr.Blocks()
151
+
152
+ with demo:
153
+
154
+ gr.Markdown("# Long-Form Summarization: LED & BookSum")
155
+ gr.Markdown(
156
+ "A simple demo using a fine-tuned LED model to summarize long-form text. See [model card](https://huggingface.co/pszemraj/led-large-book-summary) for a notebook with GPU inference (much faster) on Colab."
157
+ )
158
+ with gr.Column():
159
+
160
+ gr.Markdown("## Load Inputs & Select Parameters")
161
+ gr.Markdown(
162
+ "Enter text below in the text area. The text will be summarized [using the selected parameters](https://huggingface.co/blog/how-to-generate). Optionally load an example below or upload a file."
163
+ )
164
+ with gr.Row():
165
+ model_size = gr.Radio(
166
+ choices=["base", "large"], label="Model Variant", value="large"
167
+ )
168
+ num_beams = gr.Radio(
169
+ choices=[2, 3, 4],
170
+ label="Beam Search: # of Beams",
171
+ value=2,
172
+ )
173
+ gr.Markdown(
174
+ "_The base model is less performant than the large model, but is faster and will accept up to 2048 words per input (Large model accepts up to 768)._"
175
+ )
176
+ with gr.Row():
177
+ length_penalty = gr.inputs.Slider(
178
+ minimum=0.5,
179
+ maximum=1.0,
180
+ label="length penalty",
181
+ default=0.7,
182
+ step=0.05,
183
+ )
184
+ token_batch_length = gr.Radio(
185
+ choices=[512, 768, 1024],
186
+ label="token batch length",
187
+ value=512,
188
+ )
189
+
190
+ with gr.Row():
191
+ repetition_penalty = gr.inputs.Slider(
192
+ minimum=1.0,
193
+ maximum=5.0,
194
+ label="repetition penalty",
195
+ default=3.5,
196
+ step=0.1,
197
+ )
198
+ no_repeat_ngram_size = gr.Radio(
199
+ choices=[2, 3, 4],
200
+ label="no repeat ngram size",
201
+ value=3,
202
+ )
203
+ with gr.Row():
204
+ example_name = gr.Dropdown(
205
+ list(name_to_path.keys()),
206
+ label="Choose an Example",
207
+ )
208
+ load_examples_button = gr.Button(
209
+ "Load Example",
210
+ )
211
+ input_text = gr.Textbox(
212
+ lines=6,
213
+ label="Input Text (for summarization)",
214
+ placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
215
+ )
216
+ gr.Markdown("Upload your own file:")
217
+ with gr.Row():
218
+ uploaded_file = gr.File(
219
+ label="Upload a text file",
220
+ file_count="single",
221
+ type="file",
222
+ )
223
+ load_file_button = gr.Button("Load Uploaded File")
224
+
225
+ gr.Markdown("---")
226
+
227
+ with gr.Column():
228
+ gr.Markdown("## Generate Summary")
229
+ gr.Markdown(
230
+ "Summary generation should take approximately 1-2 minutes for most settings."
231
+ )
232
+ summarize_button = gr.Button(
233
+ "Summarize!",
234
+ variant="primary",
235
+ )
236
+
237
+ output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
238
+ gr.Markdown("### Summary Output")
239
+ summary_text = gr.Textbox(
240
+ label="Summary", placeholder="The generated summary will appear here"
241
+ )
242
+ gr.Markdown(
243
+ "The summary scores can be thought of as representing the quality of the summary. less-negative numbers (closer to 0) are better:"
244
+ )
245
+ summary_scores = gr.Textbox(
246
+ label="Summary Scores", placeholder="Summary scores will appear here"
247
+ )
248
+
249
+ gr.Markdown("---")
250
+
251
+ with gr.Column():
252
+ gr.Markdown("## About the Model")
253
+ gr.Markdown(
254
+ "- [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
255
+ )
256
+ gr.Markdown(
257
+ "- The two most important parameters-empirically-are the `num_beams` and `token_batch_length`. However, increasing these will also increase the amount of time it takes to generate a summary. The `length_penalty` and `repetition_penalty` parameters are also important for the model to generate good summaries."
258
+ )
259
+ gr.Markdown(
260
+ "- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial."
261
+ )
262
+ gr.Markdown("---")
263
+
264
+ load_examples_button.click(
265
+ fn=load_single_example_text, inputs=[example_name], outputs=[input_text]
266
+ )
267
+
268
+ load_file_button.click(
269
+ fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text]
270
+ )
271
+
272
+ summarize_button.click(
273
+ fn=proc_submission,
274
+ inputs=[
275
+ input_text,
276
+ model_size,
277
+ num_beams,
278
+ token_batch_length,
279
+ length_penalty,
280
+ repetition_penalty,
281
+ no_repeat_ngram_size,
282
+ ],
283
+ outputs=[output_text, summary_text, summary_scores],
284
+ )
285
+
286
+ demo.launch(enable_queue=True, share=True)