Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Wrapper around FSDP for more convenient use in the training loops. | |
""" | |
from contextlib import contextmanager | |
import typing as tp | |
import dora | |
import torch | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp import ( | |
MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType) | |
from torch.distributed._shard.sharded_tensor.api import ShardedTensor | |
def is_fsdp_used() -> bool: | |
"""Return whether we are using FSDP.""" | |
# A bit of a hack but should work from anywhere. | |
if dora.is_xp(): | |
cfg = dora.get_xp().cfg | |
if hasattr(cfg, 'fsdp'): | |
return cfg.fsdp.use | |
return False | |
def is_sharded_tensor(x: tp.Any) -> bool: | |
return isinstance(x, ShardedTensor) | |
def switch_to_full_state_dict(models: tp.List[FSDP]): | |
# Another bug in FSDP makes it that we cannot use the `state_dict_type` API, | |
# so let's do thing manually. | |
for model in models: | |
FSDP.set_state_dict_type( # type: ignore | |
model, StateDictType.FULL_STATE_DICT, | |
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)) | |
try: | |
yield | |
finally: | |
for model in models: | |
FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore | |
def wrap_with_fsdp(cfg, model: torch.nn.Module, | |
block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP: | |
"""Wraps a model with FSDP.""" | |
# Some of the typing is disabled until this gets integrated | |
# into the stable version of PyTorch. | |
from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore | |
# we import this here to prevent circular import. | |
from ..modules.transformer import StreamingTransformerLayer | |
from ..modules.conditioners import ConditioningProvider | |
_fix_post_backward_hook() | |
assert cfg.use | |
sharding_strategy_dict = { | |
"no_shard": ShardingStrategy.NO_SHARD, | |
"shard_grad_op": ShardingStrategy.SHARD_GRAD_OP, | |
"full_shard": ShardingStrategy.FULL_SHARD, | |
} | |
dtype_dict = { | |
"float32": torch.float32, | |
"float16": torch.float16, | |
"bfloat16": torch.bfloat16, | |
} | |
mixed_precision_config = MixedPrecision( | |
param_dtype=dtype_dict[cfg.param_dtype], | |
reduce_dtype=dtype_dict[cfg.reduce_dtype], | |
buffer_dtype=dtype_dict[cfg.buffer_dtype], | |
) | |
sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy] | |
# The following is going to require being a bit smart | |
# when doing LM, because this would flush the weights for every time step | |
# during generation. One possiblity is to use hybrid sharding: | |
# See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy | |
assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \ | |
"Not supported at the moment, requires a bit more work." | |
local_rank = dora.distrib.get_distrib_spec().local_rank | |
assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!" | |
auto_wrap_policy = None | |
if block_classes is None: | |
block_classes = {StreamingTransformerLayer, ConditioningProvider} | |
if cfg.per_block: | |
auto_wrap_policy = ModuleWrapPolicy(block_classes) | |
wrapped = _FSDPFixStateDict( | |
model, | |
sharding_strategy=sharding_strategy_config, | |
mixed_precision=mixed_precision_config, | |
device_id=local_rank, | |
sync_module_states=True, | |
use_orig_params=True, | |
auto_wrap_policy=auto_wrap_policy, | |
) # type: ignore | |
FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore | |
# Let the wrapped model know about the wrapping! | |
# We use __dict__ to avoid it going into the state dict. | |
# This is a bit dirty, but needed during generation, as otherwise | |
# the wrapped model would call itself and bypass FSDP. | |
for module in FSDP.fsdp_modules(wrapped): | |
original = module._fsdp_wrapped_module | |
original.__dict__['_fsdp'] = module | |
return wrapped | |
def purge_fsdp(model: FSDP): | |
"""Purge the FSDP cached shard inside the model. This should | |
allow setting the best state or switching to the EMA. | |
""" | |
from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore | |
for module in FSDP.fsdp_modules(model): | |
if hasattr(module, "_handles"): | |
# support for FSDP with torch<2.1.0 | |
handles = module._handles | |
if not handles: | |
continue | |
handle = handles[0] | |
unsharded_flat_param = handle._get_padded_unsharded_flat_param() | |
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore | |
if storage_size == 0: | |
continue | |
true_list = [True for h in handles] | |
_reshard(module, handles, true_list) | |
else: | |
handle = module._handle | |
if not handle: | |
continue | |
unsharded_flat_param = handle._get_padded_unsharded_flat_param() | |
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore | |
if storage_size == 0: | |
continue | |
_reshard(module, handle, True) | |
class _FSDPFixStateDict(FSDP): | |
def _name_without_fsdp_prefix(name: str) -> str: | |
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore | |
parts = name.split('.') | |
new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE] | |
return '.'.join(new_parts) | |
def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore | |
state = dict(super().state_dict(*args, **kwargs)) | |
for key, value in list(state.items()): | |
if is_sharded_tensor(value): | |
del state[key] | |
return state | |
def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore | |
if self._state_dict_type is StateDictType.FULL_STATE_DICT: | |
super().load_state_dict(state) | |
purge_fsdp(self) | |
return | |
# Fix FSDP load state dict in all situation. | |
# Use this only with LOCAL_STATE_DICT !!! | |
current_state = dict(super().state_dict()) | |
for key, value in state.items(): | |
key = _FSDPFixStateDict._name_without_fsdp_prefix(key) | |
if key not in current_state: | |
# Emulate strict loading manually. | |
raise RuntimeError(f"Unknown state key {key}") | |
current_state[key].copy_(value) | |
# Purging cached weights from previous forward. | |
purge_fsdp(self) | |
_hook_fixed = False | |
def _fix_post_backward_hook(): | |
global _hook_fixed | |
if _hook_fixed: | |
return | |
_hook_fixed = True | |
from torch.distributed.fsdp import _runtime_utils | |
from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState | |
old_hook = _runtime_utils._post_backward_hook | |
def _post_backward_hook(state, handle, *args, **kwargs): | |
checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False) | |
if checkpointed: | |
# there will be one more forward in the backward with checkpointing and that will | |
# massively confuse FSDP, so we have to make it think everything | |
# is going according to the plan. | |
state.training_state = TrainingState.FORWARD_BACKWARD | |
handle._training_state = HandleTrainingState.BACKWARD_PRE | |
old_hook(state, handle, *args, **kwargs) | |
_runtime_utils._post_backward_hook = _post_backward_hook | |