Project / extraction.py
SorterSon's picture
files uploaded
9dca909 verified
raw
history blame
5.1 kB
from transformers import AutoModel, AutoTokenizer
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os
import spacy
import certifi
import streamlit as st
os.environ['SSL_CERT_FILE'] = certifi.where()
nlp = spacy.load("en_core_web_lg")
model_name = "microsoft/MiniLM-L12-H384-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings/sum_mask
class SentenceBERTClassifier(torch.nn.Module):
def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384):
super(SentenceBERTClassifier, self).__init__()
self.model = AutoModel.from_pretrained(model_name)
self.dense1 = torch.nn.Linear(input_dim*3, 768)
self.relu1 = torch.nn.ReLU()
self.dropout1 = torch.nn.Dropout(0.1)
self.dense2 = torch.nn.Linear(768, 384)
self.relu2 = torch.nn.ReLU()
self.dropout2 = torch.nn.Dropout(0.1)
self.classifier = torch.nn.Linear(384, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask)
sent_embedding = mean_pooling(sent_output, sent_mask)
doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask)
doc_embedding = mean_pooling(doc_output, doc_mask)
combined_embedding = sent_embedding * doc_embedding
concat_embedding = torch.cat((sent_embedding, doc_embedding, combined_embedding), dim=1)
dense_output1 = self.dense1(concat_embedding)
relu_output1 = self.relu1(dense_output1)
dropout_output1 = self.dropout1(relu_output1)
dense_output2 = self.dense2(dropout_output1)
relu_output2 = self.relu2(dense_output2)
dropout_output2 = self.dropout2(relu_output2)
logits = self.classifier(dropout_output2)
probs = self.sigmoid(logits)
return probs
device = 'cuda' if torch.cuda.is_available() else 'cpu'
extractive_model = SentenceBERTClassifier(model_name=model_name)
extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) ))
extractive_model.eval()
def get_tokens(text, tokenizer):
inputs = tokenizer.batch_encode_plus(
text
, add_special_tokens=True
, max_length = 512
, padding="max_length"
, return_token_type_ids=True
, truncation=True
, return_tensors="pt")
ids = inputs["input_ids"]
mask = inputs["attention_mask"]
return ids, mask
# Predicting the relevance scores of sentences in a document
def predict(model,sents, doc):
sent_id, sent_mask = get_tokens(sents,tokenizer)
sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)
doc_id, doc_mask = get_tokens([doc],tokenizer)
doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1)
doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
# 3. Handle OOV tokens
# Replace OOV tokens with the 'unk' token ID before passing to the model
sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id
preds = model(sent_id, doc_id, sent_mask, doc_mask)
return preds
def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4):
doc = doc.replace("\n","")
doc_sentences = []
for sent in nlp(doc).sents:
if len(sent) > min_sentence_length:
doc_sentences.append(str(sent))
# doc_id, doc_mask = get_tokens([doc],tokenizer)
# doc_id, doc_mask = doc_id * batch_size, doc_mask* batch_size
# doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
scores = []
# run predictions using some batch size
for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
batch_start = i*batch_size
batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences)
batch = doc_sentences[batch_start: batch_end]
if batch:
preds = predict(model, batch, doc)
scores = scores + preds.tolist()
sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)
sorted_result = sorted_sentences[:top_k]
sorted_result = sorted(sorted_result, key=lambda k: k['index'])
summary = [x["sentence"] for x in sorted_result]
summary = " ".join(summary)
return summary