fsdp_utils.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. def fsdp_auto_wrap_policy(model, transformer_layer_name):
  4. import functools
  5. import os
  6. from accelerate import FullyShardedDataParallelPlugin
  7. from transformers.models.t5.modeling_t5 import T5Block
  8. from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
  9. from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
  10. def lambda_policy_fn(module):
  11. if (
  12. len(list(module.named_children())) == 0
  13. and getattr(module, "weight", None) is not None
  14. and module.weight.requires_grad
  15. ):
  16. return True
  17. return False
  18. lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
  19. transformer_wrap_policy = functools.partial(
  20. transformer_auto_wrap_policy,
  21. transformer_layer_cls=(
  22. PrefixEncoder,
  23. PromptEncoder,
  24. PromptEmbedding,
  25. transformer_layer_name,
  26. # FullyShardedDataParallelPlugin.get_module_class_from_name(
  27. # model, transformer_layer_name
  28. # ),
  29. ),
  30. )
  31. auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
  32. return auto_wrap_policy