wrapping.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import torch.distributed as dist
  4. import torch.nn as nn
  5. import torch
  6. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  7. from torch.distributed.fsdp.fully_sharded_data_parallel import (
  8. FullyShardedDataParallel as FSDP,
  9. CPUOffload,
  10. BackwardPrefetch,
  11. MixedPrecision,
  12. )
  13. from torch.distributed.fsdp.wrap import (
  14. transformer_auto_wrap_policy,
  15. size_based_auto_wrap_policy,
  16. enable_wrap,
  17. wrap,
  18. )
  19. import functools
  20. from typing import Type
  21. def get_size_policy(min_params=1e8):
  22. num_wrap_policy = functools.partial(
  23. size_based_auto_wrap_policy, min_num_params=min_params
  24. )
  25. return num_wrap_policy
  26. def get_llama_wrapper():
  27. """we register our main layer class and use the fsdp transformer wrapping policy
  28. ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers
  29. """
  30. # ==== use new transformer wrapper
  31. llama_auto_wrap_policy = functools.partial(
  32. transformer_auto_wrap_policy,
  33. transformer_layer_cls={
  34. LlamaDecoderLayer,
  35. },
  36. )
  37. return llama_auto_wrap_policy