123456789101112131415161718192021222324252627282930313233343536373839404142 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import torch
- from torch.distributed.fsdp import (
- # FullyShardedDataParallel as FSDP,
- # CPUOffload,
- MixedPrecision,
- # BackwardPrefetch,
- # ShardingStrategy,
- )
- # requires grad scaler in main loop
- fpSixteen = MixedPrecision(
- param_dtype=torch.float16,
- # Gradient communication precision.
- reduce_dtype=torch.float16,
- # Buffer precision.
- buffer_dtype=torch.float16,
- )
- bfSixteen = MixedPrecision(
- param_dtype=torch.bfloat16,
- # Gradient communication precision.
- reduce_dtype=torch.bfloat16,
- # Buffer precision.
- buffer_dtype=torch.bfloat16,
- cast_forward_inputs=True,
- )
- bfSixteen_mixed = MixedPrecision(
- param_dtype=torch.float32,
- reduce_dtype=torch.bfloat16,
- buffer_dtype=torch.bfloat16,
- )
- fp32_policy = MixedPrecision(
- param_dtype=torch.float32,
- reduce_dtype=torch.float32,
- buffer_dtype=torch.float32,
- )
|