1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import torch.distributed as dist
- import torch.nn as nn
- import torch
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
- from torch.distributed.fsdp.fully_sharded_data_parallel import (
- FullyShardedDataParallel as FSDP,
- CPUOffload,
- BackwardPrefetch,
- MixedPrecision,
- )
- from torch.distributed.fsdp.wrap import (
- transformer_auto_wrap_policy,
- size_based_auto_wrap_policy,
- enable_wrap,
- wrap,
- )
- import functools
- from typing import Type
- def get_size_policy(min_params=1e8):
- num_wrap_policy = functools.partial(
- size_based_auto_wrap_policy, min_num_params=min_params
- )
- return num_wrap_policy
- def get_llama_wrapper():
- """we register our main layer class and use the fsdp transformer wrapping policy
- ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers
- """
-
- llama_auto_wrap_policy = functools.partial(
- transformer_auto_wrap_policy,
- transformer_layer_cls={
- LlamaDecoderLayer,
- },
- )
- return llama_auto_wrap_policy
|