File size: 7,649 Bytes
4bc9607 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import argparse, os, sys, datetime
from omegaconf import OmegaConf
from transformers import logging as transf_logging
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
import torch
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from utils.utils import instantiate_from_config
from utils_train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
from utils_train import set_logger, init_workspace, load_checkpoints
def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument("--seed", "-s", type=int, default=20230211, help="seed for seed_everything")
parser.add_argument("--name", "-n", type=str, default="", help="experiment name, as saving folder")
parser.add_argument("--base", "-b", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list())
parser.add_argument("--train", "-t", action='store_true', default=False, help='train')
parser.add_argument("--val", "-v", action='store_true', default=False, help='val')
parser.add_argument("--test", action='store_true', default=False, help='test')
parser.add_argument("--logdir", "-l", type=str, default="logs", help="directory for logging dat shit")
parser.add_argument("--auto_resume", action='store_true', default=False, help="resume from full-info checkpoint")
parser.add_argument("--auto_resume_weight_only", action='store_true', default=False, help="resume from weight-only checkpoint")
parser.add_argument("--debug", "-d", action='store_true', default=False, help="enable post-mortem debugging")
return parser
def get_nondefault_trainer_args(args):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
default_trainer_args = parser.parse_args([])
return sorted(k for k in vars(default_trainer_args) if getattr(args, k) != getattr(default_trainer_args, k))
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
local_rank = int(os.environ.get('LOCAL_RANK'))
global_rank = int(os.environ.get('RANK'))
num_rank = int(os.environ.get('WORLD_SIZE'))
parser = get_parser()
## Extends existing argparse by default Trainer attributes
parser = Trainer.add_argparse_args(parser)
args, unknown = parser.parse_known_args()
## disable transformer warning
transf_logging.set_verbosity_error()
seed_everything(args.seed)
## yaml configs: "model" | "data" | "lightning"
configs = [OmegaConf.load(cfg) for cfg in args.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
trainer_config = lightning_config.get("trainer", OmegaConf.create())
## setup workspace directories
workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir, config, lightning_config, global_rank)
logger = set_logger(logfile=os.path.join(loginfo, 'log_%d:%s.txt'%(global_rank, now)))
logger.info("@lightning version: %s [>=1.8 required]"%(pl.__version__))
## MODEL CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
logger.info("***** Configing Model *****")
config.model.params.logdir = workdir
model = instantiate_from_config(config.model)
## load checkpoints
model = load_checkpoints(model, config.model)
## register_schedule again to make ZTSNR work
if model.rescale_betas_zero_snr:
model.register_schedule(given_betas=model.given_betas, beta_schedule=model.beta_schedule, timesteps=model.timesteps,
linear_start=model.linear_start, linear_end=model.linear_end, cosine_s=model.cosine_s)
## update trainer config
for k in get_nondefault_trainer_args(args):
trainer_config[k] = getattr(args, k)
num_nodes = trainer_config.num_nodes
ngpu_per_node = trainer_config.devices
logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs")
## setup learning rate
base_lr = config.model.base_learning_rate
bs = config.data.params.batch_size
if getattr(config.model, 'scale_lr', True):
model.learning_rate = num_rank * bs * base_lr
else:
model.learning_rate = base_lr
## DATA CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
logger.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
for k in data.datasets:
logger.info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
## TRAINER CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
logger.info("***** Configing Trainer *****")
if "accelerator" not in trainer_config:
trainer_config["accelerator"] = "gpu"
## setup trainer args: pl-logger and callbacks
trainer_kwargs = dict()
trainer_kwargs["num_sanity_val_steps"] = 0
logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
## setup callbacks
callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir, ckptdir, logger)
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
strategy_cfg = get_trainer_strategy(lightning_config)
trainer_kwargs["strategy"] = strategy_cfg if type(strategy_cfg) == str else instantiate_from_config(strategy_cfg)
trainer_kwargs['precision'] = lightning_config.get('precision', 32)
trainer_kwargs["sync_batchnorm"] = False
## trainer config: others
trainer_args = argparse.Namespace(**trainer_config)
trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs)
## allow checkpointing via USR1
def melk(*args, **kwargs):
## run all checkpoint hooks
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb;
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
## Running LOOP >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
logger.info("***** Running the Loop *****")
if args.train:
try:
if "strategy" in lightning_config and lightning_config['strategy'].startswith('deepspeed'):
logger.info("<Training in DeepSpeed Mode>")
## deepspeed
if trainer_kwargs['precision'] == 16:
with torch.cuda.amp.autocast():
trainer.fit(model, data)
else:
trainer.fit(model, data)
else:
logger.info("<Training in DDPSharded Mode>") ## this is default
## ddpsharded
trainer.fit(model, data)
except Exception:
#melk()
raise
# if args.val:
# trainer.validate(model, data)
# if args.test or not trainer.interrupted:
# trainer.test(model, data) |