Преглед на файлове

Add FSDP CPU offloading option

Howard Liberty преди 1 година
родител
ревизия
cc356b6017
променени са 2 файла, в които са добавени 3 реда и са изтрити 2 реда
  1. 1 2
      src/llama_recipes/configs/fsdp.py
  2. 2 0
      src/llama_recipes/finetuning.py

+ 1 - 2
src/llama_recipes/configs/fsdp.py

@@ -13,8 +13,7 @@ class fsdp_config:
     sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
     checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
     fsdp_activation_checkpointing: bool=True
+    fsdp_cpu_offload: bool=False
     pure_bf16: bool = False
     optimizer: str= "AdamW"
     
-    
-    

+ 2 - 0
src/llama_recipes/finetuning.py

@@ -12,6 +12,7 @@ from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
 )
+from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from torch.utils.data import DistributedSampler
 from transformers import (
@@ -144,6 +145,7 @@ def main(**kwargs):
         model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
+            cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),