1234567891011121314151617181920212223242526272829303132333435363738 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import torch
- from torch.distributed.fsdp import (
- MixedPrecision,
- )
- # requires grad scaler in main loop
- fpSixteen = MixedPrecision(
- param_dtype=torch.float16,
- # Gradient communication precision.
- reduce_dtype=torch.float16,
- # Buffer precision.
- buffer_dtype=torch.float16,
- )
- bfSixteen = MixedPrecision(
- param_dtype=torch.bfloat16,
- # Gradient communication precision.
- reduce_dtype=torch.bfloat16,
- # Buffer precision.
- buffer_dtype=torch.bfloat16,
- cast_forward_inputs=True,
- )
- bfSixteen_mixed = MixedPrecision(
- param_dtype=torch.float32,
- reduce_dtype=torch.bfloat16,
- buffer_dtype=torch.bfloat16,
- )
- fp32_policy = MixedPrecision(
- param_dtype=torch.float32,
- reduce_dtype=torch.float32,
- buffer_dtype=torch.float32,
- )
|