Spaces:
Running
Running
✨ add ability to add custom options in CLI
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
app.py
CHANGED
@@ -2,6 +2,9 @@
|
|
2 |
app.py - the main module for the gradio app for summarization
|
3 |
|
4 |
Usage:
|
|
|
|
|
|
|
5 |
python app.py --help
|
6 |
|
7 |
Environment Variables:
|
@@ -18,6 +21,8 @@ import logging
|
|
18 |
import os
|
19 |
import random
|
20 |
import re
|
|
|
|
|
21 |
import time
|
22 |
from pathlib import Path
|
23 |
|
@@ -52,7 +57,7 @@ _here = Path(__file__).parent
|
|
52 |
nltk.download("punkt", force=True, quiet=True)
|
53 |
nltk.download("popular", force=True, quiet=True)
|
54 |
|
55 |
-
|
56 |
MODEL_OPTIONS = [
|
57 |
"pszemraj/long-t5-tglobal-base-16384-book-summary",
|
58 |
"pszemraj/long-t5-tglobal-base-sci-simplify",
|
@@ -60,6 +65,14 @@ MODEL_OPTIONS = [
|
|
60 |
"pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1",
|
61 |
"pszemraj/pegasus-x-large-book-summary",
|
62 |
] # models users can choose from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
|
65 |
AGGREGATE_MODEL = "MBZUAI/LaMini-Flan-T5-783M" # model to use for aggregation
|
@@ -67,8 +80,11 @@ AGGREGATE_MODEL = "MBZUAI/LaMini-Flan-T5-783M" # model to use for aggregation
|
|
67 |
# if duplicating space: uncomment this line to adjust the max words
|
68 |
# os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
|
69 |
# os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
|
|
|
70 |
|
71 |
-
aggregator = BatchAggregator(
|
|
|
|
|
72 |
|
73 |
|
74 |
def aggregate_text(
|
@@ -364,10 +380,11 @@ def load_uploaded_file(file_obj, max_pages: int = 20, lower: bool = False) -> st
|
|
364 |
def parse_args():
|
365 |
"""arguments for the command line interface"""
|
366 |
parser = argparse.ArgumentParser(
|
367 |
-
description="Document Summarization with Long-Document Transformers Demo",
|
368 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
369 |
-
epilog="Runs a local-only web
|
370 |
)
|
|
|
371 |
parser.add_argument(
|
372 |
"--share",
|
373 |
dest="share",
|
@@ -379,16 +396,34 @@ def parse_args():
|
|
379 |
"--model",
|
380 |
type=str,
|
381 |
default=None,
|
382 |
-
help=f"Add a custom model to the list of models: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
)
|
384 |
parser.add_argument(
|
385 |
"-level",
|
386 |
-
"--
|
387 |
type=str,
|
388 |
default="INFO",
|
389 |
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
390 |
help="Set the logging level",
|
391 |
)
|
|
|
|
|
|
|
|
|
392 |
return parser.parse_args()
|
393 |
|
394 |
|
@@ -397,11 +432,19 @@ if __name__ == "__main__":
|
|
397 |
logger = logging.getLogger(__name__)
|
398 |
args = parse_args()
|
399 |
logger.setLevel(args.log_level)
|
400 |
-
logger.info(f"args: {args}")
|
|
|
|
|
401 |
if args.model is not None:
|
402 |
logger.info(f"Adding model {args.model} to the list of models")
|
403 |
MODEL_OPTIONS.append(args.model)
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
logger.info("Loading OCR model")
|
406 |
with contextlib.redirect_stdout(None):
|
407 |
ocr_model = ocr_predictor(
|
@@ -410,11 +453,14 @@ if __name__ == "__main__":
|
|
410 |
pretrained=True,
|
411 |
assume_straight_pages=True,
|
412 |
)
|
|
|
|
|
413 |
name_to_path = load_example_filenames(_here / "examples")
|
414 |
logger.info(f"Loaded {len(name_to_path)} examples")
|
415 |
|
416 |
demo = gr.Blocks(title="Document Summarization with Long-Document Transformers")
|
417 |
_examples = list(name_to_path.keys())
|
|
|
418 |
with demo:
|
419 |
gr.Markdown("# Document Summarization with Long-Document Transformers")
|
420 |
gr.Markdown(
|
@@ -436,9 +482,9 @@ if __name__ == "__main__":
|
|
436 |
label="Model Name",
|
437 |
)
|
438 |
num_beams = gr.Radio(
|
439 |
-
choices=
|
440 |
label="Beam Search: # of Beams",
|
441 |
-
value=
|
442 |
)
|
443 |
load_examples_button = gr.Button(
|
444 |
"Load Example in Dropdown",
|
@@ -542,9 +588,10 @@ if __name__ == "__main__":
|
|
542 |
step=0.05,
|
543 |
)
|
544 |
token_batch_length = gr.Radio(
|
545 |
-
choices=
|
546 |
label="token batch length",
|
547 |
-
|
|
|
548 |
)
|
549 |
|
550 |
with gr.Row(variant="compact"):
|
|
|
2 |
app.py - the main module for the gradio app for summarization
|
3 |
|
4 |
Usage:
|
5 |
+
app.py [-h] [--share] [-m MODEL] [-nb ADD_BEAM_OPTION] [-batch TOKEN_BATCH_OPTION]
|
6 |
+
[-level {DEBUG,INFO,WARNING,ERROR}]
|
7 |
+
Details:
|
8 |
python app.py --help
|
9 |
|
10 |
Environment Variables:
|
|
|
21 |
import os
|
22 |
import random
|
23 |
import re
|
24 |
+
import pprint as pp
|
25 |
+
import sys
|
26 |
import time
|
27 |
from pathlib import Path
|
28 |
|
|
|
57 |
nltk.download("punkt", force=True, quiet=True)
|
58 |
nltk.download("popular", force=True, quiet=True)
|
59 |
|
60 |
+
# Constants & Globals
|
61 |
MODEL_OPTIONS = [
|
62 |
"pszemraj/long-t5-tglobal-base-16384-book-summary",
|
63 |
"pszemraj/long-t5-tglobal-base-sci-simplify",
|
|
|
65 |
"pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1",
|
66 |
"pszemraj/pegasus-x-large-book-summary",
|
67 |
] # models users can choose from
|
68 |
+
BEAM_OPTIONS = [2, 3, 4] # beam sizes users can choose from
|
69 |
+
TOKEN_BATCH_OPTIONS = [
|
70 |
+
1024,
|
71 |
+
1536,
|
72 |
+
2048,
|
73 |
+
2560,
|
74 |
+
3072,
|
75 |
+
] # token batch sizes users can choose from
|
76 |
|
77 |
SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
|
78 |
AGGREGATE_MODEL = "MBZUAI/LaMini-Flan-T5-783M" # model to use for aggregation
|
|
|
80 |
# if duplicating space: uncomment this line to adjust the max words
|
81 |
# os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
|
82 |
# os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
|
83 |
+
# os.environ["APP_AGG_FORCE_CPU"] = str(1) # force cpu for aggregation
|
84 |
|
85 |
+
aggregator = BatchAggregator(
|
86 |
+
AGGREGATE_MODEL, force_cpu=os.environ.get("APP_AGG_FORCE_CPU", False)
|
87 |
+
)
|
88 |
|
89 |
|
90 |
def aggregate_text(
|
|
|
380 |
def parse_args():
|
381 |
"""arguments for the command line interface"""
|
382 |
parser = argparse.ArgumentParser(
|
383 |
+
description="Document Summarization with Long-Document Transformers - Demo",
|
384 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
385 |
+
epilog="Runs a local-only web UI to summarize documents. pass --share for a public link to share.",
|
386 |
)
|
387 |
+
|
388 |
parser.add_argument(
|
389 |
"--share",
|
390 |
dest="share",
|
|
|
396 |
"--model",
|
397 |
type=str,
|
398 |
default=None,
|
399 |
+
help=f"Add a custom model to the list of models: {pp.pformat(MODEL_OPTIONS, compact=True)}",
|
400 |
+
)
|
401 |
+
parser.add_argument(
|
402 |
+
"-nb",
|
403 |
+
"--add_beam_option",
|
404 |
+
type=int,
|
405 |
+
default=None,
|
406 |
+
help=f"Add a beam search option to the list of beam search options: {pp.pformat(BEAM_OPTIONS, compact=True)}",
|
407 |
+
)
|
408 |
+
parser.add_argument(
|
409 |
+
"-batch",
|
410 |
+
"--token_batch_option",
|
411 |
+
type=int,
|
412 |
+
default=None,
|
413 |
+
help=f"Add a token batch option to the list of token batch options: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
|
414 |
)
|
415 |
parser.add_argument(
|
416 |
"-level",
|
417 |
+
"--log_level",
|
418 |
type=str,
|
419 |
default="INFO",
|
420 |
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
421 |
help="Set the logging level",
|
422 |
)
|
423 |
+
# if "--help" in sys.argv or "-h" in sys.argv:
|
424 |
+
# parser.print_help()
|
425 |
+
# sys.exit(0)
|
426 |
+
|
427 |
return parser.parse_args()
|
428 |
|
429 |
|
|
|
432 |
logger = logging.getLogger(__name__)
|
433 |
args = parse_args()
|
434 |
logger.setLevel(args.log_level)
|
435 |
+
logger.info(f"args: {pp.pformat(args.__dict__, compact=True)}")
|
436 |
+
|
437 |
+
# add any custom options
|
438 |
if args.model is not None:
|
439 |
logger.info(f"Adding model {args.model} to the list of models")
|
440 |
MODEL_OPTIONS.append(args.model)
|
441 |
+
if args.add_beam_option is not None:
|
442 |
+
logger.info(f"Adding beam search option {args.add_beam_option} to the list")
|
443 |
+
BEAM_OPTIONS.append(args.add_beam_option)
|
444 |
+
if args.token_batch_option is not None:
|
445 |
+
logger.info(f"Adding token batch option {args.token_batch_option} to the list")
|
446 |
+
TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
|
447 |
+
|
448 |
logger.info("Loading OCR model")
|
449 |
with contextlib.redirect_stdout(None):
|
450 |
ocr_model = ocr_predictor(
|
|
|
453 |
pretrained=True,
|
454 |
assume_straight_pages=True,
|
455 |
)
|
456 |
+
|
457 |
+
# load the examples
|
458 |
name_to_path = load_example_filenames(_here / "examples")
|
459 |
logger.info(f"Loaded {len(name_to_path)} examples")
|
460 |
|
461 |
demo = gr.Blocks(title="Document Summarization with Long-Document Transformers")
|
462 |
_examples = list(name_to_path.keys())
|
463 |
+
logger.info("Starting app instance")
|
464 |
with demo:
|
465 |
gr.Markdown("# Document Summarization with Long-Document Transformers")
|
466 |
gr.Markdown(
|
|
|
482 |
label="Model Name",
|
483 |
)
|
484 |
num_beams = gr.Radio(
|
485 |
+
choices=BEAM_OPTIONS,
|
486 |
label="Beam Search: # of Beams",
|
487 |
+
value=BEAM_OPTIONS[0],
|
488 |
)
|
489 |
load_examples_button = gr.Button(
|
490 |
"Load Example in Dropdown",
|
|
|
588 |
step=0.05,
|
589 |
)
|
590 |
token_batch_length = gr.Radio(
|
591 |
+
choices=TOKEN_BATCH_OPTIONS,
|
592 |
label="token batch length",
|
593 |
+
# select median option
|
594 |
+
value=TOKEN_BATCH_OPTIONS[len(TOKEN_BATCH_OPTIONS) // 2],
|
595 |
)
|
596 |
|
597 |
with gr.Row(variant="compact"):
|