from transformers import AutoModel, AutoTokenizer import torch from tqdm import tqdm from torch.utils.data import Dataset, DataLoader import os import spacy import certifi import streamlit as st os.environ['SSL_CERT_FILE'] = certifi.where() nlp = spacy.load("en_core_web_lg") model_name = "microsoft/MiniLM-L12-H384-uncased" tokenizer = AutoTokenizer.from_pretrained(model_name) def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return sum_embeddings/sum_mask class SentenceBERTClassifier(torch.nn.Module): def __init__(self, model_name="microsoft/MiniLM-L12-H384-uncased", input_dim=384): super(SentenceBERTClassifier, self).__init__() self.model = AutoModel.from_pretrained(model_name) self.dense1 = torch.nn.Linear(input_dim*3, 768) self.relu1 = torch.nn.ReLU() self.dropout1 = torch.nn.Dropout(0.1) self.dense2 = torch.nn.Linear(768, 384) self.relu2 = torch.nn.ReLU() self.dropout2 = torch.nn.Dropout(0.1) self.classifier = torch.nn.Linear(384, 1) self.sigmoid = torch.nn.Sigmoid() def forward(self, sent_ids, doc_ids, sent_mask, doc_mask): sent_output = self.model(input_ids=sent_ids, attention_mask=sent_mask) sent_embedding = mean_pooling(sent_output, sent_mask) doc_output = self.model(input_ids=doc_ids, attention_mask=doc_mask) doc_embedding = mean_pooling(doc_output, doc_mask) combined_embedding = sent_embedding * doc_embedding concat_embedding = torch.cat((sent_embedding, doc_embedding, combined_embedding), dim=1) dense_output1 = self.dense1(concat_embedding) relu_output1 = self.relu1(dense_output1) dropout_output1 = self.dropout1(relu_output1) dense_output2 = self.dense2(dropout_output1) relu_output2 = self.relu2(dense_output2) dropout_output2 = self.dropout2(relu_output2) logits = self.classifier(dropout_output2) probs = self.sigmoid(logits) return probs device = 'cuda' if torch.cuda.is_available() else 'cpu' extractive_model = SentenceBERTClassifier(model_name=model_name) extractive_model.load_state_dict(torch.load("model_path\minilm_bal_exsum.pth", map_location=torch.device(device) )) extractive_model.eval() def get_tokens(text, tokenizer): inputs = tokenizer.batch_encode_plus( text , add_special_tokens=True , max_length = 512 , padding="max_length" , return_token_type_ids=True , truncation=True , return_tensors="pt") ids = inputs["input_ids"] mask = inputs["attention_mask"] return ids, mask # Predicting the relevance scores of sentences in a document def predict(model,sents, doc): sent_id, sent_mask = get_tokens(sents,tokenizer) sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long) doc_id, doc_mask = get_tokens([doc],tokenizer) doc_id, doc_mask = doc_id.repeat(len(sents), 1), doc_mask.repeat(len(sents), 1) doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long) # 3. Handle OOV tokens # Replace OOV tokens with the 'unk' token ID before passing to the model sent_id[sent_id >= tokenizer.vocab_size] = tokenizer.unk_token_id doc_id[doc_id >= tokenizer.vocab_size] = tokenizer.unk_token_id preds = model(sent_id, doc_id, sent_mask, doc_mask) return preds def extract_summary(doc, model=extractive_model, min_sentence_length=14, top_k=4, batch_size=4): doc = doc.replace("\n","") doc_sentences = [] for sent in nlp(doc).sents: if len(sent) > min_sentence_length: doc_sentences.append(str(sent)) # doc_id, doc_mask = get_tokens([doc],tokenizer) # doc_id, doc_mask = doc_id * batch_size, doc_mask* batch_size # doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long) scores = [] # run predictions using some batch size for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)): batch_start = i*batch_size batch_end = (i+1) * batch_size if (i+1) * batch_size < len(doc_sentences) else len(doc_sentences) batch = doc_sentences[batch_start: batch_end] if batch: preds = predict(model, batch, doc) scores = scores + preds.tolist() sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))] sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True) sorted_result = sorted_sentences[:top_k] sorted_result = sorted(sorted_result, key=lambda k: k['index']) summary = [x["sentence"] for x in sorted_result] summary = " ".join(summary) return summary