Spaces:
Sleeping
Sleeping
import os | |
import re | |
import nltk | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize | |
from nltk.util import ngrams | |
import spacy | |
# from gensim.summarization.summarizer import summarize | |
# from gensim.summarization import keywords | |
# Abstractive Summarisation | |
from transformers import BartForConditionalGeneration | |
from transformers import AutoTokenizer | |
import torch | |
# Keyword/Keyphrase Extraction | |
from keybert import _highlight | |
from keybert import KeyBERT | |
from keyphrase_vectorizers import KeyphraseCountVectorizer, KeyphraseTfidfVectorizer | |
from sklearn.feature_extraction.text import CountVectorizer | |
import time | |
import threading | |
from collections import defaultdict | |
class AbstractiveSummarizer: | |
def __init__(self): | |
self.nlp = spacy.load('en_core_web_lg') | |
self.summary = "" | |
def generate_batch(self, text, tokenizer): | |
""" | |
Convert the text into multiple sentence parts of appropriate input size to feed to the model | |
Arguments: | |
text: The License text to summarise | |
tokenizer: The tokenizer corresponding to the model used to convert the text into separate words(tokens) | |
Returns: | |
The text formatted into List of sentences to feed to the model | |
""" | |
parsed = self.nlp(text) | |
sents = [sent.text for sent in parsed.sents] | |
max_size = tokenizer.model_max_length | |
batch = tokenizer(sents, return_tensors='pt', return_length=True, padding='longest') | |
inp_batch = [] | |
cur_batch = torch.empty((0,), dtype=torch.int64) | |
for enc_sent, length in zip(batch['input_ids'], batch['length']): | |
cur_size = cur_batch.shape[0] | |
if (cur_size + length.item()) <= max_size: | |
cur_batch = torch.cat((cur_batch,enc_sent[:length.item()])) | |
else: | |
inp_batch.append(torch.unsqueeze(cur_batch,0)) | |
cur_batch = enc_sent[:length.item()] | |
inp_batch.append(torch.unsqueeze(cur_batch,0)) | |
return inp_batch | |
def summarize(self, src, tokenizer, model): | |
""" | |
Function to use the pre-trained model to generate the summary | |
Arguments: | |
src: License text to summarise | |
tokenizer: The tokenizer corresponding to the model used to convert the text into separate words(tokens) | |
model: The pre-trained Model object used to perform the summarization | |
Returns: | |
summary: The summarised texts | |
""" | |
batch_texts = self.generate_batch(src, tokenizer) | |
enc_summary_list = [model.generate(batch, max_length=512) for batch in batch_texts] | |
summary_list = [tokenizer.batch_decode(enc_summ, skip_special_tokens=True) for enc_summ in enc_summary_list] | |
# orig_list = [tokenizer.batch_decode(batch, skip_special_tokens=True) for batch in batch_texts] | |
summary_texts = [summ[0] for summ in summary_list] | |
summary = " ".join(summary_texts) | |
self.summary = summary | |
def bart(self, src): | |
""" | |
Initialize the facebook BART pre-trained model and call necessary functions to summarize | |
Arguments: | |
src: The text to summarise | |
Returns/Set as instance variable: | |
The summarized text | |
""" | |
start_time = time.time() | |
model_name = 'facebook/bart-large-cnn' | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = BartForConditionalGeneration.from_pretrained(model_name).to(device) | |
self.summarize(src, tokenizer, model) | |
def get_summary(lic_txt): | |
""" | |
Summarize the license and return it | |
Arguments: | |
spdx - Id of License to summarise | |
Returns: | |
pos_text: The part of the License containing information for permitted use | |
neg_text: The part of the License containing information about usage restrictions | |
lic_txt: The full license text | |
summary - The generated summary of the license | |
""" | |
print('Summarising...') | |
absSum = AbstractiveSummarizer() | |
# Generate summary | |
thread = absSum.bart(lic_txt) | |
return thread | |
def extract_ngrams(phrase): | |
phrase = re.sub('[^a-zA-Z0-9]',' ', phrase) | |
tokens = word_tokenize(phrase) | |
res = [] | |
for num in range(len(tokens)+1): | |
temp = ngrams(tokens, num) | |
res += [' '.join(grams) for grams in temp] | |
return res | |
def get_highlight_text(text, keywords): | |
""" | |
Custom function to find exact position of keywords for highlighting | |
""" | |
text = re.sub('[-/]',' ', text) | |
# text = re.sub('(\n *){2,}','\n',text) | |
text = re.sub(' {2,}', ' ', text) | |
# Group keywords by length | |
kw_len = defaultdict(list) | |
for kw in keywords: | |
kw_len[len(kw)].append(kw) | |
# Use sliding window technique to check equal strings | |
spans = [] | |
for length in kw_len: | |
w_start, w_end = 0, length | |
while w_end <= len(text): | |
for kw in kw_len[length]: | |
j = w_start | |
eq = True | |
for i in range(len(kw)): | |
if text[j] != kw[i]: | |
eq = False | |
break | |
j += 1 | |
if eq: | |
spans.append([w_start, w_end]) | |
break | |
w_start += 1 | |
w_end += 1 | |
if not spans: | |
return text | |
# merge spans | |
spans.sort(key=lambda x: x[0]) | |
merged = [] | |
st, end = spans[0][0], spans[0][1] | |
for i in range(1, len(spans)): | |
s,e = spans[i] | |
if st <= s <= end: | |
end = max(e, end) | |
else: | |
merged.append([st, end]) | |
st, end = s,e | |
merged.append([st,end]) | |
res = [] | |
sub_start = 0 | |
for s,e in merged: | |
res.append(text[sub_start:s]) | |
res.append((text[s:e], "", "#f66")) | |
sub_start = e | |
res.append(text[sub_start:]) | |
return res | |
def get_keywords(datatype, task, field, pos_text, neg_text): | |
""" | |
Summarize the license and generate the good and bad use tags | |
Arguments: | |
datafield - Type of 'data' used under the license: Eg. Model, Data, Model Derivatives, Source Code | |
task - The type of task the model is designed to do | |
field - Which 'field' to use the data in: Eg. Medical, Commercial, Non-Commercial, Research | |
pos_text: The part of the License containing information for permitted use | |
neg_text: The part of the License containing information about usage restrictions | |
Returns: | |
p_keywords - List of Positive(Permitted use) keywords extracted from summary | |
n_keywords - List of Negative(Restriction) keywords extracted from summary | |
contrd - boolean flag to show if there is any contradiction or not | |
hl_text - the license text formatted to display in a highlighted manner | |
""" | |
print('Extracting keywords...') | |
#[e.lower() for e in list_strings] | |
datatype, task, field = datatype.lower(), task.lower(), field.lower() | |
#datatype = [e.lower() for e in datatype] | |
#task = [e.lower() for e in task] | |
#field = [e.lower() for e in field] | |
#datatype, task, field = datatype, task, str(field) | |
stop_words = set(stopwords.words('english')) | |
#stops = nltk.corpus.stopwords.words('english') | |
#stop_words = set(stops) | |
stop_words = stop_words.union({'license', 'licensing', 'licensor', 'copyright', 'copyrights', 'patent'}) | |
pos_kw_model = KeyBERT() | |
neg_kw_model = KeyBERT() | |
candidates = [] | |
for term in [datatype, task, field]: | |
candidates += extract_ngrams(term) | |
p_kw = pos_kw_model.extract_keywords(docs=pos_text, top_n=40, vectorizer=KeyphraseCountVectorizer(stop_words=stop_words))#, pos_pattern='<N.*>+')) | |
n_kw = neg_kw_model.extract_keywords(docs=neg_text, top_n=40, vectorizer=KeyphraseCountVectorizer(stop_words=stop_words))#, pos_pattern='<N.*>+')) | |
ngram_max = max([len(word_tokenize(x)) for x in [datatype, task, field]]) | |
pc_kw = pos_kw_model.extract_keywords(docs=pos_text ,candidates=candidates, keyphrase_ngram_range=(1,ngram_max)) | |
nc_kw = neg_kw_model.extract_keywords(docs=neg_text ,candidates=candidates, keyphrase_ngram_range=(1,ngram_max)) | |
# Check contradiction | |
all_cont = [kw for (kw,_) in nc_kw] | |
cont_terms = set(all_cont).intersection(set(extract_ngrams(field))) | |
contrd = True if len(cont_terms) > 0 else False | |
hl_text = "" if not contrd else get_highlight_text(neg_text, all_cont) | |
p_kw += pc_kw | |
n_kw += nc_kw | |
p_kw.sort(key=lambda x: x[1], reverse=True) | |
n_kw.sort(key=lambda x: x[1], reverse=True) | |
p_keywords = [kw for (kw,score) in p_kw if score < 0.5] | |
n_keywords = [kw for (kw,score) in n_kw if score < 0.5] | |
return p_keywords, n_keywords, contrd, hl_text |