""" app.py - the main module for the gradio app Usage: python app.py Environment Variables: USE_TORCH (str): whether to use torch (1) or not (0) TOKENIZERS_PARALLELISM (str): whether to use parallelism (true) or not (false) Optional Environment Variables: APP_MAX_WORDS (int): the maximum number of words to use for summarization APP_OCR_MAX_PAGES (int): the maximum number of pages to use for OCR """ import contextlib import gc import logging import os import random import re import time from pathlib import Path os.environ["USE_TORCH"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", ) import gradio as gr import nltk import torch from cleantext import clean from doctr.models import ocr_predictor from pdf2text import convert_PDF_to_Text from summarize import load_model_and_tokenizer, summarize_via_tokenbatches from utils import ( load_example_filenames, saves_summary, textlist2html, truncate_word_count, ) _here = Path(__file__).parent nltk.download("punkt", force=True, quiet=True) nltk.download("popular", force=True, quiet=True) MODEL_OPTIONS = [ "pszemraj/long-t5-tglobal-base-16384-book-summary", "pszemraj/long-t5-tglobal-base-sci-simplify", "pszemraj/long-t5-tglobal-base-sci-simplify-elife", "pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1", "pszemraj/pegasus-x-large-book-summary", ] # models users can choose from # if duplicating space,, uncomment this line to adjust the max words # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048 # os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40 def predict( input_text: str, model_name: str, token_batch_length: int = 1024, empty_cache: bool = True, **settings, ) -> list: """ predict - helper fn to support multiple models for summarization at once :param str input_text: the input text to summarize :param str model_name: model name to use :param int token_batch_length: the length of the token batches to use :param bool empty_cache: whether to empty the cache before loading a new= model :return: list of dicts with keys "summary" and "score" """ if torch.cuda.is_available() and empty_cache: torch.cuda.empty_cache() model, tokenizer = load_model_and_tokenizer(model_name) summaries = summarize_via_tokenbatches( input_text, model, tokenizer, batch_length=token_batch_length, **settings, ) del model del tokenizer gc.collect() return summaries def proc_submission( input_text: str, model_name: str, num_beams: int, token_batch_length: int, length_penalty: float, repetition_penalty: float, no_repeat_ngram_size: int, max_input_length: int = 6144, ): """ proc_submission - a helper function for the gradio module to process submissions Args: input_text (str): the input text to summarize model_name (str): the hf model tag of the model to use num_beams (int): the number of beams to use token_batch_length (int): the length of the token batches to use length_penalty (float): the length penalty to use repetition_penalty (float): the repetition penalty to use no_repeat_ngram_size (int): the no repeat ngram size to use max_input_length (int, optional): the maximum input length to use. Defaults to 6144. Note: the max_input_length is set to 6144 by default, but can be changed by setting the environment variable APP_MAX_WORDS to a different value. Returns: str in HTML format, string of the summary, str of score """ settings = { "length_penalty": float(length_penalty), "repetition_penalty": float(repetition_penalty), "no_repeat_ngram_size": int(no_repeat_ngram_size), "encoder_no_repeat_ngram_size": 4, "num_beams": int(num_beams), "min_length": 4, "max_length": int(token_batch_length // 4), "early_stopping": True, "do_sample": False, } max_input_length = int(os.environ.get("APP_MAX_WORDS", max_input_length)) logging.info(f"max_input_length set to: {max_input_length}") st = time.perf_counter() history = {} clean_text = clean(input_text, lower=False) processed = truncate_word_count(clean_text, max_words=max_input_length) if processed["was_truncated"]: tr_in = processed["truncated_text"] # create elaborate HTML warning input_wc = re.split(r"\s+", input_text) msg = f"""
Input text was truncated to {max_input_length} words. That's about {100*max_input_length/len(input_wc):.2f}% of the submission.
Input text is too short to summarize. Detected {len(input_text)} characters. Please load text by selecting an example from the dropdown menu or by pasting text into the text box.
Runtime: {rt} minutes with model: {model_name}
" if msg is not None: html += msg html += "" # save to file settings["model_name"] = model_name saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings) return html, full_summary, scores_out, saved_file def load_single_example_text( example_path: str or Path, max_pages: int = 20, ) -> str: """ load_single_example_text - loads a single example text file :param strorPath example_path: name of the example to load :param int max_pages: the maximum number of pages to load from a PDF :return str: the text of the example """ global name_to_path full_ex_path = name_to_path[example_path] full_ex_path = Path(full_ex_path) if full_ex_path.suffix in [".txt", ".md"]: with open(full_ex_path, "r", encoding="utf-8", errors="ignore") as f: raw_text = f.read() text = clean(raw_text, lower=False) elif full_ex_path.suffix == ".pdf": logging.info(f"Loading PDF file {full_ex_path}") max_pages = int(os.environ.get("APP_MAX_PAGES", max_pages)) logging.info(f"max_pages set to: {max_pages}") conversion_stats = convert_PDF_to_Text( full_ex_path, ocr_model=ocr_model, max_pages=max_pages, ) text = conversion_stats["converted_text"] else: logging.error(f"Unknown file type {full_ex_path.suffix}") text = "ERROR - check example path" return text def load_uploaded_file(file_obj, max_pages: int = 20, lower: bool = False) -> str: """ load_uploaded_file - loads a file uploaded by the user :param file_obj (POTENTIALLY list): Gradio file object inside a list :param int max_pages: the maximum number of pages to load from a PDF :param bool lower: whether to lowercase the text :return str: the text of the file """ logger = logging.getLogger(__name__) # check if mysterious file object is a list if isinstance(file_obj, list): file_obj = file_obj[0] file_path = Path(file_obj.name) try: logger.info(f"Loading file:\t{file_path}") if file_path.suffix in [".txt", ".md"]: with open(file_path, "r", encoding="utf-8", errors="ignore") as f: raw_text = f.read() text = clean(raw_text, lower=lower) elif file_path.suffix == ".pdf": logger.info(f"loading as PDF file {file_path}") max_pages = int(os.environ.get("APP_MAX_PAGES", max_pages)) logger.info(f"max_pages set to: {max_pages}") conversion_stats = convert_PDF_to_Text( file_path, ocr_model=ocr_model, max_pages=max_pages, ) text = conversion_stats["converted_text"] else: logger.error(f"Unknown file type:\t{file_path.suffix}") text = "ERROR - check file - unknown file type. PDF, TXT, and MD are supported." return text except Exception as e: logger.error(f"Trying to load file:\t{file_path},\nerror:\t{e}") return "Error: Could not read file. Ensure that it is a valid text file with encoding UTF-8 if text, and a PDF if PDF." if __name__ == "__main__": logger = logging.getLogger(__name__) logger.info("Starting app instance") logger.info("Loading OCR model") with contextlib.redirect_stdout(None): ocr_model = ocr_predictor( "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True, ) name_to_path = load_example_filenames(_here / "examples") logger.info(f"Loaded {len(name_to_path)} examples") demo = gr.Blocks(title="Document Summarization with Long-Document Transformers") _examples = list(name_to_path.keys()) with demo: gr.Markdown("# Document Summarization with Long-Document Transformers") gr.Markdown( "An example use case for fine-tuned long document transformers. Model(s) are trained on [book summaries](https://huggingface.co/datasets/kmfoda/booksum). Architectures in this demo are [LongT5-base](https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary) and [Pegasus-X-Large](https://huggingface.co/pszemraj/pegasus-x-large-book-summary)." ) with gr.Column(): gr.Markdown("## Load Inputs & Select Parameters") gr.Markdown( """Enter/paste text below, or upload a file. Pick a model & adjust params (_optional_), and press **Summarize!** See [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for details. """ ) with gr.Row(variant="compact"): with gr.Column(scale=0.5, variant="compact"): model_name = gr.Dropdown( choices=MODEL_OPTIONS, value=MODEL_OPTIONS[0], label="Model Name", ) num_beams = gr.Radio( choices=[2, 3, 4], label="Beam Search: # of Beams", value=2, ) load_examples_button = gr.Button( "Load Example in Dropdown", ) load_file_button = gr.Button("Load an Uploaded File") with gr.Column(variant="compact"): example_name = gr.Dropdown( _examples, label="Examples", value=random.choice(_examples), ) uploaded_file = gr.File( label="File Upload", file_count="single", file_types=[".txt", ".md", ".pdf"], type="file", ) with gr.Row(): input_text = gr.Textbox( lines=4, max_lines=12, label="Input Text (for summarization)", placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)", ) gr.Markdown("---") with gr.Column(): gr.Markdown("## Generate Summary") gr.Markdown( "_Summarization should take ~1-2 minutes for most settings, but may extend up to 5-10 minutes in some scenarios._" ) summarize_button = gr.Button( "Summarize!", variant="primary", ) output_text = gr.HTML("Output will appear below:
") with gr.Column(): gr.Markdown("#### Results & Scores") with gr.Row(): with gr.Column(variant="compact"): gr.Markdown( "Download the summary as a text file, with parameters and scores." ) text_file = gr.File( label="Download as Text File", file_count="single", type="file", interactive=False, ) with gr.Column(variant="compact"): gr.Markdown( "Scores represent the summary quality **roughly** as a measure of the model's 'confidence'. less-negative numbers (closer to 0) are better." ) summary_scores = gr.Textbox( label="Summary Scores", placeholder="Summary scores will appear here", ) gr.Markdown("#### **Summary Output**") summary_text = gr.HTML( label="Summary", value="Summary will appear here!" ) gr.Markdown("---") with gr.Column(): gr.Markdown("### Advanced Settings") gr.Markdown( "Refer to [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for what these are, and how they impact _quality_ and _speed_." ) with gr.Row(variant="compact"): length_penalty = gr.Slider( minimum=0.5, maximum=1.0, label="length penalty", value=0.7, step=0.05, ) token_batch_length = gr.Radio( choices=[1024, 1536, 2048, 2560, 3072], label="token batch length", value=2048, ) with gr.Row(variant="compact"): repetition_penalty = gr.Slider( minimum=1.0, maximum=5.0, label="repetition penalty", value=1.5, step=0.1, ) no_repeat_ngram_size = gr.Radio( choices=[2, 3, 4], label="no repeat ngram size", value=3, ) with gr.Column(): gr.Markdown("### About") gr.Markdown( "- Models are fine-tuned on the [BookSum dataset](https://arxiv.org/abs/2105.08209). The goal was to create a model that generalizes well and is useful for summarizing text in academic and everyday use." ) gr.Markdown( "- _Update April 2023:_ Additional models fine-tuned on the [PLOS](https://huggingface.co/datasets/pszemraj/scientific_lay_summarisation-plos-norm) and [ELIFE](https://huggingface.co/datasets/pszemraj/scientific_lay_summarisation-elife-norm) subsets of the [scientific lay summaries](https://arxiv.org/abs/2210.09932) dataset are available (see dropdown at the top)." ) gr.Markdown( "Adjust the max input words & max PDF pages for OCR by duplicating this space and [setting the environment variables](https://huggingface.co/docs/hub/spaces-overview#managing-secrets) `APP_MAX_WORDS` and `APP_OCR_MAX_PAGES` to the desired integer values." ) gr.Markdown("---") load_examples_button.click( fn=load_single_example_text, inputs=[example_name], outputs=[input_text] ) load_file_button.click( fn=load_uploaded_file, inputs=uploaded_file, outputs=[input_text] ) summarize_button.click( fn=proc_submission, inputs=[ input_text, model_name, num_beams, token_batch_length, length_penalty, repetition_penalty, no_repeat_ngram_size, ], outputs=[output_text, summary_text, summary_scores, text_file], ) demo.launch(enable_queue=True)