File size: 13,643 Bytes
399b48d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 |
#!/usr/bin/env python
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py . pytorch_model.bin
import argparse
import torch
import glob
import os
from collections import OrderedDict
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
import deepspeed
from deepspeed.utils import logger
debug = 0
# load to cpu
device = torch.device('cpu')
def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file
if zero_stage == 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
if not os.path.exists(file):
raise FileNotFoundError(f"can't find model states file at '{file}'")
return file
def get_optim_files(checkpoint_dir):
# XXX: need to test that this simple glob rule works for multi-node setup too
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt")))
if len(optim_files) == 0:
raise FileNotFoundError(
f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
return optim_files
def parse_model_state(file):
state_dict = torch.load(file, map_location=device)
if "buffer_names" not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict["buffer_names"]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {
k: v.float()
for k,
v in state_dict["module"].items() if k in buffer_names
}
return buffers
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dicts.append(torch.load(f, map_location=device))
if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
raise ValueError(f"{files[0]} is not a zero checkpoint")
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
param_shapes = state_dicts[0]["param_shapes"]
if world_size != total_files:
raise ValueError(
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
)
# the groups are named differently in each stage
if zero_stage == 2:
fp32_groups_key = "single_partition_of_fp32_groups"
elif zero_stage == 3:
fp32_groups_key = "fp32_flat_groups"
else:
raise ValueError(f"unknown zero stage {zero_stage}")
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
0) for i in range(len(state_dicts))
]
return zero_stage, world_size, param_shapes, fp32_flat_groups
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = int(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
"""
Returns fp32 state_dict reconstructed from ds checkpoint
Args:
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
"""
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print(
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
buffers = parse_model_state(model_file)
# Reconstruction protocol:
#
# - for zero2 we just need to concat the partitions back to back and reconsolidate over one huge
# flat buffer - no need to deal with padding since if there is any it will be only in the tail
# of the last partition so there it will be just left out
#
# - for zero3 we need to zip the partitions together at boundary of each param, re-consolidating
# each param, while dealing with padding if any
if debug:
for i in range(world_size):
print(f"fp32_flat_groups[i].shape={fp32_flat_groups[i].shape}")
if zero_stage == 2:
# XXX: memory usage doubles here (zero2)
full_single_fp32_vector = torch.cat(fp32_flat_groups, 0)
avail_numel = full_single_fp32_vector.numel()
elif zero_stage == 3:
avail_numel = fp32_flat_groups[0].numel() * world_size
if debug:
wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")
state_dict = OrderedDict()
# buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
if zero_stage == 2:
if debug:
print(
f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
)
state_dict[name] = full_single_fp32_vector.narrow(
0,
offset,
unpartitioned_numel).view(shape)
offset += unpartitioned_numel
elif zero_stage == 3:
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
# XXX: memory usage doubles here (zero3)
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0,
offset,
partitioned_numel)
for i in range(world_size)),
0).view(shape)
offset += partitioned_numel + partitioned_padding_numel
if zero_stage == 3:
offset *= world_size
# Sanity check
if offset != avail_numel:
raise ValueError(
f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
)
return state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
via a model hub.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
Returns:
- pytorch ``state_dict``
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# do the training and checkpoint saving
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
# submit to model hub or save the model to share with others
In this example the ``model`` will no longer be useable in the deepspeed context of the same
application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
"""
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
"""
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
print(f"Saving fp32 state dict to {output_file}")
torch.save(state_dict, output_file)
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
"""
1. Put the provided model to cpu
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
3. Load it into the provided model
Args:
- ``model``: the model object to update
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
Returns:
- ``model`: modified model
Make sure you have plenty of CPU memory available before you call this function. If you don't
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
conveniently placed for you in the checkpoint folder.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
# submit to model hub or save the model to share with others
Note, that once this was run, the ``model`` will no longer be useable in the deepspeed context
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
"""
logger.info(f"Extracting fp32 weights")
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
logger.info(f"Overwriting model with fp32 weights")
model = model.cpu()
model.load_state_dict(state_dict, strict=False)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"checkpoint_dir",
type=str,
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument(
"output_file",
type=str,
help=
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
)
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args()
debug = args.debug
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
|