mixed_precision.py 949 B

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import torch
  4. from torch.distributed.fsdp import (
  5. MixedPrecision,
  6. )
  7. # requires grad scaler in main loop
  8. fpSixteen = MixedPrecision(
  9. param_dtype=torch.float16,
  10. # Gradient communication precision.
  11. reduce_dtype=torch.float16,
  12. # Buffer precision.
  13. buffer_dtype=torch.float16,
  14. )
  15. bfSixteen = MixedPrecision(
  16. param_dtype=torch.bfloat16,
  17. # Gradient communication precision.
  18. reduce_dtype=torch.bfloat16,
  19. # Buffer precision.
  20. buffer_dtype=torch.bfloat16,
  21. cast_forward_inputs=True,
  22. )
  23. bfSixteen_mixed = MixedPrecision(
  24. param_dtype=torch.float32,
  25. reduce_dtype=torch.bfloat16,
  26. buffer_dtype=torch.bfloat16,
  27. )
  28. fp32_policy = MixedPrecision(
  29. param_dtype=torch.float32,
  30. reduce_dtype=torch.float32,
  31. buffer_dtype=torch.float32,
  32. )