Spaces:
Build error
Build error
from .base_query_based_model import QueryBasedSummModel | |
from model.base_model import SummModel | |
from model.single_doc import TextRankModel | |
from typing import List | |
from gensim.summarization.bm25 import BM25 | |
from nltk import word_tokenize | |
class BM25SummModel(QueryBasedSummModel): | |
# static variables | |
model_name = "BM25" | |
is_extractive = True # only represents the retrieval part | |
is_neural = False # only represents the retrieval part | |
is_query_based = True | |
def __init__( | |
self, | |
trained_domain: str = None, | |
max_input_length: int = None, | |
max_output_length: int = None, | |
model_backend: SummModel = TextRankModel, | |
retrieval_ratio: float = 0.5, | |
preprocess: bool = True, | |
**kwargs | |
): | |
super(BM25SummModel, self).__init__( | |
trained_domain=trained_domain, | |
max_input_length=max_input_length, | |
max_output_length=max_output_length, | |
model_backend=model_backend, | |
retrieval_ratio=retrieval_ratio, | |
preprocess=preprocess, | |
**kwargs | |
) | |
def _retrieve(self, instance: List[str], query: List[str], n_best): | |
bm25 = BM25(word_tokenize(s) for s in instance) | |
scores = bm25.get_scores(query) | |
best_sent_ind = sorted( | |
range(len(scores)), key=lambda i: scores[i], reverse=True | |
)[:n_best] | |
top_n_sent = [instance[ind] for ind in sorted(best_sent_ind)] | |
return top_n_sent | |