# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import torch import os import torch.distributed as dist from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, apply_activation_checkpointing, ) from transformers.models.t5.modeling_t5 import T5Block from transformers.models.llama.modeling_llama import LlamaDecoderLayer from functools import partial non_reentrant_wrapper = partial( checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) def apply_fsdp_checkpointing(model): """apply activation checkpointing to model returns None as model is updated directly """ print(f"--> applying fdsp activation checkpointing...") apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn )