|
|
|
|
|
|
|
import torch |
|
|
|
from torch.distributed.fsdp import ( |
|
|
|
|
|
MixedPrecision, |
|
|
|
|
|
) |
|
|
|
|
|
fpSixteen = MixedPrecision( |
|
param_dtype=torch.float16, |
|
|
|
reduce_dtype=torch.float16, |
|
|
|
buffer_dtype=torch.float16, |
|
) |
|
|
|
bfSixteen = MixedPrecision( |
|
param_dtype=torch.bfloat16, |
|
|
|
reduce_dtype=torch.bfloat16, |
|
|
|
buffer_dtype=torch.bfloat16, |
|
cast_forward_inputs=True, |
|
) |
|
|
|
bfSixteen_mixed = MixedPrecision( |
|
param_dtype=torch.float32, |
|
reduce_dtype=torch.bfloat16, |
|
buffer_dtype=torch.bfloat16, |
|
) |
|
|
|
fp32_policy = MixedPrecision( |
|
param_dtype=torch.float32, |
|
reduce_dtype=torch.float32, |
|
buffer_dtype=torch.float32, |
|
) |
|
|