files uploaded
Browse files- app.py +25 -0
- app_func.py +11 -0
- extraction.py +127 -0
app.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import app_func
|
4 |
+
import extraction
|
5 |
+
# import abstraction
|
6 |
+
|
7 |
+
st.title("Document Summarizer")
|
8 |
+
|
9 |
+
sidebar = st.sidebar
|
10 |
+
|
11 |
+
page = sidebar.radio("Navigation", ["About Project", "Summarization", "Comparision"])
|
12 |
+
|
13 |
+
if page == "About Project":
|
14 |
+
pass
|
15 |
+
|
16 |
+
elif page == "Summarization":
|
17 |
+
uploaded_doc = st.file_uploader("Upload a document", type=["txt", "pdf"])
|
18 |
+
|
19 |
+
if uploaded_doc is not None:
|
20 |
+
doc_text = app_func.read_uploaded_doc(uploaded_doc)
|
21 |
+
|
22 |
+
extractive_summary = extraction.extract_summary(doc_text)
|
23 |
+
# st.write("Extractive Summary:",extractive_summary)
|
24 |
+
# abstractive_summary = abstraction.generate_summary(extractive_summary)
|
25 |
+
st.write("Summary:", extractive_summary)
|
app_func.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PyPDF2
|
2 |
+
|
3 |
+
def read_uploaded_doc(uploaded_doc):
|
4 |
+
# if uploaded_doc.type == "aplication/pdf":
|
5 |
+
reader = PyPDF2.PdfReader(uploaded_doc)
|
6 |
+
text = ""
|
7 |
+
for page in reader.pages:
|
8 |
+
text += page.extract_text()
|
9 |
+
return text
|
10 |
+
# else: # Assuming text file
|
11 |
+
# return str(uploaded_doc.read(), "utf-8")
|
extraction.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
import os
|
6 |
+
import spacy
|
7 |
+
import certifi
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
+
os.environ['SSL_CERT_FILE'] = certifi.where()
|
11 |
+
|
12 |
+
nlp = spacy.load("en_core_web_lg")
|
13 |
+
|
14 |
+
model_name = "microsoft/MiniLM-L12-H384-uncased"
|
15 |
+
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
+
|
18 |
+
def mean_pooling(model_output, attention_mask):
|
19 |
+
token_embeddings = model_output[0]
|
20 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
21 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
22 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
23 |
+
return sum_embeddings/sum_mask
|
24 |
+
|
25 |
+
class SentenceBERTClassifier(torch.nn.Module):
|
26 |
+
def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384):
|
27 |
+
super(SentenceBERTClassifier, self).__init__()
|
28 |
+
self.model = AutoModel.from_pretrained(model_name)
|
29 |
+
self.dense1 = torch.nn.Linear(input_dim*3, 768)
|
30 |
+
self.relu1 = torch.nn.ReLU()
|
31 |
+
self.dropout1 = torch.nn.Dropout(0.1)
|
32 |
+
self.dense2 = torch.nn.Linear(768, 384)
|
33 |
+
self.relu2 = torch.nn.ReLU()
|
34 |
+
self.dropout2 = torch.nn.Dropout(0.1)
|
35 |
+
self.classifier = torch.nn.Linear(384, 1)
|
36 |
+
self.sigmoid = torch.nn.Sigmoid()
|
37 |
+
|
38 |
+
def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
|
39 |
+
sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask)
|
40 |
+
sent_embedding = mean_pooling(sent_output, sent_mask)
|
41 |
+
|
42 |
+
doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask)
|
43 |
+
doc_embedding = mean_pooling(doc_output, doc_mask)
|
44 |
+
|
45 |
+
combined_embedding = sent_embedding * doc_embedding
|
46 |
+
concat_embedding = torch.cat((sent_embedding, doc_embedding, combined_embedding), dim=1)
|
47 |
+
|
48 |
+
|
49 |
+
dense_output1 = self.dense1(concat_embedding)
|
50 |
+
relu_output1 = self.relu1(dense_output1)
|
51 |
+
dropout_output1 = self.dropout1(relu_output1)
|
52 |
+
dense_output2 = self.dense2(dropout_output1)
|
53 |
+
relu_output2 = self.relu2(dense_output2)
|
54 |
+
dropout_output2 = self.dropout2(relu_output2)
|
55 |
+
logits = self.classifier(dropout_output2)
|
56 |
+
probs = self.sigmoid(logits)
|
57 |
+
return probs
|
58 |
+
|
59 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
60 |
+
|
61 |
+
extractive_model = SentenceBERTClassifier(model_name=model_name)
|
62 |
+
extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) ))
|
63 |
+
extractive_model.eval()
|
64 |
+
|
65 |
+
def get_tokens(text, tokenizer):
|
66 |
+
inputs = tokenizer.batch_encode_plus(
|
67 |
+
text
|
68 |
+
, add_special_tokens=True
|
69 |
+
, max_length = 512
|
70 |
+
, padding="max_length"
|
71 |
+
, return_token_type_ids=True
|
72 |
+
, truncation=True
|
73 |
+
, return_tensors="pt")
|
74 |
+
|
75 |
+
ids = inputs["input_ids"]
|
76 |
+
mask = inputs["attention_mask"]
|
77 |
+
|
78 |
+
return ids, mask
|
79 |
+
|
80 |
+
# Predicting the relevance scores of sentences in a document
|
81 |
+
def predict(model,sents, doc):
|
82 |
+
sent_id, sent_mask = get_tokens(sents,tokenizer)
|
83 |
+
sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)
|
84 |
+
|
85 |
+
doc_id, doc_mask = get_tokens([doc],tokenizer)
|
86 |
+
doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1)
|
87 |
+
doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
|
88 |
+
|
89 |
+
# 3. Handle OOV tokens
|
90 |
+
# Replace OOV tokens with the 'unk' token ID before passing to the model
|
91 |
+
sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
|
92 |
+
doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
|
93 |
+
|
94 |
+
preds = model(sent_id, doc_id, sent_mask, doc_mask)
|
95 |
+
return preds
|
96 |
+
|
97 |
+
def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4):
|
98 |
+
doc = doc.replace("\n","")
|
99 |
+
doc_sentences = []
|
100 |
+
for sent in nlp(doc).sents:
|
101 |
+
if len(sent) > min_sentence_length:
|
102 |
+
doc_sentences.append(str(sent))
|
103 |
+
|
104 |
+
# doc_id, doc_mask = get_tokens([doc],tokenizer)
|
105 |
+
# doc_id, doc_mask = doc_id * batch_size, doc_mask* batch_size
|
106 |
+
# doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
|
107 |
+
|
108 |
+
scores = []
|
109 |
+
# run predictions using some batch size
|
110 |
+
for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
|
111 |
+
batch_start = i*batch_size
|
112 |
+
batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences)
|
113 |
+
batch = doc_sentences[batch_start: batch_end]
|
114 |
+
if batch:
|
115 |
+
preds = predict(model, batch, doc)
|
116 |
+
scores = scores + preds.tolist()
|
117 |
+
|
118 |
+
sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
|
119 |
+
sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)
|
120 |
+
|
121 |
+
sorted_result = sorted_sentences[:top_k]
|
122 |
+
sorted_result = sorted(sorted_result, key=lambda k: k['index'])
|
123 |
+
|
124 |
+
summary = [x["sentence"] for x in sorted_result]
|
125 |
+
summary = " ".join(summary)
|
126 |
+
|
127 |
+
return summary
|