Spaces:
Build error
Build error
import unittest | |
from model.base_model import SummModel | |
from model import SUPPORTED_SUMM_MODELS | |
from pipeline import assemble_model_pipeline | |
from evaluation.base_metric import SummMetric | |
from evaluation import SUPPORTED_EVALUATION_METRICS | |
from dataset.st_dataset import SummInstance, SummDataset | |
from dataset import SUPPORTED_SUMM_DATASETS | |
from dataset.dataset_loaders import ScisummnetDataset, ArxivDataset | |
from helpers import print_with_color, retrieve_random_test_instances | |
import random | |
import time | |
from typing import List, Union, Tuple | |
import sys | |
import re | |
class IntegrationTests(unittest.TestCase): | |
def get_prediction( | |
self, model: SummModel, dataset: SummDataset, test_instances: List[SummInstance] | |
) -> Tuple[Union[List[str], List[List[str]]], Union[List[str], List[List[str]]]]: | |
""" | |
Get summary prediction given model and dataset instances. | |
:param SummModel `model`: Model for summarization task. | |
:param SummDataset `dataset`: Dataset for summarization task. | |
:param List[SummInstance] `test_instances`: Instances from `dataset` to summarize. | |
:returns Tuple containing summary list of summary predictions and targets corresponding to each instance in `test_instances`. | |
""" | |
src = ( | |
[ins.source[0] for ins in test_instances] | |
if isinstance(dataset, ScisummnetDataset) | |
else [ins.source for ins in test_instances] | |
) | |
tgt = [ins.summary for ins in test_instances] | |
query = ( | |
[ins.query for ins in test_instances] if dataset.is_query_based else None | |
) | |
prediction = model.summarize(src, query) | |
return prediction, tgt | |
def get_eval_dict(self, metric: SummMetric, prediction: List[str], tgt: List[str]): | |
""" | |
Run evaluation metric on summary prediction. | |
:param SummMetric `metric`: Evaluation metric. | |
:param List[str] `prediction`: Summary prediction instances. | |
:param List[str] `tgt`: Target prediction instances from dataset. | |
""" | |
score_dict = metric.evaluate(prediction, tgt) | |
return score_dict | |
def test_all(self): | |
""" | |
Runs integration test on all compatible dataset + model + evaluation metric pipelines supported by SummerTime. | |
""" | |
print_with_color("\nInitializing all evaluation metrics...", "35") | |
evaluation_metrics = [] | |
for eval_cls in SUPPORTED_EVALUATION_METRICS: | |
# # TODO: Temporarily skipping Rouge/RougeWE metrics to avoid local bug. | |
# if eval_cls in [Rouge, RougeWe]: | |
# continue | |
print(eval_cls) | |
evaluation_metrics.append(eval_cls()) | |
print_with_color("\n\nBeginning integration tests...", "35") | |
for dataset_cls in SUPPORTED_SUMM_DATASETS: | |
# TODO: Temporarily skipping Arxiv (size/time) | |
if dataset_cls in [ArxivDataset]: | |
continue | |
dataset = dataset_cls() | |
if dataset.train_set is not None: | |
dataset_instances = list(dataset.train_set) | |
print( | |
f"\n{dataset.dataset_name} has a training set of {len(dataset_instances)} examples" | |
) | |
print_with_color( | |
f"Initializing all matching model pipelines for {dataset.dataset_name} dataset...", | |
"35", | |
) | |
# matching_model_instances = assemble_model_pipeline(dataset_cls, list(filter(lambda m: m != PegasusModel, SUPPORTED_SUMM_MODELS))) | |
matching_model_instances = assemble_model_pipeline( | |
dataset_cls, SUPPORTED_SUMM_MODELS | |
) | |
for model, model_name in matching_model_instances: | |
test_instances = retrieve_random_test_instances( | |
dataset_instances=dataset_instances, num_instances=1 | |
) | |
print_with_color( | |
f"{'#' * 20} Testing: {dataset.dataset_name} dataset, {model_name} model {'#' * 20}", | |
"35", | |
) | |
prediction, tgt = self.get_prediction( | |
model, dataset, test_instances | |
) | |
print(f"Prediction: {prediction}\nTarget: {tgt}\n") | |
for metric in evaluation_metrics: | |
print_with_color(f"{metric.metric_name} metric", "35") | |
score_dict = self.get_eval_dict(metric, prediction, tgt) | |
print(score_dict) | |
print_with_color( | |
f"{'#' * 20} Test for {dataset.dataset_name} dataset, {model_name} model COMPLETE {'#' * 20}\n\n", | |
"32", | |
) | |
if __name__ == "__main__": | |
if len(sys.argv) > 2 or ( | |
len(sys.argv) == 2 and not re.match("^\\d+$", sys.argv[1]) | |
): | |
print("Usage: python tests/integration_test.py [seed]", file=sys.stderr) | |
sys.exit(1) | |
seed = int(time.time()) if len(sys.argv) == 1 else int(sys.argv.pop()) | |
random.seed(seed) | |
print_with_color(f"(to reproduce) random seeded with {seed}\n", "32") | |
unittest.main() | |