Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 13,080 Bytes
0cef1e2 fe0e9af 9bc2923 53cfd2d 01d78f2 fe0e9af 0cef1e2 fe0e9af 01d78f2 fe0e9af 0cef1e2 fe0e9af 01d78f2 0cef1e2 01d78f2 fe0e9af 0cef1e2 66e7228 fe0e9af 0cef1e2 fe0e9af 4dc1508 d3f22a6 0cef1e2 4dc1508 fe0e9af 0cef1e2 fe0e9af aa3c57c fe0e9af aa3c57c b9e8529 3247bd6 fe0e9af 0cef1e2 53cfd2d fe0e9af 01d78f2 fe0e9af 0cef1e2 9bc2923 fe0e9af 0cef1e2 01d78f2 fe0e9af 3094ba9 0cef1e2 66e7228 0cef1e2 3b66adc 0cef1e2 3b66adc fe0e9af 42e4d6f 01d78f2 42e4d6f b1e0e58 0cef1e2 01d78f2 53cfd2d 01d78f2 9bc2923 fe0e9af 01d78f2 fe0e9af edb05fc 4dc1508 f8734a4 4dc1508 11da24b 4dc1508 f8734a4 4dc1508 f8734a4 4dc1508 f8734a4 4dc1508 66e7228 edb05fc fe0e9af 0cef1e2 01d78f2 0cef1e2 9bc2923 01d78f2 d3f22a6 01d78f2 d3f22a6 01d78f2 1bc0267 0cef1e2 1bc0267 93b2cca 1bc0267 0cef1e2 1bc0267 9bc2923 1bc0267 9bc2923 1bc0267 9bc2923 0cef1e2 9bc2923 b6fced7 9bc2923 875f311 01d78f2 4dc1508 1bc0267 bc946c2 01d78f2 9bc2923 0cef1e2 9bc2923 0cef1e2 9bc2923 875f311 9bc2923 0cef1e2 9bc2923 0cef1e2 9bc2923 01d78f2 9bc2923 01d78f2 0f6a079 0cef1e2 c2e071e 01d78f2 e950125 01d78f2 1a7303a 01d78f2 4dc1508 11da24b 4dc1508 01d78f2 0cef1e2 01d78f2 0cef1e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 |
"""
app.py - the main application file for the gradio app
"""
import gc
import logging
import random
import re
import time
from pathlib import Path
import gradio as gr
import nltk
import torch
from cleantext import clean
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
from utils import load_example_filenames, truncate_word_count
_here = Path(__file__).parent
nltk.download("stopwords", quiet=True)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - [%(levelname)s] %(name)s: %(message)s"
)
MODEL_OPTIONS = [
"pszemraj/led-large-book-summary",
"pszemraj/led-base-book-summary",
]
def predict(
input_text: str,
model_name: str,
token_batch_length: int = 2048,
empty_cache: bool = True,
**settings,
) -> list:
"""
predict - helper fn to support multiple models for summarization at once
:param str input_text: the input text to summarize
:param str model_name: model name to use
:param int token_batch_length: the length of the token batches to use
:param bool empty_cache: whether to empty the cache before loading a new= model
:return: list of dicts with keys "summary" and "score"
"""
if torch.cuda.is_available() and empty_cache:
torch.cuda.empty_cache()
model, tokenizer = load_model_and_tokenizer(model_name)
summaries = summarize_via_tokenbatches(
input_text,
model,
tokenizer,
batch_length=token_batch_length,
**settings,
)
del model
del tokenizer
gc.collect()
return summaries
def proc_submission(
input_text: str,
model_name: str,
num_beams: int,
token_batch_length: int,
length_penalty: float,
repetition_penalty: float,
no_repeat_ngram_size: int,
max_input_length: int = 2560,
):
"""
proc_submission - a helper function for the gradio module to process submissions
Args:
input_text (str): the input text to summarize
model_size (str): the size of the model to use
num_beams (int): the number of beams to use
token_batch_length (int): the length of the token batches to use
length_penalty (float): the length penalty to use
repetition_penalty (float): the repetition penalty to use
no_repeat_ngram_size (int): the no-repeat ngram size to use
max_input_length (int, optional): the maximum input length to use. Defaults to 2560.
Returns:
str in HTML format, string of the summary, str of score
"""
logger = logging.getLogger(__name__)
logger.info("Processing submission")
settings = {
"length_penalty": float(length_penalty),
"repetition_penalty": float(repetition_penalty),
"no_repeat_ngram_size": int(no_repeat_ngram_size),
"encoder_no_repeat_ngram_size": 4,
"num_beams": int(num_beams),
"min_length": 4,
"max_length": int(token_batch_length // 4),
"early_stopping": True,
"do_sample": False,
}
if "base" in model_name:
logger.info("Updating max_input_length to for base model")
max_input_length = 4096
logger.info(f"max_input_length: {max_input_length}")
st = time.perf_counter()
history = {}
clean_text = clean(input_text, lower=False)
processed = truncate_word_count(clean_text, max_input_length)
if processed["was_truncated"]:
truncated_input = processed["truncated_text"]
# create elaborate HTML warning
input_wc = re.split(r"\s+", input_text)
msg = f"""
<div style="background-color: #FFA500; color: white; padding: 20px;">
<h3>Warning</h3>
<p>Input text was truncated to {max_input_length} words. That's about {100*max_input_length/len(input_wc):.2f}% of the submission.</p>
</div>
"""
logging.warning(msg)
history["WARNING"] = msg
else:
truncated_input = input_text
msg = None
if len(input_text) < 50:
# this is essentially a different case from the above
msg = f"""
<div style="background-color: #880808; color: white; padding: 20px;">
<h3>Error</h3>
<p>Input text is too short to summarize. Detected {len(input_text)} characters.
Please load text by selecting an example from the dropdown menu or by pasting text into the text box.</p>
</div>
"""
logging.warning(msg)
logging.warning("RETURNING EMPTY STRING")
history["WARNING"] = msg
return msg, "", []
_summaries = predict(
input_text=truncated_input,
model_name=model_name,
token_batch_length=token_batch_length,
**settings,
)
sum_text = [
f"\nBatch {i}:\n\t" + s["summary"][0] for i, s in enumerate(_summaries, start=1)
]
sum_scores = [
f"\n- Batch {i}:\n\t{round(s['summary_score'],4)}"
for i, s in enumerate(_summaries, start=1)
]
sum_text_out = "\n".join(sum_text)
history["Summary Scores"] = "<br><br>"
scores_out = "\n".join(sum_scores)
rt = round((time.perf_counter() - st) / 60, 2)
logger.info(f"Runtime: {rt} minutes")
html = ""
html += f"<p>Runtime: {rt} minutes on CPU</p>"
if msg is not None:
html += msg
html += ""
return html, sum_text_out, scores_out
def load_single_example_text(
example_path: str or Path,
):
"""
load_single_example - a helper function for the gradio module to load examples
Returns:
list of str, the examples
"""
global name_to_path
full_ex_path = name_to_path[example_path]
full_ex_path = Path(full_ex_path)
# load the examples into a list
with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f:
raw_text = f.read()
text = clean(raw_text, lower=False)
return text
def load_uploaded_file(file_obj):
"""
load_uploaded_file - process an uploaded file
Args:
file_obj (POTENTIALLY list): Gradio file object inside a list
Returns:
str, the uploaded file contents
"""
# file_path = Path(file_obj[0].name)
# check if mysterious file object is a list
if isinstance(file_obj, list):
file_obj = file_obj[0]
file_path = Path(file_obj.name)
try:
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
raw_text = f.read()
text = clean(raw_text, lower=False)
return text
except Exception as e:
logging.info(f"Trying to load file with path {file_path}, error: {e}")
return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8."
if __name__ == "__main__":
logger = logging.getLogger(__name__)
logger.info("Starting up app")
name_to_path = load_example_filenames(_here / "examples")
logging.info(f"Loaded {len(name_to_path)} examples")
demo = gr.Blocks(
title="Summarize Long-Form Text",
)
_examples = list(name_to_path.keys())
with demo:
gr.Markdown("# Long-Form Summarization: LED & BookSum")
gr.Markdown(
"LED models ([model card](https://huggingface.co/pszemraj/led-large-book-summary)) fine-tuned to summarize long-form text. A [space with other models can be found here](https://huggingface.co/spaces/pszemraj/document-summarization)"
)
with gr.Column():
gr.Markdown("## Load Inputs & Select Parameters")
gr.Markdown(
"Enter or upload text below, and it will be summarized [using the selected parameters](https://huggingface.co/blog/how-to-generate). "
)
with gr.Row():
model_name = gr.Dropdown(
choices=MODEL_OPTIONS,
value=MODEL_OPTIONS[0],
label="Model Name",
)
num_beams = gr.Radio(
choices=[2, 3, 4],
label="Beam Search: # of Beams",
value=2,
)
gr.Markdown(
"Load a a .txt - example or your own (_You may find [this OCR space](https://huggingface.co/spaces/pszemraj/pdf-ocr) useful_)"
)
with gr.Row():
example_name = gr.Dropdown(
_examples,
label="Examples",
value=random.choice(_examples),
)
uploaded_file = gr.File(
label="File Upload",
file_count="single",
type="file",
)
with gr.Row():
input_text = gr.Textbox(
lines=4,
max_lines=12,
label="Text to Summarize",
placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
)
with gr.Column():
load_examples_button = gr.Button(
"Load Example",
)
load_file_button = gr.Button("Upload File")
gr.Markdown("---")
with gr.Column():
gr.Markdown("## Generate Summary")
gr.Markdown(
"Summary generation should take approximately 1-2 minutes for most settings."
)
summarize_button = gr.Button(
"Summarize!",
variant="primary",
)
output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
gr.Markdown("### Summary Output")
summary_text = gr.Textbox(
label="Summary", placeholder="The generated summary will appear here"
)
gr.Markdown(
"The summary scores can be thought of as representing the quality of the summary. less-negative numbers (closer to 0) are better:"
)
summary_scores = gr.Textbox(
label="Summary Scores", placeholder="Summary scores will appear here"
)
gr.Markdown("---")
with gr.Column():
gr.Markdown("### Advanced Settings")
with gr.Row():
length_penalty = gr.Slider(
minimum=0.5,
maximum=1.0,
label="length penalty",
value=0.7,
step=0.05,
)
token_batch_length = gr.Radio(
choices=[512, 768, 1024, 1536],
label="token batch length",
value=1024,
)
with gr.Row():
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=5.0,
label="repetition penalty",
value=3.5,
step=0.1,
)
no_repeat_ngram_size = gr.Radio(
choices=[2, 3, 4],
label="no repeat ngram size",
value=3,
)
with gr.Column():
gr.Markdown("### About the Model")
gr.Markdown(
"- [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage."
)
gr.Markdown(
"- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a Colab notebook for a tutorial."
)
gr.Markdown(
"- **Update May 1, 2023:** Enabled faster inference times via `use_cache=True`, the number of words the model will processed has been increased! Not on this demo, but there is a [test model](https://huggingface.co/pszemraj/led-large-book-summary-continued) available: an extension of `led-large-book-summary`."
)
gr.Markdown("---")
load_examples_button.click(
fn=load_single_example_text, inputs=[example_name], outputs=[input_text]
)
load_file_button.click(
fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text]
)
summarize_button.click(
fn=proc_submission,
inputs=[
input_text,
model_name,
num_beams,
token_batch_length,
length_penalty,
repetition_penalty,
no_repeat_ngram_size,
],
outputs=[output_text, summary_text, summary_scores],
)
demo.launch(
enable_queue=True,
)
|