Spaces:
Build error
Build error
import unittest | |
from typing import List | |
from dataset.dataset_loaders import CnndmDataset, MultinewsDataset, PubmedqaDataset | |
from model import SUPPORTED_SUMM_MODELS, list_all_models | |
from model.single_doc import LexRankModel, LongformerModel | |
from model.dialogue import HMNetModel | |
from helpers import ( | |
print_with_color, | |
get_summarization_set, | |
get_query_based_summarization_set, | |
) | |
class TestModels(unittest.TestCase): | |
single_doc_dataset = CnndmDataset() | |
multi_doc_dataset = MultinewsDataset() | |
query_based_dataset = PubmedqaDataset() | |
# # TODO: temporarily skipping HMNet, no dialogue-based dataset needed | |
# dialogue_based_dataset = SamsumDataset() | |
def test_list_models(self): | |
print_with_color(f"{'#'*10} Testing test_list_models... {'#'*10}\n", "35") | |
all_models = list_all_models() | |
for model_class, model_description in all_models: | |
print(f"{model_class} : {model_description}") | |
self.assertTrue(True) | |
self.assertEqual(len(all_models), len(SUPPORTED_SUMM_MODELS)) | |
print_with_color( | |
f"{'#'*10} test_list_models {__name__} test complete {'#'*10}\n\n", "32" | |
) | |
def validate_prediction(self, prediction: List[str], src: List): | |
""" | |
Verify that prediction instances match source instances. | |
""" | |
self.assertTrue(isinstance(prediction, list)) | |
self.assertTrue(all([isinstance(ins, str) for ins in prediction])) | |
self.assertTrue(len(prediction) == len(src)) | |
print("Prediction typing and length matches source instances!") | |
def test_model_summarize(self): | |
""" | |
Test all supported models on instances from datasets. | |
""" | |
print_with_color(f"{'#'*10} Testing all models... {'#'*10}\n", "35") | |
num_models = 0 | |
all_models = list_all_models() | |
for model_class, _ in all_models: | |
if model_class in [HMNetModel]: | |
# TODO: Temporarily skip HMNet (requires large pre-trained model download + GPU) | |
continue | |
print_with_color(f"Testing {model_class.model_name} model...", "35") | |
if model_class == LexRankModel: | |
# current LexRankModel requires a training set | |
training_src, training_tgt = get_summarization_set( | |
self.single_doc_dataset, 100 | |
) | |
model = model_class(training_src) | |
else: | |
model = model_class() | |
if model.is_query_based: | |
test_src, test_tgt, test_query = get_query_based_summarization_set( | |
self.query_based_dataset, 1 | |
) | |
prediction = model.summarize(test_src, test_query) | |
print( | |
f"Query: {test_query}\nGold summary: {test_tgt}\nPredicted summary: {prediction}" | |
) | |
elif model.is_multi_document: | |
test_src, test_tgt = get_summarization_set(self.multi_doc_dataset, 1) | |
prediction = model.summarize(test_src) | |
print(f"Gold summary: {test_tgt} \nPredicted summary: {prediction}") | |
self.validate_prediction(prediction, test_src) | |
elif model.is_dialogue_based: | |
test_src, test_tgt = get_summarization_set( | |
self.dialogue_based_dataset, 1 | |
) | |
prediction = model.summarize(test_src) | |
print(f"Gold summary: {test_tgt}\nPredicted summary: {prediction}") | |
self.validate_prediction(prediction, test_src) | |
else: | |
test_src, test_tgt = get_summarization_set(self.single_doc_dataset, 1) | |
prediction = model.summarize( | |
[test_src[0] * 5] if model_class == LongformerModel else test_src | |
) | |
print(f"Gold summary: {test_tgt} \nPredicted summary: {prediction}") | |
self.validate_prediction( | |
prediction, | |
[test_src[0] * 5] if model_class == LongformerModel else test_src, | |
) | |
print_with_color(f"{model_class.model_name} model test complete\n", "32") | |
num_models += 1 | |
print_with_color( | |
f"{'#'*10} test_model_summarize complete ({num_models} models) {'#'*10}\n", | |
"32", | |
) | |
if __name__ == "__main__": | |
unittest.main() | |