Spaces:
Sleeping
Sleeping
import glob | |
from unittest import TestCase | |
from unittest.mock import patch | |
import pytest | |
import requests | |
import yaml | |
from evaluate.hub import push_to_hub | |
from tests.test_metric import DummyMetric | |
minimum_metadata = { | |
"model-index": [ | |
{ | |
"results": [ | |
{ | |
"task": {"type": "dummy-task"}, | |
"dataset": {"type": "dataset_type", "name": "dataset_name"}, | |
"metrics": [ | |
{"type": "dummy_metric", "value": 1.0, "name": "Pretty Metric Name"}, | |
], | |
} | |
] | |
} | |
] | |
} | |
extras_metadata = { | |
"model-index": [ | |
{ | |
"results": [ | |
{ | |
"task": {"type": "dummy-task", "name": "task_name"}, | |
"dataset": { | |
"type": "dataset_type", | |
"name": "dataset_name", | |
"config": "fr", | |
"split": "test", | |
"revision": "abc", | |
"args": {"a": 1, "b": 2}, | |
}, | |
"metrics": [ | |
{ | |
"type": "dummy_metric", | |
"value": 1.0, | |
"name": "Pretty Metric Name", | |
"config": "default", | |
"args": {"hello": 1, "world": 2}, | |
}, | |
], | |
} | |
] | |
} | |
] | |
} | |
class TestHub(TestCase): | |
def inject_fixtures(self, caplog): | |
self._caplog = caplog | |
def setUp(self): | |
self.metric = DummyMetric() | |
self.metric.add() | |
self.args = {"hello": 1, "world": 2} | |
self.result = self.metric.compute() | |
def test_push_metric_required_arguments(self, metadata_update): | |
push_to_hub( | |
model_id="username/repo", | |
metric_value=self.result["accuracy"], | |
metric_name="Pretty Metric Name", | |
metric_type=self.metric.name, | |
dataset_name="dataset_name", | |
dataset_type="dataset_type", | |
task_type="dummy-task", | |
) | |
metadata_update.assert_called_once_with(repo_id="username/repo", metadata=minimum_metadata, overwrite=False) | |
def test_push_metric_missing_arguments(self, metadata_update): | |
with pytest.raises(TypeError): | |
push_to_hub( | |
model_id="username/repo", | |
metric_value=self.result["accuracy"], | |
metric_name="Pretty Metric Name", | |
metric_type=self.metric.name, | |
dataset_name="dataset_name", | |
dataset_type="dummy-task", | |
) | |
def test_push_metric_invalid_arguments(self, metadata_update): | |
with pytest.raises(TypeError): | |
push_to_hub( | |
model_id="username/repo", | |
metric_value=self.result["accuracy"], | |
metric_name="Pretty Metric Name", | |
metric_type=self.metric.name, | |
dataset_name="dataset_name", | |
dataset_type="dataset_type", | |
task_type="dummy-task", | |
random_value="incorrect", | |
) | |
def test_push_metric_extra_arguments(self, metadata_update): | |
push_to_hub( | |
model_id="username/repo", | |
metric_value=self.result["accuracy"], | |
metric_name="Pretty Metric Name", | |
metric_type=self.metric.name, | |
dataset_name="dataset_name", | |
dataset_type="dataset_type", | |
dataset_config="fr", | |
dataset_split="test", | |
dataset_revision="abc", | |
dataset_args={"a": 1, "b": 2}, | |
task_type="dummy-task", | |
task_name="task_name", | |
metric_config=self.metric.config_name, | |
metric_args=self.args, | |
) | |
metadata_update.assert_called_once_with(repo_id="username/repo", metadata=extras_metadata, overwrite=False) | |
def test_push_metric_invalid_task_type(self, metadata_update): | |
with pytest.raises(ValueError): | |
push_to_hub( | |
model_id="username/repo", | |
metric_value=self.result["accuracy"], | |
metric_name="Pretty Metric Name", | |
metric_type=self.metric.name, | |
dataset_name="dataset_name", | |
dataset_type="dataset_type", | |
task_type="audio-classification", | |
) | |
def test_push_metric_invalid_dataset_type(self, metadata_update): | |
with patch("evaluate.hub.dataset_info") as mock_dataset_info: | |
mock_dataset_info.side_effect = requests.HTTPError() | |
push_to_hub( | |
model_id="username/repo", | |
metric_value=self.result["accuracy"], | |
metric_name="Pretty Metric Name", | |
metric_type=self.metric.name, | |
dataset_name="dataset_name", | |
dataset_type="dataset_type", | |
task_type="dummy-task", | |
) | |
assert "Dataset dataset_type not found on the Hub at hf.co/datasets/dataset_type" in self._caplog.text | |
metadata_update.assert_called_once_with( | |
repo_id="username/repo", metadata=minimum_metadata, overwrite=False | |
) | |
def test_push_metric_invalid_model_id(self, metadata_update): | |
with patch("evaluate.hub.model_info") as mock_model_info: | |
mock_model_info.side_effect = requests.HTTPError() | |
with pytest.raises(ValueError): | |
push_to_hub( | |
model_id="username/bad-repo", | |
metric_value=self.result["accuracy"], | |
metric_name="Pretty Metric Name", | |
metric_type=self.metric.name, | |
dataset_name="dataset_name", | |
dataset_type="dataset_type", | |
task_type="dummy-task", | |
) | |
class ValidateYaml(TestCase): | |
def setUp(self): | |
pass | |
def testLoadingCards(self): | |
readme_filepaths = [] | |
for glob_path in ["measurements/*/README.md", "metrics/*/README.md", "comparisons/*/README.md"]: | |
readme_filepaths.extend(glob.glob(glob_path)) | |
for readme_file in readme_filepaths: | |
with open(readme_file, encoding="utf8") as f_yaml: | |
x = yaml.safe_load_all(f_yaml) | |
self.assertIsInstance(next(x), dict) | |