old_tok / submitit_test.py
tennant's picture
upload
af7c0ce
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# A script to run multinode training with submitit.
# --------------------------------------------------------
import argparse
import os.path as osp
import submitit
import itertools
from omegaconf import OmegaConf
from paintmind.engine.util import instantiate_from_config
from paintmind.utils.device_utils import configure_compute_backend
def parse_args():
parser = argparse.ArgumentParser("Submitit for accelerator training")
# Slurm configuration
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
parser.add_argument("--timeout", default=7000, type=int, help="Duration of the job, default 5 days")
parser.add_argument("--qos", default="normal", type=str, help="QOS to request")
parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
parser.add_argument("--partition", default="h100-camera-train", type=str, help="Partition where to submit")
parser.add_argument("--exclude", default="", type=str, help="Exclude nodes from the partition")
parser.add_argument("--nodelist", default="", type=str, help="Nodelist to request")
parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
# Model and testing configuration
parser.add_argument('--model', type=str, nargs='+', default=[None], help="Path to model(s)")
parser.add_argument('--step', type=int, nargs='+', default=[250000], help="Step number(s)")
parser.add_argument('--cfg', type=str, default=None, help="Path to config file")
parser.add_argument('--dataset', type=str, default='imagenet', help="Dataset to use")
# Legacy parameter (preserved for backward compatibility)
parser.add_argument('--cfg_value', type=float, nargs='+', default=[None],
help='Legacy parameter for GPT classifier-free guidance scale')
# CFG-related parameters - all with nargs='+' to support multiple values
parser.add_argument('--ae_cfg', type=float, nargs='+', default=[None],
help="Autoencoder classifier-free guidance scale")
parser.add_argument('--diff_cfg', type=float, nargs='+', default=[None],
help="Diffusion classifier-free guidance scale")
parser.add_argument('--cfg_schedule', type=str, nargs='+', default=[None],
help="CFG schedule type (e.g., constant, linear)")
parser.add_argument('--diff_cfg_schedule', type=str, nargs='+', default=[None],
help="Diffusion CFG schedule type (e.g., constant, inv_linear)")
parser.add_argument('--test_num_slots', type=int, nargs='+', default=[None],
help="Number of slots to use for inference")
parser.add_argument('--temperature', type=float, nargs='+', default=[None],
help="Temperature for sampling")
return parser.parse_args()
def load_config(model_path, cfg_path=None):
"""Load configuration from file or model directory."""
if cfg_path is not None and osp.exists(cfg_path):
config_path = cfg_path
elif model_path and osp.exists(osp.join(model_path, 'config.yaml')):
config_path = osp.join(model_path, 'config.yaml')
else:
raise ValueError(f"No config file found at {model_path} or {cfg_path}")
return OmegaConf.load(config_path)
def setup_checkpoint_path(model_path, step, config):
"""Set up the checkpoint path based on model and step."""
if model_path:
ckpt_path = osp.join(model_path, 'models', f'step{step}')
if not osp.exists(ckpt_path):
print(f"Skipping non-existent checkpoint: {ckpt_path}")
return None
if hasattr(config.trainer.params, 'model'):
config.trainer.params.model.params.ckpt_path = ckpt_path
else:
config.trainer.params.gpt_model.params.ckpt_path = ckpt_path
else:
result_folder = config.trainer.params.result_folder
ckpt_path = osp.join(result_folder, 'models', f'step{step}')
if hasattr(config.trainer.params, 'model'):
config.trainer.params.model.params.ckpt_path = ckpt_path
else:
config.trainer.params.gpt_model.params.ckpt_path = ckpt_path
return ckpt_path
def setup_test_config(config, use_coco=False):
"""Set up common test configuration parameters."""
config.trainer.params.test_dataset = config.trainer.params.dataset
if not use_coco:
config.trainer.params.test_dataset.params.split = 'val'
else:
config.trainer.params.test_dataset.target = 'paintmind.utils.datasets.COCO'
config.trainer.params.test_dataset.params.root = './dataset/coco'
config.trainer.params.test_dataset.params.split = 'val2017'
config.trainer.params.test_only = True
config.trainer.params.compile = False
config.trainer.params.eval_fid = True
config.trainer.params.fid_stats = 'fid_stats/adm_in256_stats.npz'
if hasattr(config.trainer.params, 'model'):
config.trainer.params.model.params.num_sampling_steps = '250'
else:
config.trainer.params.ae_model.params.num_sampling_steps = '250'
def apply_cfg_params(config, param_dict):
"""Apply CFG-related parameters to the config."""
# Apply each parameter if it's not None
if param_dict.get('cfg_value') is not None:
config.trainer.params.cfg = param_dict['cfg_value']
print(f"Setting cfg to {param_dict['cfg_value']}")
if param_dict.get('ae_cfg') is not None:
config.trainer.params.ae_cfg = param_dict['ae_cfg']
print(f"Setting ae_cfg to {param_dict['ae_cfg']}")
if param_dict.get('diff_cfg') is not None:
config.trainer.params.diff_cfg = param_dict['diff_cfg']
print(f"Setting diff_cfg to {param_dict['diff_cfg']}")
if param_dict.get('cfg_schedule') is not None:
config.trainer.params.cfg_schedule = param_dict['cfg_schedule']
print(f"Setting cfg_schedule to {param_dict['cfg_schedule']}")
if param_dict.get('diff_cfg_schedule') is not None:
config.trainer.params.diff_cfg_schedule = param_dict['diff_cfg_schedule']
print(f"Setting diff_cfg_schedule to {param_dict['diff_cfg_schedule']}")
if param_dict.get('test_num_slots') is not None:
config.trainer.params.test_num_slots = param_dict['test_num_slots']
print(f"Setting test_num_slots to {param_dict['test_num_slots']}")
if param_dict.get('temperature') is not None:
config.trainer.params.temperature = param_dict['temperature']
print(f"Setting temperature to {param_dict['temperature']}")
def run_test(config):
"""Instantiate trainer and run test."""
trainer = instantiate_from_config(config.trainer)
trainer.train()
def generate_param_combinations(args):
"""Generate all combinations of parameters from the provided arguments."""
# Create parameter grid for all combinations
param_grid = {
'cfg_value': [None] if args.cfg_value == [None] else args.cfg_value,
'ae_cfg': [None] if args.ae_cfg == [None] else args.ae_cfg,
'diff_cfg': [None] if args.diff_cfg == [None] else args.diff_cfg,
'cfg_schedule': [None] if args.cfg_schedule == [None] else args.cfg_schedule,
'diff_cfg_schedule': [None] if args.diff_cfg_schedule == [None] else args.diff_cfg_schedule,
'test_num_slots': [None] if args.test_num_slots == [None] else args.test_num_slots,
'temperature': [None] if args.temperature == [None] else args.temperature
}
# Get all parameter names that have non-None values
active_params = [k for k, v in param_grid.items() if v != [None]]
if not active_params:
# If no parameters are specified, yield a dict with all None values
yield {k: None for k in param_grid.keys()}
return
# Generate all combinations of active parameters
active_values = [param_grid[k] for k in active_params]
for combination in itertools.product(*active_values):
param_dict = {k: None for k in param_grid.keys()} # Start with all None
for i, param_name in enumerate(active_params):
param_dict[param_name] = combination[i]
yield param_dict
class Trainer(object):
def __init__(self, args):
self.args = args
def __call__(self):
"""Main entry point for the submitit job."""
self._setup_gpu_args()
configure_compute_backend()
self._run_tests()
def _run_tests(self):
"""Run tests for all specified models and steps."""
for step in self.args.step:
for model in self.args.model:
print(f"Testing model: {model} at step: {step}")
# Load configuration
config = load_config(model, self.args.cfg)
# Setup checkpoint path
ckpt_path = setup_checkpoint_path(model, step, config)
if ckpt_path is None:
continue
use_coco = self.args.dataset == 'coco' or self.args.dataset == 'COCO'
# Setup test configuration
setup_test_config(config, use_coco)
# Generate and apply all parameter combinations
for param_dict in generate_param_combinations(self.args):
# Create a copy of the config for each parameter combination
current_config = OmegaConf.create(OmegaConf.to_container(config, resolve=True))
# Print parameter combination
param_str = ", ".join([f"{k}={v}" for k, v in param_dict.items() if v is not None])
print(f"Testing with parameters: {param_str}")
# Apply parameters and run test
apply_cfg_params(current_config, param_dict)
run_test(current_config)
def _setup_gpu_args(self):
"""Set up GPU and distributed environment variables."""
import submitit
print("Exporting PyTorch distributed environment variables")
dist_env = submitit.helpers.TorchDistributedEnvironment().export(set_cuda_visible_devices=False)
print(f"Master: {dist_env.master_addr}:{dist_env.master_port}")
print(f"Rank: {dist_env.rank}")
print(f"World size: {dist_env.world_size}")
print(f"Local rank: {dist_env.local_rank}")
print(f"Local world size: {dist_env.local_world_size}")
job_env = submitit.JobEnvironment()
self.args.output_dir = str(self.args.output_dir).replace("%j", str(job_env.job_id))
self.args.log_dir = self.args.output_dir
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
def main():
"""Main function to set up and submit the job."""
args = parse_args()
# Determine job directory
if args.cfg is not None and osp.exists(args.cfg):
config = OmegaConf.load(args.cfg)
elif osp.exists(osp.join(args.model[0], 'config.yaml')):
config = OmegaConf.load(osp.join(args.model[0], 'config.yaml'))
else:
raise ValueError(f"No config file found at {args.model[0]} or {args.cfg}")
args.job_dir = config.trainer.params.result_folder
# Set up the executor
executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
# Configure slurm parameters
slurm_kwargs = {
'slurm_signal_delay_s': 120,
'slurm_qos': args.qos
}
if args.comment:
slurm_kwargs['slurm_comment'] = args.comment
if args.exclude:
slurm_kwargs['slurm_exclude'] = args.exclude
if args.nodelist:
slurm_kwargs['slurm_nodelist'] = args.nodelist
# Update executor parameters
executor.update_parameters(
gpus_per_node=args.ngpus,
tasks_per_node=args.ngpus, # one task per GPU
nodes=args.nodes,
timeout_min=args.timeout,
slurm_partition=args.partition,
name="fid",
**slurm_kwargs
)
args.output_dir = args.job_dir
# Submit the job
trainer = Trainer(args)
job = executor.submit(trainer)
print("Submitted job_id:", job.job_id)
if __name__ == "__main__":
main()