Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
import logging | |
import unittest | |
from io import StringIO | |
from unittest.mock import MagicMock, patch | |
import torch | |
from fairseq import checkpoint_utils, data | |
from omegaconf import OmegaConf | |
def mock_trainer(epoch, num_updates, iterations_in_epoch): | |
trainer = MagicMock() | |
trainer.load_checkpoint.return_value = { | |
"train_iterator": { | |
"epoch": epoch, | |
"iterations_in_epoch": iterations_in_epoch, | |
"shuffle": False, | |
}, | |
} | |
trainer.get_num_updates.return_value = num_updates | |
return trainer | |
def mock_dict(): | |
d = MagicMock() | |
d.pad.return_value = 1 | |
d.eos.return_value = 2 | |
d.unk.return_value = 3 | |
return d | |
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): | |
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1) | |
tokens_ds = data.TokenBlockDataset( | |
tokens, | |
sizes=[tokens.size(-1)], | |
block_size=1, | |
pad=0, | |
eos=1, | |
include_targets=False, | |
) | |
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) | |
dataset = data.LanguagePairDataset( | |
tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False | |
) | |
epoch_itr = data.EpochBatchIterator( | |
dataset=dataset, | |
collate_fn=dataset.collater, | |
batch_sampler=[[i] for i in range(epoch_size)], | |
) | |
return trainer, epoch_itr | |
def get_mock_cfg(finetune_from_model): | |
cfg_mock = OmegaConf.create( | |
{ | |
"checkpoint": { | |
"save_dir": None, | |
"optimizer_overrides": "{}", | |
"reset_dataloader": False, | |
"reset_meters": False, | |
"reset_optimizer": False, | |
"reset_lr_scheduler": False, | |
"finetune_from_model": finetune_from_model, | |
"model_parallel_size": 1, | |
"restore_file": "checkpoint_last.pt", | |
}, | |
"common": { | |
"model_parallel_size": 1, | |
}, | |
} | |
) | |
return cfg_mock | |
class TestLoadCheckpoint(unittest.TestCase): | |
def setUp(self): | |
self.cfg_mock = get_mock_cfg(None) | |
self.patches = { | |
"os.makedirs": MagicMock(), | |
"os.path.join": MagicMock(), | |
"os.path.isfile": MagicMock(return_value=True), | |
"os.path.isabs": MagicMock(return_value=False), | |
"fairseq.file_io.PathManager.exists": MagicMock(return_value=False), | |
} | |
self.applied_patches = [patch(p, d) for p, d in self.patches.items()] | |
[p.start() for p in self.applied_patches] | |
logging.disable(logging.CRITICAL) | |
def tearDown(self): | |
patch.stopall() | |
logging.disable(logging.NOTSET) | |
def test_load_partial_checkpoint(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
_, epoch_itr = checkpoint_utils.load_checkpoint( | |
self.cfg_mock.checkpoint, trainer | |
) | |
self.assertEqual(epoch_itr.epoch, 2) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 50) | |
itr = epoch_itr.next_epoch_itr(shuffle=False) | |
self.assertEqual(epoch_itr.epoch, 2) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 50) | |
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 50) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 51) | |
for _ in range(150 - 52): | |
next(itr) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 149) | |
self.assertTrue(itr.has_next()) | |
next(itr) | |
self.assertFalse(itr.has_next()) | |
itr = epoch_itr.next_epoch_itr(shuffle=False) | |
self.assertTrue(itr.has_next()) | |
self.assertEqual(epoch_itr.epoch, 3) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 0) | |
def test_load_full_checkpoint(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
_, epoch_itr = checkpoint_utils.load_checkpoint( | |
self.cfg_mock.checkpoint, trainer | |
) | |
itr = epoch_itr.next_epoch_itr(shuffle=False) | |
self.assertEqual(epoch_itr.epoch, 3) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 0) | |
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0) | |
def test_load_no_checkpoint(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
self.patches["os.path.isfile"].return_value = False | |
_, epoch_itr = checkpoint_utils.load_checkpoint( | |
self.cfg_mock.checkpoint, trainer | |
) | |
itr = epoch_itr.next_epoch_itr(shuffle=False) | |
self.assertEqual(epoch_itr.epoch, 1) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 0) | |
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0) | |
def test_finetune_from_model_args_conflict(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
for arg in [ | |
"reset_optimizer", | |
"reset_lr_scheduler", | |
"reset_meters", | |
"reset_dataloader", | |
]: | |
with self.subTest(arg=arg): | |
cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt") | |
cfg_mock["checkpoint"][arg] = True | |
with self.assertRaises(Exception) as context: | |
_, _ = checkpoint_utils.load_checkpoint( | |
cfg_mock.checkpoint, trainer | |
) | |
self.assertTrue( | |
"--finetune-from-model can not be set together with either --reset-optimizer" | |
" or reset_lr_scheduler or reset_meters or reset_dataloader" | |
in str(context.exception) | |
) | |
def test_finetune_from_model(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
from_model_path = "/temp/checkpoint_pretrained.pt" | |
def mock_finetune_exist(path): | |
if path == from_model_path: | |
return True | |
else: | |
return False | |
self.patches[ | |
"fairseq.file_io.PathManager.exists" | |
].side_effect = mock_finetune_exist | |
cfg_mock = get_mock_cfg(from_model_path) | |
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt" | |
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer) | |
( | |
checkpoint_path, | |
reset_optimizer, | |
reset_lr_scheduler, | |
optimizer_overrides, | |
) = trainer.load_checkpoint.call_args[0] | |
reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"] | |
self.assertTrue(reset_optimizer) | |
self.assertTrue(reset_lr_scheduler) | |
self.assertTrue(reset_meters) | |
def test_finetune_from_model_resume(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
from_model_path = "/temp/checkpoint_pretrained.pt" | |
# launch second time | |
# both restore_file=checkpoint_last.pt and finetune_from_model are set | |
def mock_finetune_exist(path): | |
if path == from_model_path or path.endsWith("checkpoint_last.pt"): | |
return True | |
else: | |
return False | |
self.patches[ | |
"fairseq.file_io.PathManager.exists" | |
].side_effect = mock_finetune_exist | |
cfg_mock = get_mock_cfg(from_model_path) | |
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt" | |
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer) | |
( | |
checkpoint_path, | |
reset_optimizer, | |
reset_lr_scheduler, | |
optimizer_overrides, | |
) = trainer.load_checkpoint.call_args[0] | |
reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"] | |
self.assertFalse(reset_optimizer) | |
self.assertFalse(reset_lr_scheduler) | |
self.assertFalse(reset_meters) | |
if __name__ == "__main__": | |
unittest.main() | |