fsdp.py 1.4 KB

12345678910111213141516171819202122
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass
  4. from torch.distributed.fsdp import ShardingStrategy
  5. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  6. @dataclass
  7. class fsdp_config:
  8. mixed_precision: bool=True
  9. use_fp16: bool=False
  10. sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP".
  11. hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group.
  12. sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
  13. replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size.
  14. checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
  15. fsdp_activation_checkpointing: bool=True
  16. fsdp_cpu_offload: bool=False
  17. pure_bf16: bool = False
  18. optimizer: str= "AdamW"