# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. def fsdp_auto_wrap_policy(model, transformer_layer_name): import functools import os from accelerate import FullyShardedDataParallelPlugin from transformers.models.t5.modeling_t5 import T5Block from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder def lambda_policy_fn(module): if ( len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and module.weight.requires_grad ): return True return False lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=( PrefixEncoder, PromptEncoder, PromptEmbedding, transformer_layer_name, # FullyShardedDataParallelPlugin.get_module_class_from_name( # model, transformer_layer_name # ), ), ) auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) return auto_wrap_policy