import logging from pathlib import Path import os import re import gradio as gr import nltk import torch from cleantext import clean from summarize import load_model_and_tokenizer, summarize_via_tokenbatches _here = Path(__file__).parent nltk.download("stopwords") # TODO=find where this requirement originates from import transformers transformers.logging.set_verbosity_error() logging.basicConfig() def truncate_word_count(text, max_words=512): """ truncate_word_count - a helper function for the gradio module Parameters ---------- text : str, required, the text to be processed max_words : int, optional, the maximum number of words, default=512 Returns ------- dict, the text and whether it was truncated """ # split on whitespace with regex words = re.split(r"\s+", text) processed = {} if len(words) > max_words: processed["was_truncated"] = True processed["truncated_text"] = " ".join(words[:max_words]) else: processed["was_truncated"] = False processed["truncated_text"] = text return processed def proc_submission( input_text: str, num_beams, length_penalty, repetition_penalty, no_repeat_ngram_size, token_batch_length, max_input_length: int = 512, ): """ proc_submission - a helper function for the gradio module Parameters ---------- input_text : str, required, the text to be processed max_input_length : int, optional, the maximum length of the input text, default=512 Returns ------- str of HTML, the interactive HTML form for the model """ settings = { "length_penalty": length_penalty, "repetition_penalty": repetition_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "encoder_no_repeat_ngram_size": 4, "num_beams": num_beams, } history = {} clean_text = clean(input_text, lower=False) processed = truncate_word_count(clean_text, max_input_length) if processed["was_truncated"]: history["input_text"] = processed["truncated_text"] history["was_truncated"] = True msg = f"Input text was truncated to {max_input_length} characters." logging.warning(msg) history["WARNING"] = msg else: history["input_text"] = input_text history["was_truncated"] = False _summaries = summarize_via_tokenbatches( history["input_text"], model, tokenizer, batch_length=token_batch_length, **settings, ) sum_text = [s["summary"][0] for s in _summaries] sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in _summaries] history["Input"] = input_text history["Summary Text"] = "\n\t".join(sum_text) history["Summary Scores"] = "\n".join(sum_scores) html = "" for name, item in history.items(): html += ( f"