Browse Source

Pass `use_cache=False` when training with FSDP (#165)

Geeta Chauhan 1 year atrás
parent
commit
486b880964
1 changed files with 4 additions and 0 deletions
  1. 4 0
      llama_finetuning.py

+ 4 - 0
llama_finetuning.py

@@ -66,6 +66,7 @@ def main(**kwargs):
         setup_environ_flags(rank)
 
     # Load the pre-trained model and setup its configuration
+    use_cache = False if train_config.enable_fsdp else None
     if train_config.enable_fsdp and train_config.low_cpu_fsdp:
         """
         for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
@@ -83,9 +84,11 @@ def main(**kwargs):
                 train_config.model_name,
                 load_in_8bit=True if train_config.quantization else None,
                 device_map="auto" if train_config.quantization else None,
+                use_cache=use_cache,
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
+            llama_config.use_cache = use_cache
             with torch.device("meta"):
                 model = LlamaForCausalLM(llama_config)
 
@@ -94,6 +97,7 @@ def main(**kwargs):
             train_config.model_name,
             load_in_8bit=True if train_config.quantization else None,
             device_map="auto" if train_config.quantization else None,
+            use_cache=use_cache,
         )
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """