|
|
|
|
|
"""Input/output checkpointing.""" |
|
|
|
import os |
|
import random |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from megatron import update_num_microbatches |
|
from megatron.core import mpu, tensor_parallel |
|
|
|
from .global_vars import get_args |
|
from .utils import print_rank_0, unwrap_model |
|
|
|
_CHECKPOINT_VERSION = None |
|
|
|
|
|
def set_checkpoint_version(value): |
|
global _CHECKPOINT_VERSION |
|
if _CHECKPOINT_VERSION is not None: |
|
assert _CHECKPOINT_VERSION == value, "checkpoint versions do not match" |
|
_CHECKPOINT_VERSION = value |
|
|
|
|
|
def get_checkpoint_version(): |
|
global _CHECKPOINT_VERSION |
|
return _CHECKPOINT_VERSION |
|
|
|
|
|
def check_checkpoint_args(checkpoint_args): |
|
"""Ensure fixed arguments for a model are the same for the input |
|
arguments and the one retrieved from checkpoint.""" |
|
args = get_args() |
|
|
|
def _compare(arg_name, old_arg_name=None): |
|
if old_arg_name is not None: |
|
checkpoint_value = getattr(checkpoint_args, old_arg_name) |
|
else: |
|
checkpoint_value = getattr(checkpoint_args, arg_name) |
|
args_value = getattr(args, arg_name) |
|
error_message = ( |
|
"{} value from checkpoint ({}) is not equal to the " |
|
"input argument value ({}).".format(arg_name, checkpoint_value, args_value) |
|
) |
|
|
|
|
|
_compare("num_layers") |
|
_compare("hidden_size") |
|
_compare("num_attention_heads") |
|
if args.vocab_file: |
|
_compare("max_position_embeddings") |
|
_compare("make_vocab_size_divisible_by") |
|
_compare("padded_vocab_size") |
|
_compare("tokenizer_type") |
|
if args.data_parallel_random_init: |
|
_compare("data_parallel_random_init") |
|
if get_checkpoint_version() < 3.0: |
|
_compare("tensor_model_parallel_size", old_arg_name="model_parallel_size") |
|
if get_checkpoint_version() >= 3.0: |
|
_compare("tensor_model_parallel_size") |
|
_compare("pipeline_model_parallel_size") |
|
|
|
|
|
def ensure_directory_exists(filename): |
|
"""Build filename's path if it does not already exists.""" |
|
dirname = os.path.dirname(filename) |
|
if not os.path.exists(dirname): |
|
os.makedirs(dirname) |
|
|
|
|
|
def get_checkpoint_name( |
|
checkpoints_path, |
|
iteration, |
|
release=False, |
|
pipeline_parallel=None, |
|
tensor_rank=None, |
|
pipeline_rank=None, |
|
): |
|
"""Determine the directory name for this rank's checkpoint.""" |
|
if release: |
|
directory = "release" |
|
else: |
|
directory = "iter_{:07d}".format(iteration) |
|
|
|
|
|
if pipeline_parallel is None: |
|
pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 |
|
if tensor_rank is None: |
|
tensor_rank = mpu.get_tensor_model_parallel_rank() |
|
if pipeline_rank is None: |
|
pipeline_rank = mpu.get_pipeline_model_parallel_rank() |
|
|
|
|
|
|
|
|
|
if not pipeline_parallel: |
|
common_path = os.path.join( |
|
checkpoints_path, directory, f"mp_rank_{tensor_rank:02d}" |
|
) |
|
else: |
|
common_path = os.path.join( |
|
checkpoints_path, |
|
directory, |
|
f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}", |
|
) |
|
|
|
return os.path.join(common_path, "model_optim_rng.pt") |
|
|
|
|
|
def get_checkpoint_names( |
|
checkpoints_path, |
|
iteration, |
|
use_distributed_optimizer, |
|
release=False, |
|
pipeline_parallel=None, |
|
tensor_rank=None, |
|
pipeline_rank=None, |
|
): |
|
"""Determine the directory name for this rank's checkpoint.""" |
|
if release: |
|
directory = "release" |
|
else: |
|
directory = "iter_{:07d}".format(iteration) |
|
|
|
|
|
if pipeline_parallel is None: |
|
pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 |
|
if tensor_rank is None: |
|
tensor_rank = mpu.get_tensor_model_parallel_rank() |
|
if pipeline_rank is None: |
|
pipeline_rank = mpu.get_pipeline_model_parallel_rank() |
|
|
|
|
|
|
|
|
|
if not pipeline_parallel: |
|
common_path = os.path.join( |
|
checkpoints_path, directory, f"mp_rank_{tensor_rank:02d}" |
|
) |
|
else: |
|
common_path = os.path.join( |
|
checkpoints_path, |
|
directory, |
|
f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}", |
|
) |
|
|
|
if use_distributed_optimizer: |
|
model_name = os.path.join(common_path, "model_rng.pt") |
|
optim_name = os.path.join( |
|
common_path + "_%03d" % mpu.get_data_parallel_rank(), "optim.pt" |
|
) |
|
else: |
|
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt") |
|
return model_name, optim_name |
|
|
|
|
|
def find_checkpoint_rank_0( |
|
checkpoints_path, iteration, use_distributed_optimizer, release=False |
|
): |
|
"""Finds the checkpoint for rank 0 without knowing if we are using |
|
pipeline parallelism or not. |
|
|
|
Since the checkpoint naming scheme changes if pipeline parallelism |
|
is present, we need to look for both naming schemes if we don't |
|
know if the checkpoint has pipeline parallelism. |
|
|
|
""" |
|
|
|
|
|
filenames = get_checkpoint_names( |
|
checkpoints_path, |
|
iteration, |
|
use_distributed_optimizer, |
|
release, |
|
pipeline_parallel=False, |
|
tensor_rank=0, |
|
pipeline_rank=0, |
|
) |
|
if os.path.isfile(filenames[0]): |
|
return filenames |
|
|
|
|
|
filenames = get_checkpoint_names( |
|
checkpoints_path, |
|
iteration, |
|
use_distributed_optimizer, |
|
release, |
|
pipeline_parallel=True, |
|
tensor_rank=0, |
|
pipeline_rank=0, |
|
) |
|
if os.path.isfile(filenames[0]): |
|
return filenames |
|
|
|
return None, None |
|
|
|
|
|
def get_checkpoint_tracker_filename(checkpoints_path): |
|
"""Tracker file rescords the latest chckpoint during |
|
training to restart from.""" |
|
return os.path.join(checkpoints_path, "latest_checkpointed_iteration.txt") |
|
|
|
|
|
def read_metadata(tracker_filename): |
|
|
|
|
|
iteration = 0 |
|
release = False |
|
with open(tracker_filename, "r") as f: |
|
metastring = f.read().strip() |
|
try: |
|
iteration = int(metastring) |
|
except ValueError: |
|
release = metastring == "release" |
|
if not release: |
|
print_rank_0( |
|
"ERROR: Invalid metadata file {}. Exiting".format(tracker_filename) |
|
) |
|
sys.exit() |
|
assert iteration > 0 or release, "error parsing metadata file {}".format( |
|
tracker_filename |
|
) |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
iters_cuda = torch.cuda.LongTensor([iteration]) |
|
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX) |
|
max_iter = iters_cuda[0].item() |
|
|
|
|
|
|
|
|
|
if iteration != max_iter: |
|
print( |
|
"WARNING: on rank {} found iteration {} in the " |
|
"metadata while max iteration across the ranks " |
|
"is {}, replacing it with max iteration.".format( |
|
rank, iteration, max_iter |
|
), |
|
flush=True, |
|
) |
|
else: |
|
|
|
|
|
|
|
max_iter = iteration |
|
return max_iter, release |
|
|
|
|
|
def get_rng_state(): |
|
"""collect rng state across data parallel ranks""" |
|
args = get_args() |
|
rng_state = { |
|
"random_rng_state": random.getstate(), |
|
"np_rng_state": np.random.get_state(), |
|
"torch_rng_state": torch.get_rng_state(), |
|
"cuda_rng_state": torch.cuda.get_rng_state(), |
|
"rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), |
|
} |
|
|
|
rng_state_list = None |
|
if ( |
|
torch.distributed.is_initialized() |
|
and mpu.get_data_parallel_world_size() > 1 |
|
and args.data_parallel_random_init |
|
): |
|
rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] |
|
torch.distributed.all_gather_object( |
|
rng_state_list, rng_state, group=mpu.get_data_parallel_group() |
|
) |
|
else: |
|
rng_state_list = [rng_state] |
|
|
|
return rng_state_list |
|
|
|
|
|
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): |
|
"""Save a model checkpoint.""" |
|
args = get_args() |
|
|
|
|
|
model = unwrap_model(model) |
|
|
|
release = iteration == "release" |
|
if release: |
|
print_rank_0( |
|
"saving checkpoint marked as release to {}".format(iteration, args.save) |
|
) |
|
else: |
|
print_rank_0( |
|
"saving checkpoint at iteration {:7d} to {}".format(iteration, args.save) |
|
) |
|
|
|
|
|
rng_state = get_rng_state() |
|
|
|
|
|
model_checkpoint_name, optim_checkpoint_name = get_checkpoint_names( |
|
args.save, iteration, args.use_distributed_optimizer, release=release |
|
) |
|
|
|
|
|
model_state_dict = {} |
|
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0: |
|
|
|
model_state_dict["args"] = args |
|
model_state_dict["checkpoint_version"] = 3.0 |
|
model_state_dict["iteration"] = iteration |
|
if len(model) == 1: |
|
model_state_dict["model"] = model[0].state_dict_for_save_checkpoint() |
|
else: |
|
for i in range(len(model)): |
|
mpu.set_virtual_pipeline_model_parallel_rank(i) |
|
model_state_dict["model%d" % i] = model[ |
|
i |
|
].state_dict_for_save_checkpoint() |
|
|
|
|
|
if not args.no_save_rng: |
|
model_state_dict["rng_state"] = rng_state |
|
|
|
|
|
|
|
optim_state_dict = {} |
|
if not args.no_save_optim and ( |
|
not torch.distributed.is_initialized() |
|
or mpu.get_data_parallel_rank() == 0 |
|
or args.use_distributed_optimizer |
|
): |
|
|
|
if optimizer is not None: |
|
optim_state_dict["optimizer"] = optimizer.state_dict() |
|
if opt_param_scheduler is not None: |
|
optim_state_dict["opt_param_scheduler"] = opt_param_scheduler.state_dict() |
|
|
|
|
|
if args.use_distributed_optimizer: |
|
|
|
if model_state_dict: |
|
ensure_directory_exists(model_checkpoint_name) |
|
torch.save(model_state_dict, model_checkpoint_name) |
|
if optim_state_dict: |
|
ensure_directory_exists(optim_checkpoint_name) |
|
torch.save(optim_state_dict, optim_checkpoint_name) |
|
else: |
|
|
|
state_dict = {**model_state_dict, **optim_state_dict} |
|
if state_dict: |
|
ensure_directory_exists(model_checkpoint_name) |
|
torch.save(state_dict, model_checkpoint_name) |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
|
|
if release: |
|
print_rank_0( |
|
" successfully saved checkpoint marked as release to {}".format( |
|
iteration, args.save |
|
) |
|
) |
|
else: |
|
print_rank_0( |
|
" successfully saved checkpoint at iteration {:7d} to {}".format( |
|
iteration, args.save |
|
) |
|
) |
|
|
|
|
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: |
|
tracker_filename = get_checkpoint_tracker_filename(args.save) |
|
with open(tracker_filename, "w") as f: |
|
f.write(str(iteration)) |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
|
|
|
|
def _transpose_first_dim(t, num_splits, num_splits_first, model): |
|
input_shape = t.size() |
|
|
|
|
|
while hasattr(model, "module"): |
|
model = model.module |
|
attention_module = model.language_model.encoder.layers[0].self_attention |
|
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head |
|
num_attention_heads_per_partition = ( |
|
attention_module.num_attention_heads_per_partition |
|
) |
|
if num_splits_first: |
|
"""[num_splits * np * hn, h] |
|
-->(view) [num_splits, np, hn, h] |
|
-->(tranpose) [np, num_splits, hn, h] |
|
-->(view) [np * num_splits * hn, h]""" |
|
|
|
intermediate_shape = ( |
|
num_splits, |
|
num_attention_heads_per_partition, |
|
hidden_size_per_attention_head, |
|
) + input_shape[1:] |
|
|
|
t = t.view(*intermediate_shape) |
|
t = t.transpose(0, 1).contiguous() |
|
else: |
|
"""[np * hn * num_splits, h] |
|
-->(view) [np, hn, num_splits, h] |
|
-->(tranpose) [np, num_splits, hn, h] |
|
-->(view) [np * num_splits * hn, h]""" |
|
|
|
intermediate_shape = ( |
|
num_attention_heads_per_partition, |
|
hidden_size_per_attention_head, |
|
num_splits, |
|
) + input_shape[1:] |
|
|
|
t = t.view(*intermediate_shape) |
|
t = t.transpose(1, 2).contiguous() |
|
t = t.view(*input_shape) |
|
|
|
return t |
|
|
|
|
|
def fix_query_key_value_ordering(model, checkpoint_version): |
|
"""Fix up query/key/value matrix ordering if checkpoint |
|
version is smaller than 2.0 |
|
""" |
|
if checkpoint_version < 2.0: |
|
if isinstance(model, list): |
|
assert len(model) == 1 |
|
model = model[0] |
|
for name, param in model.named_parameters(): |
|
if name.endswith((".query_key_value.weight", ".query_key_value.bias")): |
|
|
|
args = get_args() |
|
if args.num_attention_heads_kv != args.num_attention_heads: |
|
continue |
|
if checkpoint_version == 0: |
|
fixed_param = _transpose_first_dim(param.data, 3, True, model) |
|
elif checkpoint_version == 1.0: |
|
fixed_param = _transpose_first_dim(param.data, 3, False, model) |
|
else: |
|
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") |
|
sys.exit() |
|
param.data.copy_(fixed_param) |
|
if name.endswith((".key_value.weight", ".key_value.bias")): |
|
if checkpoint_version == 0: |
|
fixed_param = _transpose_first_dim(param.data, 2, True, model) |
|
elif checkpoint_version == 1.0: |
|
fixed_param = _transpose_first_dim(param.data, 2, False, model) |
|
else: |
|
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") |
|
sys.exit() |
|
param.data.copy_(fixed_param) |
|
print_rank_0( |
|
" succesfully fixed query-key-values ordering for" |
|
" checkpoint version {}".format(checkpoint_version) |
|
) |
|
|
|
|
|
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False): |
|
"""Load the base state_dict from the given directory |
|
|
|
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. |
|
""" |
|
|
|
tracker_filename = get_checkpoint_tracker_filename(load_dir) |
|
|
|
|
|
if not os.path.isfile(tracker_filename): |
|
if not rank0: |
|
print_rank_0( |
|
"WARNING: could not find the metadata file {} ".format(tracker_filename) |
|
) |
|
print_rank_0( |
|
" will not load any checkpoints and will start from " "random" |
|
) |
|
return None, None, False |
|
|
|
|
|
|
|
iteration, release = read_metadata(tracker_filename) |
|
|
|
|
|
if rank0: |
|
checkpoint_names = find_checkpoint_rank_0( |
|
load_dir, iteration, use_distributed_optimizer, release |
|
) |
|
else: |
|
checkpoint_names = get_checkpoint_names( |
|
load_dir, iteration, use_distributed_optimizer, release |
|
) |
|
if release: |
|
print_rank_0(f" loading release checkpoint from {load_dir}") |
|
else: |
|
print_rank_0( |
|
f" loading checkpoint from {load_dir} at iteration {iteration}" |
|
) |
|
|
|
model_checkpoint_name, optim_checkpoint_name = checkpoint_names |
|
|
|
|
|
try: |
|
model_state_dict = torch.load(model_checkpoint_name, map_location="cpu") |
|
if use_distributed_optimizer: |
|
optim_state_dict = torch.load(optim_checkpoint_name, map_location="cpu") |
|
else: |
|
optim_state_dict = model_state_dict |
|
except ModuleNotFoundError: |
|
from megatron.fp16_deprecated import loss_scaler |
|
|
|
|
|
if not rank0: |
|
print_rank_0(" > deserializing using the old code structure ...") |
|
sys.modules["fp16.loss_scaler"] = sys.modules[ |
|
"megatron.fp16_deprecated.loss_scaler" |
|
] |
|
sys.modules["megatron.fp16.loss_scaler"] = sys.modules[ |
|
"megatron.fp16_deprecated.loss_scaler" |
|
] |
|
model_state_dict = torch.load(model_checkpoint_name, map_location="cpu") |
|
optim_state_dict = torch.load(optim_checkpoint_name, map_location="cpu") |
|
sys.modules.pop("fp16.loss_scaler", None) |
|
sys.modules.pop("megatron.fp16.loss_scaler", None) |
|
except BaseException as e: |
|
print_rank_0("could not load the checkpoint") |
|
print_rank_0(e) |
|
sys.exit() |
|
return model_state_dict, optim_state_dict, release |
|
|
|
|
|
def load_args_from_checkpoint(args, load_arg="load"): |
|
"""Set required arguments from the checkpoint specified in the |
|
arguments. |
|
|
|
Will overwrite arguments that have a non-None default value, but |
|
will leave any arguments that default to None as set. |
|
|
|
Returns the same args NameSpace with the new values added/updated. |
|
|
|
If no checkpoint is specified in args, or if the checkpoint is |
|
there but invalid, the arguments will not be modified |
|
|
|
""" |
|
load_dir = getattr(args, load_arg) |
|
|
|
if load_dir is None: |
|
print_rank_0("No load directory specified, using provided arguments.") |
|
return args |
|
|
|
model_state_dict, optim_state_dict, release = _load_base_checkpoint( |
|
load_dir, use_distributed_optimizer=args.use_distributed_optimizer, rank0=True |
|
) |
|
|
|
|
|
state_dict = model_state_dict |
|
|
|
if not state_dict: |
|
print_rank_0( |
|
"Checkpoint not found to provide arguments, using provided arguments." |
|
) |
|
return args |
|
|
|
if "args" not in state_dict: |
|
print_rank_0( |
|
"Checkpoint provided does not have arguments saved, using provided arguments." |
|
) |
|
return args |
|
|
|
checkpoint_args = state_dict["args"] |
|
checkpoint_version = state_dict.get("checkpoint_version", 0) |
|
args.iteration = state_dict["iteration"] |
|
|
|
def _set_arg(arg_name, old_arg_name=None, force=False): |
|
if not force and getattr(args, arg_name, None) is not None: |
|
return |
|
|
|
if old_arg_name is not None: |
|
checkpoint_value = getattr(checkpoint_args, old_arg_name, None) |
|
else: |
|
checkpoint_value = getattr(checkpoint_args, arg_name, None) |
|
|
|
if checkpoint_value is not None: |
|
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint") |
|
setattr(args, arg_name, checkpoint_value) |
|
|
|
_set_arg("num_layers") |
|
_set_arg("hidden_size") |
|
_set_arg("ffn_hidden_size") |
|
_set_arg("seq_length") |
|
_set_arg("num_attention_heads") |
|
_set_arg("kv_channels") |
|
_set_arg("max_position_embeddings") |
|
_set_arg("tokenizer_type") |
|
_set_arg("padded_vocab_size", force=True) |
|
|
|
_set_arg("position_embedding_type", force=True) |
|
_set_arg("num_attention_heads_kv") |
|
_set_arg("bias_droput_fusion") |
|
_set_arg("bias_gelu_fusion") |
|
_set_arg("hidden_dropout") |
|
_set_arg("parallel_attn", force=True) |
|
_set_arg("parallel_layernorm", force=True) |
|
_set_arg("use_flash_attn") |
|
_set_arg("use_rms_norm", force=True) |
|
_set_arg("ffn_hidden_size") |
|
_set_arg("glu_activation") |
|
_set_arg("tie_embed_logits", force=True) |
|
_set_arg("make_vocab_size_divisible_by", force=True) |
|
_set_arg("train_iters") |
|
_set_arg("sliding_window_size") |
|
if checkpoint_version < 3.0: |
|
_set_arg("tensor_model_parallel_size", "model_parallel_size") |
|
else: |
|
_set_arg("tensor_model_parallel_size", force=True) |
|
_set_arg("pipeline_model_parallel_size", force=True) |
|
_set_arg("num_layers_per_virtual_pipeline_stage") |
|
return args |
|
|
|
|
|
def load_checkpoint( |
|
model, optimizer, opt_param_scheduler, load_arg="load", strict=True |
|
): |
|
"""Load a model checkpoint and return the iteration. |
|
strict (bool): whether to strictly enforce that the keys in |
|
:attr:`state_dict` of the checkpoint match the names of |
|
parameters and buffers in model. |
|
""" |
|
args = get_args() |
|
load_dir = getattr(args, load_arg) |
|
model = unwrap_model(model) |
|
|
|
model_state_dict, optim_state_dict, release = _load_base_checkpoint( |
|
load_dir, use_distributed_optimizer=args.use_distributed_optimizer, rank0=False |
|
) |
|
|
|
if model_state_dict is None: |
|
return 0 |
|
|
|
|
|
set_checkpoint_version(model_state_dict.get("checkpoint_version", 0)) |
|
|
|
|
|
if args.finetune or release or args.annealing: |
|
iteration = 0 |
|
else: |
|
try: |
|
iteration = model_state_dict["iteration"] |
|
except KeyError: |
|
try: |
|
iteration = model_state_dict["total_iters"] |
|
except KeyError: |
|
print_rank_0( |
|
"A metadata file exists but unable to load " |
|
"iteration from checkpoint {}, exiting".format(checkpoint_name) |
|
) |
|
sys.exit() |
|
|
|
|
|
assert args.consumed_train_samples == 0 |
|
assert args.consumed_valid_samples == 0 |
|
if "args" in model_state_dict and not args.finetune and not args.annealing: |
|
checkpoint_args = model_state_dict["args"] |
|
check_checkpoint_args(checkpoint_args) |
|
args.consumed_train_samples = getattr( |
|
checkpoint_args, "consumed_train_samples", 0 |
|
) |
|
update_num_microbatches(consumed_samples=args.consumed_train_samples) |
|
args.consumed_valid_samples = getattr( |
|
checkpoint_args, "consumed_valid_samples", 0 |
|
) |
|
else: |
|
print_rank_0("could not find arguments in the checkpoint ...") |
|
|
|
|
|
if len(model) == 1: |
|
model[0].load_state_dict(model_state_dict["model"], strict=strict) |
|
else: |
|
for i in range(len(model)): |
|
mpu.set_virtual_pipeline_model_parallel_rank(i) |
|
model[i].load_state_dict(model_state_dict["model%d" % i], strict=strict) |
|
|
|
|
|
checkpoint_version = get_checkpoint_version() |
|
print_rank_0(f" checkpoint version {checkpoint_version}") |
|
fix_query_key_value_ordering(model, checkpoint_version) |
|
|
|
|
|
if not release and not args.finetune and not args.no_load_optim: |
|
try: |
|
if optimizer is not None: |
|
optimizer.load_state_dict(optim_state_dict["optimizer"]) |
|
if opt_param_scheduler is not None and not args.annealing: |
|
if "lr_scheduler" in optim_state_dict: |
|
opt_param_scheduler.load_state_dict( |
|
optim_state_dict["lr_scheduler"] |
|
) |
|
else: |
|
opt_param_scheduler.load_state_dict( |
|
optim_state_dict["opt_param_scheduler"] |
|
) |
|
except KeyError: |
|
print_rank_0( |
|
"Unable to load optimizer from checkpoint {}. " |
|
"Specify --no_load_optim or --finetune to prevent " |
|
"attempting to load the optimizer state, " |
|
"exiting ...".format(checkpoint_name) |
|
) |
|
sys.exit() |
|
else: |
|
if args.fp16 and optimizer is not None: |
|
optimizer.reload_model_params() |
|
|
|
|
|
if not release and not args.finetune and not args.no_load_rng: |
|
try: |
|
if "rng_state" in model_state_dict: |
|
|
|
if args.data_parallel_random_init: |
|
rng_state = model_state_dict["rng_state"][ |
|
mpu.get_data_parallel_rank() |
|
] |
|
else: |
|
rng_state = model_state_dict["rng_state"][0] |
|
random.setstate(rng_state["random_rng_state"]) |
|
np.random.set_state(rng_state["np_rng_state"]) |
|
torch.set_rng_state(rng_state["torch_rng_state"]) |
|
torch.cuda.set_rng_state(rng_state["cuda_rng_state"]) |
|
|
|
if not rng_state["rng_tracker_states"]: |
|
raise KeyError |
|
tensor_parallel.get_cuda_rng_tracker().set_states( |
|
rng_state["rng_tracker_states"] |
|
) |
|
else: |
|
random.setstate(model_state_dict["random_rng_state"]) |
|
np.random.set_state(model_state_dict["np_rng_state"]) |
|
torch.set_rng_state(model_state_dict["torch_rng_state"]) |
|
torch.cuda.set_rng_state(model_state_dict["cuda_rng_state"]) |
|
|
|
if not model_state_dict["rng_tracker_states"]: |
|
raise KeyError |
|
tensor_parallel.get_cuda_rng_tracker().set_states( |
|
model_state_dict["rng_tracker_states"] |
|
) |
|
except KeyError: |
|
print_rank_0( |
|
"Unable to load rng state from checkpoint {}. " |
|
"Specify --no_load_rng or --finetune to prevent " |
|
"attempting to load the rng state, " |
|
"exiting ...".format(checkpoint_name) |
|
) |
|
sys.exit() |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
|
|
print_rank_0( |
|
f" successfully loaded checkpoint from {args.load} " |
|
f"at iteration {iteration}" |
|
) |
|
|
|
return iteration |
|
|
|
|
|
def load_biencoder_checkpoint( |
|
model, only_query_model=False, only_context_model=False, custom_load_path=None |
|
): |
|
""" |
|
selectively load retrieval models for indexing/retrieving |
|
from saved checkpoints |
|
""" |
|
|
|
args = get_args() |
|
|
|
model = unwrap_model(model) |
|
load_path = custom_load_path if custom_load_path is not None else args.load |
|
|
|
tracker_filename = get_checkpoint_tracker_filename(load_path) |
|
with open(tracker_filename, "r") as f: |
|
iteration = int(f.read().strip()) |
|
|
|
checkpoint_name, _ = get_checkpoint_names( |
|
load_path, iteration, args.use_distributed_optimizer, release=False |
|
) |
|
|
|
if mpu.get_data_parallel_rank() == 0: |
|
print( |
|
"global rank {} is loading checkpoint {}".format( |
|
torch.distributed.get_rank(), checkpoint_name |
|
) |
|
) |
|
|
|
state_dict = torch.load(model_checkpoint_name, map_location="cpu") |
|
ret_state_dict = state_dict["model"] |
|
|
|
if only_query_model: |
|
ret_state_dict.pop("context_model") |
|
if only_context_model: |
|
ret_state_dict.pop("query_model") |
|
|
|
assert len(model) == 1 |
|
model[0].load_state_dict(ret_state_dict) |
|
torch.distributed.barrier() |
|
|
|
if mpu.get_data_parallel_rank() == 0: |
|
print(" successfully loaded {}".format(checkpoint_name)) |
|
|
|
return model |
|
|