12345678910111213141516171819 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- from dataclasses import dataclass
- from torch.distributed.fsdp import ShardingStrategy
- from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
- @dataclass
- class fsdp_config:
- mixed_precision: bool=True
- use_fp16: bool=False
- sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
- checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
- fsdp_activation_checkpointing: bool=True
- fsdp_cpu_offload: bool=False
- pure_bf16: bool = False
- optimizer: str= "AdamW"
-
|