SorterSon commited on
Commit
9dca909
·
verified ·
1 Parent(s): aca07c4

files uploaded

Browse files
Files changed (3) hide show
  1. app.py +25 -0
  2. app_func.py +11 -0
  3. extraction.py +127 -0
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import app_func
4
+ import extraction
5
+ # import abstraction
6
+
7
+ st.title("Document Summarizer")
8
+
9
+ sidebar = st.sidebar
10
+
11
+ page = sidebar.radio("Navigation", ["About Project", "Summarization", "Comparision"])
12
+
13
+ if page == "About Project":
14
+ pass
15
+
16
+ elif page == "Summarization":
17
+ uploaded_doc = st.file_uploader("Upload a document", type=["txt", "pdf"])
18
+
19
+ if uploaded_doc is not None:
20
+ doc_text = app_func.read_uploaded_doc(uploaded_doc)
21
+
22
+ extractive_summary = extraction.extract_summary(doc_text)
23
+ # st.write("Extractive Summary:",extractive_summary)
24
+ # abstractive_summary = abstraction.generate_summary(extractive_summary)
25
+ st.write("Summary:", extractive_summary)
app_func.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PyPDF2
2
+
3
+ def read_uploaded_doc(uploaded_doc):
4
+ # if uploaded_doc.type == "aplication/pdf":
5
+ reader = PyPDF2.PdfReader(uploaded_doc)
6
+ text = ""
7
+ for page in reader.pages:
8
+ text += page.extract_text()
9
+ return text
10
+ # else: # Assuming text file
11
+ # return str(uploaded_doc.read(), "utf-8")
extraction.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import torch
3
+ from tqdm import tqdm
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import os
6
+ import spacy
7
+ import certifi
8
+ import streamlit as st
9
+
10
+ os.environ['SSL_CERT_FILE'] = certifi.where()
11
+
12
+ nlp = spacy.load("en_core_web_lg")
13
+
14
+ model_name = "microsoft/MiniLM-L12-H384-uncased"
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+
18
+ def mean_pooling(model_output, attention_mask):
19
+ token_embeddings = model_output[0]
20
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
21
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
22
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
23
+ return sum_embeddings/sum_mask
24
+
25
+ class SentenceBERTClassifier(torch.nn.Module):
26
+ def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384):
27
+ super(SentenceBERTClassifier, self).__init__()
28
+ self.model = AutoModel.from_pretrained(model_name)
29
+ self.dense1 = torch.nn.Linear(input_dim*3, 768)
30
+ self.relu1 = torch.nn.ReLU()
31
+ self.dropout1 = torch.nn.Dropout(0.1)
32
+ self.dense2 = torch.nn.Linear(768, 384)
33
+ self.relu2 = torch.nn.ReLU()
34
+ self.dropout2 = torch.nn.Dropout(0.1)
35
+ self.classifier = torch.nn.Linear(384, 1)
36
+ self.sigmoid = torch.nn.Sigmoid()
37
+
38
+ def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
39
+ sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask)
40
+ sent_embedding = mean_pooling(sent_output, sent_mask)
41
+
42
+ doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask)
43
+ doc_embedding = mean_pooling(doc_output, doc_mask)
44
+
45
+ combined_embedding = sent_embedding * doc_embedding
46
+ concat_embedding = torch.cat((sent_embedding, doc_embedding, combined_embedding), dim=1)
47
+
48
+
49
+ dense_output1 = self.dense1(concat_embedding)
50
+ relu_output1 = self.relu1(dense_output1)
51
+ dropout_output1 = self.dropout1(relu_output1)
52
+ dense_output2 = self.dense2(dropout_output1)
53
+ relu_output2 = self.relu2(dense_output2)
54
+ dropout_output2 = self.dropout2(relu_output2)
55
+ logits = self.classifier(dropout_output2)
56
+ probs = self.sigmoid(logits)
57
+ return probs
58
+
59
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
60
+
61
+ extractive_model = SentenceBERTClassifier(model_name=model_name)
62
+ extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) ))
63
+ extractive_model.eval()
64
+
65
+ def get_tokens(text, tokenizer):
66
+ inputs = tokenizer.batch_encode_plus(
67
+ text
68
+ , add_special_tokens=True
69
+ , max_length = 512
70
+ , padding="max_length"
71
+ , return_token_type_ids=True
72
+ , truncation=True
73
+ , return_tensors="pt")
74
+
75
+ ids = inputs["input_ids"]
76
+ mask = inputs["attention_mask"]
77
+
78
+ return ids, mask
79
+
80
+ # Predicting the relevance scores of sentences in a document
81
+ def predict(model,sents, doc):
82
+ sent_id, sent_mask = get_tokens(sents,tokenizer)
83
+ sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)
84
+
85
+ doc_id, doc_mask = get_tokens([doc],tokenizer)
86
+ doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1)
87
+ doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
88
+
89
+ # 3. Handle OOV tokens
90
+ # Replace OOV tokens with the 'unk' token ID before passing to the model
91
+ sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
92
+ doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
93
+
94
+ preds = model(sent_id, doc_id, sent_mask, doc_mask)
95
+ return preds
96
+
97
+ def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4):
98
+ doc = doc.replace("\n","")
99
+ doc_sentences = []
100
+ for sent in nlp(doc).sents:
101
+ if len(sent) > min_sentence_length:
102
+ doc_sentences.append(str(sent))
103
+
104
+ # doc_id, doc_mask = get_tokens([doc],tokenizer)
105
+ # doc_id, doc_mask = doc_id * batch_size, doc_mask* batch_size
106
+ # doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
107
+
108
+ scores = []
109
+ # run predictions using some batch size
110
+ for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
111
+ batch_start = i*batch_size
112
+ batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences)
113
+ batch = doc_sentences[batch_start: batch_end]
114
+ if batch:
115
+ preds = predict(model, batch, doc)
116
+ scores = scores + preds.tolist()
117
+
118
+ sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
119
+ sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)
120
+
121
+ sorted_result = sorted_sentences[:top_k]
122
+ sorted_result = sorted(sorted_result, key=lambda k: k['index'])
123
+
124
+ summary = [x["sentence"] for x in sorted_result]
125
+ summary = " ".join(summary)
126
+
127
+ return summary