Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
from torch.distributed.fsdp.fully_sharded_data_parallel import ( | |
FullyShardedDataParallel as FSDP, | |
CPUOffload, | |
BackwardPrefetch, | |
MixedPrecision, | |
) | |
from torch.distributed.fsdp.wrap import ( | |
transformer_auto_wrap_policy, | |
size_based_auto_wrap_policy, | |
enable_wrap, | |
wrap, | |
) | |
import functools | |
from typing import Type | |
def get_size_policy(min_params=1e8): | |
num_wrap_policy = functools.partial( | |
size_based_auto_wrap_policy, min_num_params=min_params | |
) | |
return num_wrap_policy | |
def get_llama_wrapper(): | |
"""we register our main layer class and use the fsdp transformer wrapping policy | |
ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers | |
""" | |
# ==== use new transformer wrapper | |
llama_auto_wrap_policy = functools.partial( | |
transformer_auto_wrap_policy, | |
transformer_layer_cls={ | |
LlamaDecoderLayer, | |
}, | |
) | |
return llama_auto_wrap_policy | |