123456789101112131415161718192021222324252627282930313233 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import functools
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
- from torch.distributed.fsdp.wrap import (
- transformer_auto_wrap_policy,
- size_based_auto_wrap_policy,
- )
- def get_size_policy(min_params=1e8):
- num_wrap_policy = functools.partial(
- size_based_auto_wrap_policy, min_num_params=min_params
- )
- return num_wrap_policy
- def get_llama_wrapper():
- """we register our main layer class and use the fsdp transformer wrapping policy
- ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers
- """
- # ==== use new transformer wrapper
- llama_auto_wrap_policy = functools.partial(
- transformer_auto_wrap_policy,
- transformer_layer_cls={
- LlamaDecoderLayer,
- },
- )
- return llama_auto_wrap_policy
|