mixed_precision.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import torch
  4. from torch.distributed.fsdp import (
  5. # FullyShardedDataParallel as FSDP,
  6. # CPUOffload,
  7. MixedPrecision,
  8. # BackwardPrefetch,
  9. # ShardingStrategy,
  10. )
  11. # requires grad scaler in main loop
  12. fpSixteen = MixedPrecision(
  13. param_dtype=torch.float16,
  14. # Gradient communication precision.
  15. reduce_dtype=torch.float16,
  16. # Buffer precision.
  17. buffer_dtype=torch.float16,
  18. )
  19. bfSixteen = MixedPrecision(
  20. param_dtype=torch.bfloat16,
  21. # Gradient communication precision.
  22. reduce_dtype=torch.bfloat16,
  23. # Buffer precision.
  24. buffer_dtype=torch.bfloat16,
  25. cast_forward_inputs=True,
  26. )
  27. bfSixteen_mixed = MixedPrecision(
  28. param_dtype=torch.float32,
  29. reduce_dtype=torch.bfloat16,
  30. buffer_dtype=torch.bfloat16,
  31. )
  32. fp32_policy = MixedPrecision(
  33. param_dtype=torch.float32,
  34. reduce_dtype=torch.float32,
  35. buffer_dtype=torch.float32,
  36. )