|
from transformers import AutoModel, AutoTokenizer
|
|
import torch
|
|
from tqdm import tqdm
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import os
|
|
import spacy
|
|
import certifi
|
|
import streamlit as st
|
|
|
|
os.environ['SSL_CERT_FILE'] = certifi.where()
|
|
|
|
nlp = spacy.load("en_core_web_lg")
|
|
|
|
model_name = "microsoft/MiniLM-L12-H384-uncased"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
def mean_pooling(model_output, attention_mask):
|
|
token_embeddings = model_output[0]
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
return sum_embeddings/sum_mask
|
|
|
|
class SentenceBERTClassifier(torch.nn.Module):
|
|
def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384):
|
|
super(SentenceBERTClassifier, self).__init__()
|
|
self.model = AutoModel.from_pretrained(model_name)
|
|
self.dense1 = torch.nn.Linear(input_dim*3, 768)
|
|
self.relu1 = torch.nn.ReLU()
|
|
self.dropout1 = torch.nn.Dropout(0.1)
|
|
self.dense2 = torch.nn.Linear(768, 384)
|
|
self.relu2 = torch.nn.ReLU()
|
|
self.dropout2 = torch.nn.Dropout(0.1)
|
|
self.classifier = torch.nn.Linear(384, 1)
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
|
|
def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
|
|
sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask)
|
|
sent_embedding = mean_pooling(sent_output, sent_mask)
|
|
|
|
doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask)
|
|
doc_embedding = mean_pooling(doc_output, doc_mask)
|
|
|
|
combined_embedding = sent_embedding * doc_embedding
|
|
concat_embedding = torch.cat((sent_embedding, doc_embedding, combined_embedding), dim=1)
|
|
|
|
|
|
dense_output1 = self.dense1(concat_embedding)
|
|
relu_output1 = self.relu1(dense_output1)
|
|
dropout_output1 = self.dropout1(relu_output1)
|
|
dense_output2 = self.dense2(dropout_output1)
|
|
relu_output2 = self.relu2(dense_output2)
|
|
dropout_output2 = self.dropout2(relu_output2)
|
|
logits = self.classifier(dropout_output2)
|
|
probs = self.sigmoid(logits)
|
|
return probs
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
extractive_model = SentenceBERTClassifier(model_name=model_name)
|
|
extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) ))
|
|
extractive_model.eval()
|
|
|
|
def get_tokens(text, tokenizer):
|
|
inputs = tokenizer.batch_encode_plus(
|
|
text
|
|
, add_special_tokens=True
|
|
, max_length = 512
|
|
, padding="max_length"
|
|
, return_token_type_ids=True
|
|
, truncation=True
|
|
, return_tensors="pt")
|
|
|
|
ids = inputs["input_ids"]
|
|
mask = inputs["attention_mask"]
|
|
|
|
return ids, mask
|
|
|
|
|
|
def predict(model,sents, doc):
|
|
sent_id, sent_mask = get_tokens(sents,tokenizer)
|
|
sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)
|
|
|
|
doc_id, doc_mask = get_tokens([doc],tokenizer)
|
|
doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1)
|
|
doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
|
|
|
|
|
|
|
|
sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
|
|
doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
|
|
|
|
preds = model(sent_id, doc_id, sent_mask, doc_mask)
|
|
return preds
|
|
|
|
def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4):
|
|
doc = doc.replace("\n","")
|
|
doc_sentences = []
|
|
for sent in nlp(doc).sents:
|
|
if len(sent) > min_sentence_length:
|
|
doc_sentences.append(str(sent))
|
|
|
|
|
|
|
|
|
|
|
|
scores = []
|
|
|
|
for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
|
|
batch_start = i*batch_size
|
|
batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences)
|
|
batch = doc_sentences[batch_start: batch_end]
|
|
if batch:
|
|
preds = predict(model, batch, doc)
|
|
scores = scores + preds.tolist()
|
|
|
|
sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
|
|
sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)
|
|
|
|
sorted_result = sorted_sentences[:top_k]
|
|
sorted_result = sorted(sorted_result, key=lambda k: k['index'])
|
|
|
|
summary = [x["sentence"] for x in sorted_result]
|
|
summary = " ".join(summary)
|
|
|
|
return summary |