import transformers import streamlit as st import nltk from nltk import sent_tokenize from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import json import numpy as np from sentence_transformers import SentenceTransformer nltk.download('punkt') with open('testbook.json') as f: test_book = json.load(f) tokenizer = AutoTokenizer.from_pretrained("UNIST-Eunchan/bart-dnc-booksum") @st.cache_resource def load_model(model_name): nltk.download('punkt') sentence_transformer_model = SentenceTransformer("sentence-transformers/all-roberta-large-v1") model = AutoModelForSeq2SeqLM.from_pretrained("UNIST-Eunchan/bart-dnc-booksum") return sentence_transformer_model, model sentence_transformer_model, model = load_model("UNIST-Eunchan/bart-dnc-booksum") def infer(input_ids, max_length, temperature, top_k, top_p): output_sequences = model.generate( input_ids=input_ids, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, num_return_sequences=1, num_beams=4, no_repeat_ngram_size=2 ) return output_sequences def cos_similarity(v1, v2): dot_product = np.dot(v1, v2) l2_norm = (np.sqrt(sum(np.square(v1))) * np.sqrt(sum(np.square(v2)))) similarity = dot_product / l2_norm return similarity @st.cache_data def chunking(book_text): sentences = sent_tokenize(book_text) segments = [] token_lens = [] for sent_i_th in sentences: token_lens.append(len(tokenizer.tokenize(sent_i_th))) #sentences, token_lens current_segment = "" total_token_lens = 0 for i in range(len(sentences)): if total_token_lens < 512: total_token_lens += token_lens[i] current_segment += (sentences[i] + " ") elif total_token_lens > 768: segments.append(current_segment) current_segment = sentences[i] total_token_lens = token_lens[i] else: #make next_pseudo_segment next_pseudo_segment = "" next_token_len = 0 for t in range(10): if (i+t < len(sentences)) and (next_token_len + token_lens[i+t] < 512): next_token_len += token_lens[i+t] next_pseudo_segment += sentences[i+t] embs = sentence_transformer_model.encode([current_segment, next_pseudo_segment, sentences[i]]) # current, next, sent if cos_similarity(embs[1],embs[2]) > cos_similarity(embs[0],embs[2]): segments.append(current_segment) current_segment = sentences[i] total_token_lens = token_lens[i] else: total_token_lens += token_lens[i] current_segment += (sentences[i] + " ") return segments book_index = 0 _book = test_book[book_index]['book'] #prompts st.title("Book Summarization 📚") st.write("The almighty king of text generation, GPT-2 comes in four available sizes, only three of which have been publicly made available. Feared for its fake news generation capabilities, it currently stands as the most syntactically coherent model. A direct successor to the original GPT, it reinforces the already established pre-training/fine-tuning killer duo. From the paper: Language Models are Unsupervised Multitask Learners by Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever.") #book_index = st.sidebar.slider("Select Book Example", value = 0,min_value = 0, max_value=4) sent = st.text_area("Text", _book[:512], height = 550) max_length = st.sidebar.slider("Max Length", value = 512,min_value = 10, max_value=1024) temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05) top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0) top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.92) chunked_segments = chunking(_book) def generate_output(test_samples): inputs = tokenizer( test_samples, padding=max_length, truncation=True, max_length=1024, return_tensors="pt", ) input_ids = inputs.input_ids attention_mask = inputs.attention_mask outputs = model.generate(input_ids, max_length = 256, min_length=32, top_p = 0.92, num_beams=5, no_repeat_ngram_size=2, attention_mask=attention_mask) output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True) return outputs, output_str chunked_segments = chunking(test_book[0]['book']) for segment in range(len(chunked_segments)): summaries = generate_output(segment) st.write(summaries[-1])