|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import doctest |
|
import sys |
|
import warnings |
|
from os.path import abspath, dirname, join |
|
|
|
import _pytest |
|
import pytest |
|
|
|
from transformers.testing_utils import HfDoctestModule, HfDocTestParser |
|
|
|
|
|
NOT_DEVICE_TESTS = { |
|
"test_tokenization", |
|
"test_processor", |
|
"test_processing", |
|
"test_beam_constraints", |
|
"test_configuration_utils", |
|
"test_data_collator", |
|
"test_trainer_callback", |
|
"test_trainer_utils", |
|
"test_feature_extraction", |
|
"test_image_processing", |
|
"test_image_processor", |
|
"test_image_transforms", |
|
"test_optimization", |
|
"test_retrieval", |
|
"test_config", |
|
"test_from_pretrained_no_checkpoint", |
|
"test_keep_in_fp32_modules", |
|
"test_gradient_checkpointing_backward_compatibility", |
|
"test_gradient_checkpointing_enable_disable", |
|
"test_save_load_fast_init_from_base", |
|
"test_fast_init_context_manager", |
|
"test_fast_init_tied_embeddings", |
|
"test_save_load_fast_init_to_base", |
|
"test_torch_save_load", |
|
"test_initialization", |
|
"test_forward_signature", |
|
"test_model_common_attributes", |
|
"test_model_main_input_name", |
|
"test_correct_missing_keys", |
|
"test_tie_model_weights", |
|
"test_can_use_safetensors", |
|
"test_load_save_without_tied_weights", |
|
"test_tied_weights_keys", |
|
"test_model_weights_reload_no_missing_tied_weights", |
|
"test_pt_tf_model_equivalence", |
|
"test_mismatched_shapes_have_properly_initialized_weights", |
|
"test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist", |
|
"test_model_is_small", |
|
"test_tf_from_pt_safetensors", |
|
"test_flax_from_pt_safetensors", |
|
"ModelTest::test_pipeline_", |
|
"ModelTester::test_pipeline_", |
|
"/repo_utils/", |
|
"/utils/", |
|
"/tools/", |
|
} |
|
|
|
|
|
|
|
git_repo_path = abspath(join(dirname(__file__), "src")) |
|
sys.path.insert(1, git_repo_path) |
|
|
|
|
|
|
|
warnings.simplefilter(action="ignore", category=FutureWarning) |
|
|
|
|
|
def pytest_configure(config): |
|
config.addinivalue_line( |
|
"markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested" |
|
) |
|
config.addinivalue_line( |
|
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" |
|
) |
|
config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested") |
|
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment") |
|
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate") |
|
config.addinivalue_line("markers", "tool_tests: mark the tool tests that are run on their specific schedule") |
|
config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu") |
|
|
|
|
|
def pytest_collection_modifyitems(items): |
|
for item in items: |
|
if any(test_name in item.nodeid for test_name in NOT_DEVICE_TESTS): |
|
item.add_marker(pytest.mark.not_device_test) |
|
|
|
|
|
def pytest_addoption(parser): |
|
from transformers.testing_utils import pytest_addoption_shared |
|
|
|
pytest_addoption_shared(parser) |
|
|
|
|
|
def pytest_terminal_summary(terminalreporter): |
|
from transformers.testing_utils import pytest_terminal_summary_main |
|
|
|
make_reports = terminalreporter.config.getoption("--make-reports") |
|
if make_reports: |
|
pytest_terminal_summary_main(terminalreporter, id=make_reports) |
|
|
|
|
|
def pytest_sessionfinish(session, exitstatus): |
|
|
|
if exitstatus == 5: |
|
session.exitstatus = 0 |
|
|
|
|
|
|
|
IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT") |
|
|
|
OutputChecker = doctest.OutputChecker |
|
|
|
|
|
class CustomOutputChecker(OutputChecker): |
|
def check_output(self, want, got, optionflags): |
|
if IGNORE_RESULT & optionflags: |
|
return True |
|
return OutputChecker.check_output(self, want, got, optionflags) |
|
|
|
|
|
doctest.OutputChecker = CustomOutputChecker |
|
_pytest.doctest.DoctestModule = HfDoctestModule |
|
doctest.DocTestParser = HfDocTestParser |
|
|