Gael Le Lan
Initial commit
9d0d223
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapper around FSDP for more convenient use in the training loops.
"""
from contextlib import contextmanager
import typing as tp
import dora
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType)
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
def is_fsdp_used() -> bool:
"""Return whether we are using FSDP."""
# A bit of a hack but should work from anywhere.
if dora.is_xp():
cfg = dora.get_xp().cfg
if hasattr(cfg, 'fsdp'):
return cfg.fsdp.use
return False
def is_sharded_tensor(x: tp.Any) -> bool:
return isinstance(x, ShardedTensor)
@contextmanager
def switch_to_full_state_dict(models: tp.List[FSDP]):
# Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
# so let's do thing manually.
for model in models:
FSDP.set_state_dict_type( # type: ignore
model, StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True))
try:
yield
finally:
for model in models:
FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore
def wrap_with_fsdp(cfg, model: torch.nn.Module,
block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP:
"""Wraps a model with FSDP."""
# Some of the typing is disabled until this gets integrated
# into the stable version of PyTorch.
from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore
# we import this here to prevent circular import.
from ..modules.transformer import StreamingTransformerLayer
from ..modules.conditioners import ConditioningProvider
_fix_post_backward_hook()
assert cfg.use
sharding_strategy_dict = {
"no_shard": ShardingStrategy.NO_SHARD,
"shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
"full_shard": ShardingStrategy.FULL_SHARD,
}
dtype_dict = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
mixed_precision_config = MixedPrecision(
param_dtype=dtype_dict[cfg.param_dtype],
reduce_dtype=dtype_dict[cfg.reduce_dtype],
buffer_dtype=dtype_dict[cfg.buffer_dtype],
)
sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy]
# The following is going to require being a bit smart
# when doing LM, because this would flush the weights for every time step
# during generation. One possiblity is to use hybrid sharding:
# See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy
assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
"Not supported at the moment, requires a bit more work."
local_rank = dora.distrib.get_distrib_spec().local_rank
assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!"
auto_wrap_policy = None
if block_classes is None:
block_classes = {StreamingTransformerLayer, ConditioningProvider}
if cfg.per_block:
auto_wrap_policy = ModuleWrapPolicy(block_classes)
wrapped = _FSDPFixStateDict(
model,
sharding_strategy=sharding_strategy_config,
mixed_precision=mixed_precision_config,
device_id=local_rank,
sync_module_states=True,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
) # type: ignore
FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore
# Let the wrapped model know about the wrapping!
# We use __dict__ to avoid it going into the state dict.
# This is a bit dirty, but needed during generation, as otherwise
# the wrapped model would call itself and bypass FSDP.
for module in FSDP.fsdp_modules(wrapped):
original = module._fsdp_wrapped_module
original.__dict__['_fsdp'] = module
return wrapped
def purge_fsdp(model: FSDP):
"""Purge the FSDP cached shard inside the model. This should
allow setting the best state or switching to the EMA.
"""
from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore
for module in FSDP.fsdp_modules(model):
if hasattr(module, "_handles"):
# support for FSDP with torch<2.1.0
handles = module._handles
if not handles:
continue
handle = handles[0]
unsharded_flat_param = handle._get_padded_unsharded_flat_param()
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
if storage_size == 0:
continue
true_list = [True for h in handles]
_reshard(module, handles, true_list)
else:
handle = module._handle
if not handle:
continue
unsharded_flat_param = handle._get_padded_unsharded_flat_param()
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
if storage_size == 0:
continue
_reshard(module, handle, True)
class _FSDPFixStateDict(FSDP):
@staticmethod
def _name_without_fsdp_prefix(name: str) -> str:
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore
parts = name.split('.')
new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE]
return '.'.join(new_parts)
def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore
state = dict(super().state_dict(*args, **kwargs))
for key, value in list(state.items()):
if is_sharded_tensor(value):
del state[key]
return state
def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore
if self._state_dict_type is StateDictType.FULL_STATE_DICT:
super().load_state_dict(state)
purge_fsdp(self)
return
# Fix FSDP load state dict in all situation.
# Use this only with LOCAL_STATE_DICT !!!
current_state = dict(super().state_dict())
for key, value in state.items():
key = _FSDPFixStateDict._name_without_fsdp_prefix(key)
if key not in current_state:
# Emulate strict loading manually.
raise RuntimeError(f"Unknown state key {key}")
current_state[key].copy_(value)
# Purging cached weights from previous forward.
purge_fsdp(self)
_hook_fixed = False
def _fix_post_backward_hook():
global _hook_fixed
if _hook_fixed:
return
_hook_fixed = True
from torch.distributed.fsdp import _runtime_utils
from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState
old_hook = _runtime_utils._post_backward_hook
def _post_backward_hook(state, handle, *args, **kwargs):
checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False)
if checkpointed:
# there will be one more forward in the backward with checkpointing and that will
# massively confuse FSDP, so we have to make it think everything
# is going according to the plan.
state.training_state = TrainingState.FORWARD_BACKWARD
handle._training_state = HandleTrainingState.BACKWARD_PRE
old_hook(state, handle, *args, **kwargs)
_runtime_utils._post_backward_hook = _post_backward_hook