|
import transformers |
|
import streamlit as st |
|
import nltk |
|
from nltk import sent_tokenize |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import json |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
|
|
nltk.download('punkt') |
|
with open('testbook.json') as f: |
|
test_book = json.load(f) |
|
tokenizer = AutoTokenizer.from_pretrained("UNIST-Eunchan/bart-dnc-booksum") |
|
|
|
@st.cache_resource |
|
def load_model(model_name): |
|
nltk.download('punkt') |
|
sentence_transformer_model = SentenceTransformer("sentence-transformers/all-roberta-large-v1") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("UNIST-Eunchan/bart-dnc-booksum") |
|
return sentence_transformer_model, model |
|
|
|
sentence_transformer_model, model = load_model("UNIST-Eunchan/bart-dnc-booksum") |
|
|
|
def infer(input_ids, max_length, temperature, top_k, top_p): |
|
|
|
output_sequences = model.generate( |
|
input_ids=input_ids, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
do_sample=True, |
|
num_return_sequences=1, |
|
|
|
no_repeat_ngram_size=2 |
|
) |
|
|
|
return output_sequences |
|
|
|
|
|
@st.cache_data |
|
def chunking(book_text): |
|
sentences = sent_tokenize(book_text) |
|
segments = [] |
|
token_lens = [] |
|
|
|
def cos_similarity(v1, v2): |
|
dot_product = np.dot(v1, v2) |
|
l2_norm = (np.sqrt(sum(np.square(v1))) * np.sqrt(sum(np.square(v2)))) |
|
similarity = dot_product / l2_norm |
|
|
|
return similarity |
|
|
|
for sent_i_th in sentences: |
|
token_lens.append(len(tokenizer.tokenize(sent_i_th))) |
|
|
|
current_segment = "" |
|
total_token_lens = 0 |
|
for i in range(len(sentences)): |
|
|
|
if total_token_lens < 512: |
|
total_token_lens += token_lens[i] |
|
current_segment += (sentences[i] + " ") |
|
|
|
elif total_token_lens > 768: |
|
segments.append(current_segment) |
|
current_segment = sentences[i] |
|
total_token_lens = token_lens[i] |
|
|
|
else: |
|
|
|
next_pseudo_segment = "" |
|
next_token_len = 0 |
|
for t in range(10): |
|
if (i+t < len(sentences)) and (next_token_len + token_lens[i+t] < 512): |
|
next_token_len += token_lens[i+t] |
|
next_pseudo_segment += sentences[i+t] |
|
|
|
embs = sentence_transformer_model.encode([current_segment, next_pseudo_segment, sentences[i]]) |
|
if cos_similarity(embs[1],embs[2]) > cos_similarity(embs[0],embs[2]): |
|
segments.append(current_segment) |
|
current_segment = sentences[i] |
|
total_token_lens = token_lens[i] |
|
else: |
|
total_token_lens += token_lens[i] |
|
current_segment += (sentences[i] + " ") |
|
|
|
return segments |
|
|
|
|
|
book_index = 0 |
|
_book = test_book[book_index]['book'][:10000] |
|
|
|
|
|
st.title("Book Summarization π") |
|
st.write("Book Summarization using BART-BOOK! ") |
|
|
|
|
|
docu_size = len(tokenizer.tokenize(_book)) |
|
st.write(f"Document size: {docu_size} tokens") |
|
sent = st.text_area("Text", _book, height = 1500) |
|
max_length = st.sidebar.slider("Max Length", value = 512,min_value = 10, max_value=1024) |
|
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05) |
|
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0) |
|
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.92) |
|
|
|
|
|
chunked_segments = chunking(_book) |
|
|
|
|
|
def generate_output(test_samples): |
|
inputs = tokenizer( |
|
test_samples, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=1024, |
|
return_tensors="pt", |
|
) |
|
|
|
input_ids = inputs.input_ids |
|
|
|
attention_mask = inputs.attention_mask |
|
|
|
outputs = model.generate(input_ids, |
|
max_length = max_length, |
|
min_length=32, |
|
top_p = top_p, |
|
top_k = top_k, |
|
temperature= temperature, |
|
num_beams=2, |
|
no_repeat_ngram_size=2, |
|
attention_mask=attention_mask) |
|
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
return outputs, output_str |
|
|
|
|
|
chunked_segments = chunking(test_book[0]['book']) |
|
st.write(f'The number of Segments: {len(chunked_segments)}.') |
|
|
|
|
|
for i in range(10): |
|
|
|
summaries = generate_output(chunked_segments[i]) |
|
st.write(f'A summary of Segment {i}.') |
|
|
|
chunk_size = len(tokenizer.tokenize(summaries)) |
|
st.write(f"Summary of Segment's size: {chunk_size} tokens") |
|
st.success(summaries[-1]) |
|
|