File size: 5,100 Bytes
9dca909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from transformers import AutoModel, AutoTokenizer
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os
import spacy
import certifi
import streamlit as st

os.environ['SSL_CERT_FILE'] = certifi.where()

nlp = spacy.load("en_core_web_lg")

model_name = "microsoft/MiniLM-L12-H384-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings/sum_mask

class SentenceBERTClassifier(torch.nn.Module):
  def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384):
    super(SentenceBERTClassifier, self).__init__()
    self.model = AutoModel.from_pretrained(model_name)
    self.dense1 = torch.nn.Linear(input_dim*3, 768)
    self.relu1 = torch.nn.ReLU()
    self.dropout1 = torch.nn.Dropout(0.1)
    self.dense2 = torch.nn.Linear(768, 384)
    self.relu2 = torch.nn.ReLU()
    self.dropout2 = torch.nn.Dropout(0.1)
    self.classifier = torch.nn.Linear(384, 1)
    self.sigmoid = torch.nn.Sigmoid()

  def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
    sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask)
    sent_embedding = mean_pooling(sent_output, sent_mask)

    doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask)
    doc_embedding = mean_pooling(doc_output, doc_mask)

    combined_embedding = sent_embedding * doc_embedding
    concat_embedding = torch.cat((sent_embedding, doc_embedding, combined_embedding), dim=1)
    

    dense_output1 = self.dense1(concat_embedding)
    relu_output1 = self.relu1(dense_output1)
    dropout_output1 = self.dropout1(relu_output1)
    dense_output2 = self.dense2(dropout_output1)
    relu_output2 = self.relu2(dense_output2)
    dropout_output2 = self.dropout2(relu_output2)
    logits = self.classifier(dropout_output2)
    probs = self.sigmoid(logits)
    return probs

device = 'cuda' if torch.cuda.is_available() else 'cpu'

extractive_model = SentenceBERTClassifier(model_name=model_name) 
extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) ))
extractive_model.eval()

def get_tokens(text, tokenizer):
  inputs = tokenizer.batch_encode_plus(
                        text
                      , add_special_tokens=True
                      , max_length = 512
                      , padding="max_length"
                      , return_token_type_ids=True
                      , truncation=True
                      , return_tensors="pt")
  
  ids = inputs["input_ids"]
  mask = inputs["attention_mask"]

  return ids, mask

# Predicting the relevance scores of sentences in a document
def predict(model,sents, doc):
  sent_id, sent_mask = get_tokens(sents,tokenizer)
  sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)

  doc_id, doc_mask = get_tokens([doc],tokenizer)
  doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1)
  doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)

  # 3. Handle OOV tokens
  # Replace OOV tokens with the 'unk' token ID before passing to the model
  sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id 
  doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id

  preds = model(sent_id, doc_id, sent_mask, doc_mask)
  return preds

def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4):
  doc = doc.replace("\n","")
  doc_sentences = []
  for sent in nlp(doc).sents:
    if len(sent) > min_sentence_length:
      doc_sentences.append(str(sent))

#   doc_id, doc_mask = get_tokens([doc],tokenizer)
#   doc_id, doc_mask = doc_id * batch_size, doc_mask* batch_size
#   doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)

  scores = []
  # run predictions using some batch size
  for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
    batch_start = i*batch_size
    batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences)
    batch = doc_sentences[batch_start: batch_end]
    if batch:
      preds = predict(model, batch, doc)
      scores = scores + preds.tolist()

  sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
  sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)

  sorted_result = sorted_sentences[:top_k]
  sorted_result = sorted(sorted_result, key=lambda k: k['index'])

  summary = [x["sentence"] for x in sorted_result]
  summary = " ".join(summary)

  return summary