File size: 4,991 Bytes
955c8a8
 
98a8436
b48fa68
fc5a865
e07cf78
950ede6
e0b483f
 
2262546
e07cf78
 
fc5a865
 
6b4db0c
955c8a8
6b4db0c
 
fc5a865
6b4db0c
955c8a8
6b4db0c
955c8a8
 
 
 
 
 
 
 
 
 
fc5a865
8622207
fc5a865
955c8a8
fc5a865
955c8a8
e07cf78
944d9b4
2c13377
944d9b4
b48fa68
944d9b4
56b6b47
949eeaf
 
 
 
 
 
 
56b6b47
 
 
944d9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d872090
944d9b4
 
 
 
e0b483f
944d9b4
 
 
 
 
 
 
 
 
 
 
ccb93ff
8622207
d17a7aa
955c8a8
e07cf78
fafffac
ccb93ff
05195f8
 
 
8622207
e07cf78
955c8a8
 
fc5a865
955c8a8
d17a7aa
 
 
 
d872090
 
 
f7c77c9
d872090
 
 
 
9027395
61aa0aa
9027395
61aa0aa
9027395
d872090
f7c77c9
d872090
f7c77c9
 
 
 
d872090
 
 
 
 
 
 
be3ea4d
d872090
be3ea4d
 
 
2c13377
 
05195f8
 
 
2c13377
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import transformers
import streamlit as st
import nltk
from nltk import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import json
import numpy as np
from sentence_transformers import SentenceTransformer

nltk.download('punkt')
with open('testbook.json') as f:
    test_book = json.load(f)
tokenizer = AutoTokenizer.from_pretrained("UNIST-Eunchan/bart-dnc-booksum")

@st.cache_resource
def load_model(model_name):
    nltk.download('punkt')
    sentence_transformer_model = SentenceTransformer("sentence-transformers/all-roberta-large-v1")
    model = AutoModelForSeq2SeqLM.from_pretrained("UNIST-Eunchan/bart-dnc-booksum")
    return sentence_transformer_model, model

sentence_transformer_model, model = load_model("UNIST-Eunchan/bart-dnc-booksum")

def infer(input_ids, max_length, temperature, top_k, top_p):

    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=1,
        #num_beams=4,
        no_repeat_ngram_size=2
    )
    
    return output_sequences


@st.cache_data
def chunking(book_text):
    sentences = sent_tokenize(book_text)
    segments = [] 
    token_lens = []
        
    def cos_similarity(v1, v2):
        dot_product = np.dot(v1, v2)
        l2_norm = (np.sqrt(sum(np.square(v1))) * np.sqrt(sum(np.square(v2))))
        similarity = dot_product / l2_norm     
        
        return similarity

    for sent_i_th in sentences:
        token_lens.append(len(tokenizer.tokenize(sent_i_th)))
    #sentences, token_lens
    current_segment = ""    
    total_token_lens = 0
    for i in range(len(sentences)):
        
        if total_token_lens < 512:
            total_token_lens += token_lens[i]
            current_segment += (sentences[i] + " ")
            
        elif total_token_lens > 768:
            segments.append(current_segment)
            current_segment = sentences[i]
            total_token_lens = token_lens[i]
    
        else:
            #make next_pseudo_segment
            next_pseudo_segment = ""
            next_token_len = 0
            for t in range(10):
                if (i+t < len(sentences)) and (next_token_len + token_lens[i+t] < 512):
                    next_token_len += token_lens[i+t] 
                    next_pseudo_segment += sentences[i+t] 
                
            embs = sentence_transformer_model.encode([current_segment, next_pseudo_segment, sentences[i]]) # current, next, sent
            if cos_similarity(embs[1],embs[2]) > cos_similarity(embs[0],embs[2]):
                segments.append(current_segment)
                current_segment = sentences[i]
                total_token_lens = token_lens[i]
            else: 
                total_token_lens += token_lens[i]
                current_segment += (sentences[i] + " ")

    return segments


book_index = 0
_book = test_book[book_index]['book'][:10000]

#prompts
st.title("Book Summarization 📚")
st.write("Book Summarization using BART-BOOK! ")
#book_index = st.sidebar.slider("Select Book Example", value = 0,min_value = 0, max_value=4)

docu_size = len(tokenizer.tokenize(_book))
st.write(f"Document size: {docu_size} tokens")
sent = st.text_area("Text", _book, height = 1500)
max_length = st.sidebar.slider("Max Length", value = 512,min_value = 10, max_value=1024)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.92)


chunked_segments = chunking(_book)


def generate_output(test_samples):
    inputs = tokenizer(
        test_samples,
        padding='max_length',
        truncation=True,
        max_length=1024,
        return_tensors="pt",
    )
    
    input_ids = inputs.input_ids
    
    attention_mask = inputs.attention_mask
    
    outputs = model.generate(input_ids,
                             max_length = max_length,
                             min_length=32,
                             top_p = top_p,
                             top_k = top_k,
                             temperature= temperature,
                             num_beams=2,
                             no_repeat_ngram_size=2,
                             attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


chunked_segments = chunking(test_book[0]['book'])
st.write(f'The number of Segments: {len(chunked_segments)}.')

#for i in range(len(chunked_segments)):
for i in range(10):
       
    summaries = generate_output(chunked_segments[i])
    st.write(f'A summary of Segment {i}.')
    
    chunk_size = len(tokenizer.tokenize(summaries))
    st.write(f"Summary of Segment's size: {chunk_size} tokens")
    st.success(summaries[-1])