1234567891011121314151617181920212223242526272829303132333435363738 |
- import torch
- from torch.distributed.fsdp import (
- MixedPrecision,
- )
- fpSixteen = MixedPrecision(
- param_dtype=torch.float16,
-
- reduce_dtype=torch.float16,
-
- buffer_dtype=torch.float16,
- )
- bfSixteen = MixedPrecision(
- param_dtype=torch.bfloat16,
-
- reduce_dtype=torch.bfloat16,
-
- buffer_dtype=torch.bfloat16,
- cast_forward_inputs=True,
- )
- bfSixteen_mixed = MixedPrecision(
- param_dtype=torch.float32,
- reduce_dtype=torch.bfloat16,
- buffer_dtype=torch.bfloat16,
- )
- fp32_policy = MixedPrecision(
- param_dtype=torch.float32,
- reduce_dtype=torch.float32,
- buffer_dtype=torch.float32,
- )
|