fsdp_utils.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. def fsdp_auto_wrap_policy(model, transformer_layer_name):
  4. import functools
  5. from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
  6. from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
  7. def lambda_policy_fn(module):
  8. if (
  9. len(list(module.named_children())) == 0
  10. and getattr(module, "weight", None) is not None
  11. and module.weight.requires_grad
  12. ):
  13. return True
  14. return False
  15. lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
  16. transformer_wrap_policy = functools.partial(
  17. transformer_auto_wrap_policy,
  18. transformer_layer_cls=(
  19. PrefixEncoder,
  20. PromptEncoder,
  21. PromptEmbedding,
  22. transformer_layer_name,
  23. # FullyShardedDataParallelPlugin.get_module_class_from_name(
  24. # model, transformer_layer_name
  25. # ),
  26. ),
  27. )
  28. auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
  29. return auto_wrap_policy