Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
import torch | |
from torch.distributed.fsdp import ( | |
# FullyShardedDataParallel as FSDP, | |
# CPUOffload, | |
MixedPrecision, | |
# BackwardPrefetch, | |
# ShardingStrategy, | |
) | |
# requires grad scaler in main loop | |
fpSixteen = MixedPrecision( | |
param_dtype=torch.float16, | |
# Gradient communication precision. | |
reduce_dtype=torch.float16, | |
# Buffer precision. | |
buffer_dtype=torch.float16, | |
) | |
bfSixteen = MixedPrecision( | |
param_dtype=torch.bfloat16, | |
# Gradient communication precision. | |
reduce_dtype=torch.bfloat16, | |
# Buffer precision. | |
buffer_dtype=torch.bfloat16, | |
cast_forward_inputs=True, | |
) | |
bfSixteen_mixed = MixedPrecision( | |
param_dtype=torch.float32, | |
reduce_dtype=torch.bfloat16, | |
buffer_dtype=torch.bfloat16, | |
) | |
fp32_policy = MixedPrecision( | |
param_dtype=torch.float32, | |
reduce_dtype=torch.float32, | |
buffer_dtype=torch.float32, | |
) | |