import logging
import time
from pathlib import Path
import gradio as gr
import nltk
from cleantext import clean
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
from utils import load_example_filenames, truncate_word_count
_here = Path(__file__).parent
nltk.download("stopwords") # TODO=find where this requirement originates from
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def proc_submission(
input_text: str,
model_size: str,
num_beams,
token_batch_length,
length_penalty,
repetition_penalty,
no_repeat_ngram_size,
max_input_length: int = 768,
):
"""
proc_submission - a helper function for the gradio module to process submissions
Args:
input_text (str): the input text to summarize
model_size (str): the size of the model to use
num_beams (int): the number of beams to use
token_batch_length (int): the length of the token batches to use
length_penalty (float): the length penalty to use
repetition_penalty (float): the repetition penalty to use
no_repeat_ngram_size (int): the no repeat ngram size to use
max_input_length (int, optional): the maximum input length to use. Defaults to 768.
Returns:
str in HTML format, string of the summary, str of score
"""
settings = {
"length_penalty": float(length_penalty),
"repetition_penalty": float(repetition_penalty),
"no_repeat_ngram_size": int(no_repeat_ngram_size),
"encoder_no_repeat_ngram_size": 4,
"num_beams": int(num_beams),
"min_length": 4,
"max_length": int(token_batch_length // 4),
"early_stopping": True,
"do_sample": False,
}
st = time.perf_counter()
history = {}
clean_text = clean(input_text, lower=False)
max_input_length = 1024 if model_size == "base" else max_input_length
processed = truncate_word_count(clean_text, max_input_length)
if processed["was_truncated"]:
tr_in = processed["truncated_text"]
msg = f"Input text was truncated to {max_input_length} words (based on whitespace)"
logging.warning(msg)
history["WARNING"] = msg
else:
tr_in = input_text
msg = None
_summaries = summarize_via_tokenbatches(
tr_in,
model_sm if model_size == "base" else model,
tokenizer_sm if model_size == "base" else tokenizer,
batch_length=token_batch_length,
**settings,
)
sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)]
sum_scores = [
f" - Section {i}: {round(s['summary_score'],4)}"
for i, s in enumerate(_summaries)
]
sum_text_out = "\n".join(sum_text)
history["Summary Scores"] = "
"
scores_out = "\n".join(sum_scores)
rt = round((time.perf_counter() - st) / 60, 2)
print(f"Runtime: {rt} minutes")
html = ""
html += f"
Runtime: {rt} minutes on CPU
" if msg is not None: html += f"Output will appear below:
") gr.Markdown("### Summary Output") summary_text = gr.Textbox( label="Summary", placeholder="The generated summary will appear here" ) gr.Markdown( "The summary scores can be thought of as representing the quality of the summary. less-negative numbers (closer to 0) are better:" ) summary_scores = gr.Textbox( label="Summary Scores", placeholder="Summary scores will appear here" ) with gr.Column(): gr.Markdown("## About the Model") gr.Markdown( "- [This model](https://huggingface.co/pszemraj/led-large-book-summary) is a fine-tuned checkpoint of [allenai/led-large-16384](https://huggingface.co/allenai/led-large-16384) on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage." ) gr.Markdown( "- The model can be used with tag [pszemraj/led-large-book-summary](https://huggingface.co/pszemraj/led-large-book-summary). See the model card for details on usage & a notebook for a tutorial." ) load_examples_button.click( fn=load_single_example_text, inputs=[example_name], outputs=[input_text] ) load_file_button.click( fn=load_uploaded_file, inputs=[uploaded_file], outputs=[input_text] ) summarize_button.click( fn=proc_submission, inputs=[ input_text, model_size, num_beams, token_batch_length, length_penalty, repetition_penalty, no_repeat_ngram_size, ], outputs=[output_text, summary_text, summary_scores], ) demo.launch(enable_queue=True, prevent_thread_lock=True)