import os import torch import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP def continue_training(checkpoint_path, generator: DDP, mpd: DDP, mrd: DDP, optimizer_d: optim.Optimizer, optimizer_g: optim.Optimizer) -> int: """load the latest checkpoints and optimizers""" generator_dict = {} mpd_dict = {} mrd_dict = {} optimizer_d_dict = {} optimizer_g_dict = {} # globt all the checkpoints in the directory for file in os.listdir(checkpoint_path): if file.endswith(".pt"): name, epoch_str = file.rsplit('_', 1) epoch = int(epoch_str.split('.')[0]) if name.startswith("generator"): generator_dict[epoch] = file elif name.startswith("mpd"): mpd_dict[epoch] = file elif name.startswith("mrd"): mrd_dict[epoch] = file elif name.startswith("optimizerd"): optimizer_d_dict[epoch] = file elif name.startswith("optimizerg"): optimizer_g_dict[epoch] = file # get the largest epoch common_epochs = set(generator_dict.keys()) & set(mpd_dict.keys()) & set(mrd_dict.keys()) & set(optimizer_d_dict.keys()) & set(optimizer_g_dict.keys()) if common_epochs: max_epoch = max(common_epochs) generator_path = os.path.join(checkpoint_path, generator_dict[max_epoch]) mpd_path = os.path.join(checkpoint_path, mpd_dict[max_epoch]) mrd_path = os.path.join(checkpoint_path, mrd_dict[max_epoch]) optimizer_d_path = os.path.join(checkpoint_path, optimizer_d_dict[max_epoch]) optimizer_g_path = os.path.join(checkpoint_path, optimizer_g_dict[max_epoch]) # load model and optimizer generator.module.load_state_dict(torch.load(generator_path, map_location='cpu')) mpd.module.load_state_dict(torch.load(mpd_path, map_location='cpu')) mrd.module.load_state_dict(torch.load(mrd_path, map_location='cpu')) optimizer_d.load_state_dict(torch.load(optimizer_d_path, map_location='cpu')) optimizer_g.load_state_dict(torch.load(optimizer_g_path, map_location='cpu')) print(f'resume model and optimizer from {max_epoch} epoch') return max_epoch + 1 else: return 0