123456789101112131415161718192021222324252627282930313233343536373839404142 |
- import torch
- from torch.distributed.fsdp import (
-
-
- MixedPrecision,
-
-
- )
- fpSixteen = MixedPrecision(
- param_dtype=torch.float16,
-
- reduce_dtype=torch.float16,
-
- buffer_dtype=torch.float16,
- )
- bfSixteen = MixedPrecision(
- param_dtype=torch.bfloat16,
-
- reduce_dtype=torch.bfloat16,
-
- buffer_dtype=torch.bfloat16,
- cast_forward_inputs=True,
- )
- bfSixteen_mixed = MixedPrecision(
- param_dtype=torch.float32,
- reduce_dtype=torch.bfloat16,
- buffer_dtype=torch.bfloat16,
- )
- fp32_policy = MixedPrecision(
- param_dtype=torch.float32,
- reduce_dtype=torch.float32,
- buffer_dtype=torch.float32,
- )
|