1234567891011121314151617181920212223242526272829303132333435363738 |
- def fsdp_auto_wrap_policy(model, transformer_layer_name):
- import functools
- import os
- from accelerate import FullyShardedDataParallelPlugin
- from transformers.models.t5.modeling_t5 import T5Block
- from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
- from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
- def lambda_policy_fn(module):
- if (
- len(list(module.named_children())) == 0
- and getattr(module, "weight", None) is not None
- and module.weight.requires_grad
- ):
- return True
- return False
- lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
- transformer_wrap_policy = functools.partial(
- transformer_auto_wrap_policy,
- transformer_layer_cls=(
- PrefixEncoder,
- PromptEncoder,
- PromptEmbedding,
- transformer_layer_name,
-
-
-
- ),
- )
- auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
- return auto_wrap_policy
|