# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import torch from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP, # CPUOffload, MixedPrecision, # BackwardPrefetch, # ShardingStrategy, ) # requires grad scaler in main loop fpSixteen = MixedPrecision( param_dtype=torch.float16, # Gradient communication precision. reduce_dtype=torch.float16, # Buffer precision. buffer_dtype=torch.float16, ) bfSixteen = MixedPrecision( param_dtype=torch.bfloat16, # Gradient communication precision. reduce_dtype=torch.bfloat16, # Buffer precision. buffer_dtype=torch.bfloat16, cast_forward_inputs=True, ) bfSixteen_mixed = MixedPrecision( param_dtype=torch.float32, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) fp32_policy = MixedPrecision( param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32, )