pszemraj commited on
Commit
9c491e8
·
1 Parent(s): 9cdd2ba

✨ basic CLI

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (1) hide show
  1. app.py +43 -4
app.py CHANGED
@@ -11,6 +11,7 @@ Optional Environment Variables:
11
  APP_MAX_WORDS (int): the maximum number of words to use for summarization
12
  APP_OCR_MAX_PAGES (int): the maximum number of pages to use for OCR
13
  """
 
14
  import contextlib
15
  import gc
16
  import logging
@@ -72,13 +73,15 @@ aggregator = BatchAggregator("MBZUAI/LaMini-Flan-T5-783M")
72
  def aggregate_text(
73
  summary_text: str,
74
  text_file: gr.inputs.File = None,
75
- ):
76
  """
77
  Aggregate the text from the batches.
78
 
79
  NOTE: you should probably include passing the BatchAggregator object as a parameter if using this code
80
  outside of this file.
81
  :param batches_html: The batches to aggregate, in html format
 
 
82
  """
83
  if summary_text is None or summary_text == SUMMARY_PLACEHOLDER:
84
  logging.error("No text provided. Make sure a summary has been generated first.")
@@ -292,7 +295,7 @@ def load_single_example_text(
292
  :param int max_pages: the maximum number of pages to load from a PDF
293
  :return str: the text of the example
294
  """
295
- global name_to_path
296
  full_ex_path = name_to_path[example_path]
297
  full_ex_path = Path(full_ex_path)
298
  if full_ex_path.suffix in [".txt", ".md"]:
@@ -325,7 +328,7 @@ def load_uploaded_file(file_obj, max_pages: int = 20, lower: bool = False) -> st
325
  :param bool lower: whether to lowercase the text
326
  :return str: the text of the file
327
  """
328
-
329
  logger = logging.getLogger(__name__)
330
  # check if mysterious file object is a list
331
  if isinstance(file_obj, list):
@@ -357,8 +360,44 @@ def load_uploaded_file(file_obj, max_pages: int = 20, lower: bool = False) -> st
357
  return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8 if text, and a PDF if PDF."
358
 
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  if __name__ == "__main__":
 
361
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
362
  logger.info("Starting app instance")
363
  logger.info("Loading OCR model")
364
  with contextlib.redirect_stdout(None):
@@ -538,4 +577,4 @@ if __name__ == "__main__":
538
  inputs=[summary_text, text_file],
539
  outputs=[aggregated_summary],
540
  )
541
- demo.launch(enable_queue=True, share=True)
 
11
  APP_MAX_WORDS (int): the maximum number of words to use for summarization
12
  APP_OCR_MAX_PAGES (int): the maximum number of pages to use for OCR
13
  """
14
+ import argparse
15
  import contextlib
16
  import gc
17
  import logging
 
73
  def aggregate_text(
74
  summary_text: str,
75
  text_file: gr.inputs.File = None,
76
+ ) -> str:
77
  """
78
  Aggregate the text from the batches.
79
 
80
  NOTE: you should probably include passing the BatchAggregator object as a parameter if using this code
81
  outside of this file.
82
  :param batches_html: The batches to aggregate, in html format
83
+ :param text_file: The text file to append the aggregate summary to
84
+ :return: The aggregate summary in html format
85
  """
86
  if summary_text is None or summary_text == SUMMARY_PLACEHOLDER:
87
  logging.error("No text provided. Make sure a summary has been generated first.")
 
295
  :param int max_pages: the maximum number of pages to load from a PDF
296
  :return str: the text of the example
297
  """
298
+ global name_to_path, ocr_model
299
  full_ex_path = name_to_path[example_path]
300
  full_ex_path = Path(full_ex_path)
301
  if full_ex_path.suffix in [".txt", ".md"]:
 
328
  :param bool lower: whether to lowercase the text
329
  :return str: the text of the file
330
  """
331
+ global ocr_model
332
  logger = logging.getLogger(__name__)
333
  # check if mysterious file object is a list
334
  if isinstance(file_obj, list):
 
360
  return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8 if text, and a PDF if PDF."
361
 
362
 
363
+ def parse_args():
364
+ parser = argparse.ArgumentParser(
365
+ description="Document Summarization with Long-Document Transformers",
366
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
367
+ )
368
+ parser.add_argument(
369
+ "--share",
370
+ dest="share",
371
+ action="store_true",
372
+ help="Create a public link to share",
373
+ )
374
+ parser.add_argument(
375
+ "-m",
376
+ "--model",
377
+ type=str,
378
+ default=None,
379
+ help=f"Add a custom model to the list of models: {', '.join(MODEL_OPTIONS)}",
380
+ )
381
+ parser.add_argument(
382
+ "-level",
383
+ "--log-level",
384
+ type=str,
385
+ default="INFO",
386
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
387
+ help="Set the logging level",
388
+ )
389
+ return parser.parse_args()
390
+
391
+
392
  if __name__ == "__main__":
393
+ """main - the main function of the app"""
394
  logger = logging.getLogger(__name__)
395
+ args = parse_args()
396
+ logger.setLevel(args.log_level)
397
+ logger.info(f"args: {args}")
398
+ if args.model is not None:
399
+ logger.info(f"Adding model {args.model} to the list of models")
400
+ MODEL_OPTIONS.append(args.model)
401
  logger.info("Starting app instance")
402
  logger.info("Loading OCR model")
403
  with contextlib.redirect_stdout(None):
 
577
  inputs=[summary_text, text_file],
578
  outputs=[aggregated_summary],
579
  )
580
+ demo.launch(enable_queue=True, share=args.share)