Bladeren bron

save cpu mem by leveraging FSDP rank0 broadcasting

lchu 1 jaar geleden
bovenliggende
commit
d8a81bb531
1 gewijzigde bestanden met toevoegingen van 41 en 6 verwijderingen
  1. 41 6
      llama_finetuning.py

+ 41 - 6
llama_finetuning.py

@@ -17,6 +17,7 @@ from utils import fsdp_auto_wrap_policy
 from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
+    LlamaConfig,
     AutoModelForCausalLM,
     AutoModelForSeq2SeqLM,
     AutoTokenizer,
@@ -41,6 +42,8 @@ from utils.train_utils import (
     get_policies  
 )
 
+from accelerate import init_empty_weights
+
 from utils.dataset_utils import get_preprocessed_dataset
 
 from utils.config_utils import (
@@ -62,8 +65,10 @@ import torch.optim as optim
 from torch.optim.lr_scheduler import StepLR
 from pkg_resources import packaging
 import torch
+import torch.nn as nn
 import torch.cuda.nccl as nccl
 import torch.distributed as dist
+from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
@@ -90,11 +95,26 @@ def main(**kwargs):
     gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
      
     # Load the pre-trained model and setup its configuration
-    model = LlamaForCausalLM.from_pretrained(
-        train_config.model_name,
-        load_in_8bit=True if train_config.quantization else None,
-        device_map="auto" if train_config.quantization else None,
-    )
+    if train_config.enable_fsdp:
+        # for FSDP, we save cpu memory by loading pretrained model on rank0 only.
+        # this avoids cpu oom when loading large models like llama 70B, in which case
+        # model alone would consume 2+TB cpu mem (70 * 4 * 8)
+        if rank == 0:
+            model = LlamaForCausalLM.from_pretrained(
+                train_config.model_name,
+                load_in_8bit=True if train_config.quantization else None,
+                device_map="auto" if train_config.quantization else None,
+            )
+        else:
+            llama_config = LlamaConfig.from_pretrained(train_config.model_name)
+            with init_empty_weights():
+                model = LlamaForCausalLM(llama_config)
+    else:
+        model = LlamaForCausalLM.from_pretrained(
+            train_config.model_name,
+            load_in_8bit=True if train_config.quantization else None,
+            device_map="auto" if train_config.quantization else None,
+        )
     
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     
@@ -127,7 +147,20 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-   
+
+        # given the fast evolving PRs around meta device init, I am not sure
+        # what is the best param_init_fn atm, maybe we can switch to simple to_emtpy.
+        def _param_init_fn(module: nn.Module):
+            torch.manual_seed(0)
+            for submodule in module.modules():
+                for param_name, param in submodule.named_parameters(recurse=False):
+                    if not _is_fsdp_flattened(param) and param.is_meta:
+                        materialized_param = nn.Parameter(
+                            torch.empty_like(param, device=torch.device("cuda"))
+                        )
+                        nn.init.uniform_(materialized_param)
+                        setattr(submodule, param_name, materialized_param)
+
         model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
@@ -135,6 +168,8 @@ def main(**kwargs):
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
+            sync_module_states=True,
+            param_init_fn=None if rank == 0 else _param_init_fn,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)