Spaces:
Build error
Build error
File size: 1,206 Bytes
66a6dc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import os
import sys
import shutil
import pytorch_lightning as pl
class CopyPretrainedCheckpoints(pl.callbacks.Callback):
def __init__(self):
super().__init__()
def on_fit_start(self, trainer, pl_module):
"""Before training, move the pre-trained checkpoints
to the current checkpoint directory.
"""
# copy any pre-trained checkpoints to new directory
if pl_module.hparams.processor_model == "proxy":
pretrained_ckpt_dir = os.path.join(
pl_module.logger.experiment.log_dir, "pretrained_checkpoints"
)
if not os.path.isdir(pretrained_ckpt_dir):
os.makedirs(pretrained_ckpt_dir)
cp_proxy_ckpts = []
for proxy_ckpt in pl_module.hparams.proxy_ckpts:
new_ckpt = shutil.copy(
proxy_ckpt,
pretrained_ckpt_dir,
)
cp_proxy_ckpts.append(new_ckpt)
print(f"Moved checkpoint to {new_ckpt}.")
# overwrite to the paths in current experiment logs
pl_module.hparams.proxy_ckpts = cp_proxy_ckpts
print(pl_module.hparams.proxy_ckpts)
|