Browse Source

Merge branch 'main' into feature/package_distribution

Matthias Reso 1 year ago
parent
commit
eb69e754a0
1 changed files with 4 additions and 0 deletions
  1. 4 0
      src/llama_recipes/finetuning.py

+ 4 - 0
src/llama_recipes/finetuning.py

@@ -65,6 +65,7 @@ def main(**kwargs):
         setup_environ_flags(rank)
 
     # Load the pre-trained model and setup its configuration
+    use_cache = False if train_config.enable_fsdp else None
     if train_config.enable_fsdp and train_config.low_cpu_fsdp:
         """
         for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
@@ -82,9 +83,11 @@ def main(**kwargs):
                 train_config.model_name,
                 load_in_8bit=True if train_config.quantization else None,
                 device_map="auto" if train_config.quantization else None,
+                use_cache=use_cache,
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
+            llama_config.use_cache = use_cache
             with torch.device("meta"):
                 model = LlamaForCausalLM(llama_config)
 
@@ -93,6 +96,7 @@ def main(**kwargs):
             train_config.model_name,
             load_in_8bit=True if train_config.quantization else None,
             device_map="auto" if train_config.quantization else None,
+            use_cache=use_cache,
         )
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """