File size: 5,100 Bytes
9dca909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from transformers import AutoModel, AutoTokenizer
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os
import spacy
import certifi
import streamlit as st
os.environ['SSL_CERT_FILE'] = certifi.where()
nlp = spacy.load("en_core_web_lg")
model_name = "microsoft/MiniLM-L12-H384-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings/sum_mask
class SentenceBERTClassifier(torch.nn.Module):
def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384):
super(SentenceBERTClassifier, self).__init__()
self.model = AutoModel.from_pretrained(model_name)
self.dense1 = torch.nn.Linear(input_dim*3, 768)
self.relu1 = torch.nn.ReLU()
self.dropout1 = torch.nn.Dropout(0.1)
self.dense2 = torch.nn.Linear(768, 384)
self.relu2 = torch.nn.ReLU()
self.dropout2 = torch.nn.Dropout(0.1)
self.classifier = torch.nn.Linear(384, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask)
sent_embedding = mean_pooling(sent_output, sent_mask)
doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask)
doc_embedding = mean_pooling(doc_output, doc_mask)
combined_embedding = sent_embedding * doc_embedding
concat_embedding = torch.cat((sent_embedding, doc_embedding, combined_embedding), dim=1)
dense_output1 = self.dense1(concat_embedding)
relu_output1 = self.relu1(dense_output1)
dropout_output1 = self.dropout1(relu_output1)
dense_output2 = self.dense2(dropout_output1)
relu_output2 = self.relu2(dense_output2)
dropout_output2 = self.dropout2(relu_output2)
logits = self.classifier(dropout_output2)
probs = self.sigmoid(logits)
return probs
device = 'cuda' if torch.cuda.is_available() else 'cpu'
extractive_model = SentenceBERTClassifier(model_name=model_name)
extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) ))
extractive_model.eval()
def get_tokens(text, tokenizer):
inputs = tokenizer.batch_encode_plus(
text
, add_special_tokens=True
, max_length = 512
, padding="max_length"
, return_token_type_ids=True
, truncation=True
, return_tensors="pt")
ids = inputs["input_ids"]
mask = inputs["attention_mask"]
return ids, mask
# Predicting the relevance scores of sentences in a document
def predict(model,sents, doc):
sent_id, sent_mask = get_tokens(sents,tokenizer)
sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)
doc_id, doc_mask = get_tokens([doc],tokenizer)
doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1)
doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
# 3. Handle OOV tokens
# Replace OOV tokens with the 'unk' token ID before passing to the model
sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
preds = model(sent_id, doc_id, sent_mask, doc_mask)
return preds
def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4):
doc = doc.replace("\n","")
doc_sentences = []
for sent in nlp(doc).sents:
if len(sent) > min_sentence_length:
doc_sentences.append(str(sent))
# doc_id, doc_mask = get_tokens([doc],tokenizer)
# doc_id, doc_mask = doc_id * batch_size, doc_mask* batch_size
# doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
scores = []
# run predictions using some batch size
for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
batch_start = i*batch_size
batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences)
batch = doc_sentences[batch_start: batch_end]
if batch:
preds = predict(model, batch, doc)
scores = scores + preds.tolist()
sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)
sorted_result = sorted_sentences[:top_k]
sorted_result = sorted(sorted_result, key=lambda k: k['index'])
summary = [x["sentence"] for x in sorted_result]
summary = " ".join(summary)
return summary |