Spaces:
Runtime error
Runtime error
File size: 1,612 Bytes
04a30fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import pytest
from unittest.mock import Mock, call
from datasets import Dataset
from substra_template.substra_runner import SubstraRunner
class TestSubstraRunner:
@pytest.fixture
def mock_substra_client_class(self, monkeypatch):
mock_substra_client_class = Mock()
monkeypatch.setattr("substra_template.substra_runner.Client", mock_substra_client_class)
return mock_substra_client_class
@pytest.fixture
def mock_load_dataset(self, monkeypatch):
mock_load_dataset = Mock()
monkeypatch.setattr("substra_template.substra_runner.load_dataset", mock_load_dataset)
return mock_load_dataset
def test_set_up_clients(self, mock_substra_client_class):
runner = SubstraRunner()
runner.set_up_clients()
mock_substra_client_class.assert_called()
def test_prepare_data(self, mock_load_dataset):
runner = SubstraRunner()
runner.prepare_data()
mock_load_dataset.assert_has_calls(calls=[
call("mnist", split="train"),
call("mnist", split="test"),
], any_order=True)
assert len(runner.datasets) == runner.num_clients - 1
def test_register_data(self, mock_load_dataset):
runner = SubstraRunner()
runner.datasets = [Dataset.from_dict({}) for _ in range(runner.num_clients - 1)]
runner.register_data()
def test_register_metric(self):
runner = SubstraRunner()
runner.set_up_clients()
runner.register_metric()
def test_set_aggregation(self):
pass
def test_set_testing(self):
pass
|