Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
import torch | |
import os | |
import torch.distributed as dist | |
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
checkpoint_wrapper, | |
CheckpointImpl, | |
apply_activation_checkpointing, | |
) | |
from transformers.models.t5.modeling_t5 import T5Block | |
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
from functools import partial | |
non_reentrant_wrapper = partial( | |
checkpoint_wrapper, | |
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
) | |
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) | |
def apply_fsdp_checkpointing(model): | |
"""apply activation checkpointing to model | |
returns None as model is updated directly | |
""" | |
print(f"--> applying fdsp activation checkpointing...") | |
apply_activation_checkpointing( | |
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn | |
) | |