Pārlūkot izejas kodu

minor code optimization

lchu 1 gadu atpakaļ
vecāks
revīzija
3d1e9cd58c
1 mainītis faili ar 1 papildinājumiem un 1 dzēšanām
  1. 1 1
      llama_finetuning.py

+ 1 - 1
llama_finetuning.py

@@ -137,7 +137,7 @@ def main(**kwargs):
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
-            sync_module_states=True if train_config.low_cpu_fsdp else False,
+            sync_module_states=train_config.low_cpu_fsdp,
             param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )