Browse Source

save cpu mem by leveraging FSDP rank0 broadcasting (#77)

Geeta Chauhan 1 year ago
parent
commit
205e5a4b81
6 changed files with 97 additions and 69 deletions
  1. 10 0
      README.md
  2. 2 1
      configs/training.py
  3. 10 0
      docs/multi_gpu.md
  4. 72 65
      llama_finetuning.py
  5. 2 2
      model_checkpointing/checkpoint_handler.py
  6. 1 1
      utils/train_utils.py

+ 10 - 0
README.md

@@ -125,6 +125,16 @@ torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --mode
 
 ```
 
+### Fine-tuning using FSDP on 70B Model
+
+If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
+
+```bash
+
+torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+
+```
+
 ### Multi GPU Multi Node:
 
 ```bash

+ 2 - 1
configs/training.py

@@ -7,7 +7,8 @@ from typing import ClassVar
 @dataclass
 class train_config:
     model_name: str="PATH/to/LLAMA/7B"
-    enable_fsdp: bool= False 
+    enable_fsdp: bool=False
+    low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
     num_epochs: int=3

+ 10 - 0
docs/multi_gpu.md

@@ -62,6 +62,16 @@ torchrun --nnodes 1 --nproc_per_node 8  llama_finetuning.py --enable_fsdp --mode
 
 ```
 
+### Fine-tuning using FSDP on 70B Model
+
+If you are interested in running full parameter fine-tuning on the 70B model, you can enable `low_cpu_fsdp` mode as the following command. This option will load model on rank0 only before moving model to devices to construct FSDP. This can dramatically save cpu memory when loading large models like 70B (on a 8-gpu node, this reduces cpu memory from 2+T to 280G for 70B model). This has been tested with `BF16` on 16xA100, 80GB GPUs.
+
+```bash
+
+torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --micro_batch_size 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+
+```
+
 **Multi GPU multi node**:
 
 Here we use a slurm script to schedule a job with slurm over multiple nodes.

+ 72 - 65
llama_finetuning.py

@@ -2,68 +2,47 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-import sys
-from typing import List, Union
 
 import fire
 import torch
-import transformers
-from datasets import load_dataset
-import os.path as osp
-from tqdm import tqdm
-
-# Unused imports removed
-from utils import fsdp_auto_wrap_policy
+import torch.distributed as dist
+import torch.optim as optim
+from peft import get_peft_model, prepare_model_for_int8_training
+from pkg_resources import packaging
+from torch.distributed.fsdp import (
+    FullyShardedDataParallel as FSDP,
+)
+from torch.optim.lr_scheduler import StepLR
+from torch.utils.data import DistributedSampler
 from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
-    AutoModelForCausalLM,
-    AutoModelForSeq2SeqLM,
-    AutoTokenizer,
+    LlamaConfig,
     default_data_collator,
-    BitsAndBytesConfig
-)
-import torch.distributed as dist
-# Unused imports removed
-from utils.train_utils import (
-    set_tokenizer_params,
-    train,
-    evaluation,
-    freeze_transformer_layers,
-    check_frozen_layers_peft_model,
-    setup,
-    setup_environ_flags,
-    cleanup,
-    clear_gpu_cache,
-    get_parameter_dtypes,
-    print_model_size,
-    get_policies  
 )
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from utils.dataset_utils import get_preprocessed_dataset
+import policies
+from configs import fsdp_config, train_config
+from policies import AnyPrecisionAdamW
 
+from utils import fsdp_auto_wrap_policy
 from utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
 )
-from peft import get_peft_model, TaskType, prepare_model_for_int8_training
-import configs
-from torch.distributed.fsdp import (
-    FullyShardedDataParallel as FSDP,
-    MixedPrecision,
+from utils.dataset_utils import get_preprocessed_dataset
+
+from utils.train_utils import (
+    train,
+    freeze_transformer_layers,
+    setup,
+    setup_environ_flags,
+    clear_gpu_cache,
+    print_model_size,
+    get_policies
 )
-from torch.utils.data import DistributedSampler
-import policies
-from policies import AnyPrecisionAdamW
-from configs import fsdp_config, train_config
-import torch.optim as optim
-from torch.optim.lr_scheduler import StepLR
-from pkg_resources import packaging
-import torch
-import torch.cuda.nccl as nccl
-import torch.distributed as dist
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
 def main(**kwargs):
@@ -82,18 +61,43 @@ def main(**kwargs):
         world_size = int(os.environ["WORLD_SIZE"])
 
     if torch.distributed.is_initialized():
-        torch.cuda.set_device(rank)
+        torch.cuda.set_device(local_rank)
+        clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
-    
+
     # Calculate gradient accumulation steps
     gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
-     
+
     # Load the pre-trained model and setup its configuration
-    model = LlamaForCausalLM.from_pretrained(
-        train_config.model_name,
-        load_in_8bit=True if train_config.quantization else None,
-        device_map="auto" if train_config.quantization else None,
-    )
+    if train_config.enable_fsdp and train_config.low_cpu_fsdp:
+        """
+        for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
+        this avoids cpu oom when loading large models like llama 70B, in which case
+        model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
+        overhead and currently requires latest nightly.
+        """
+        v = packaging.version.parse(torch.__version__)
+        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
+        if not verify_latest_nightly:
+            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
+                            "please install latest nightly.")
+        if rank == 0:
+            model = LlamaForCausalLM.from_pretrained(
+                train_config.model_name,
+                load_in_8bit=True if train_config.quantization else None,
+                device_map="auto" if train_config.quantization else None,
+            )
+        else:
+            llama_config = LlamaConfig.from_pretrained(train_config.model_name)
+            with torch.device("meta"):
+                model = LlamaForCausalLM(llama_config)
+
+    else:
+        model = LlamaForCausalLM.from_pretrained(
+            train_config.model_name,
+            load_in_8bit=True if train_config.quantization else None,
+            device_map="auto" if train_config.quantization else None,
+        )
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
@@ -106,11 +110,11 @@ def main(**kwargs):
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-    
+
     # Prepare the model for int8 training if quantization is enabled
     if train_config.quantization:
         model = prepare_model_for_int8_training(model)
-        
+
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
@@ -119,7 +123,7 @@ def main(**kwargs):
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
     tokenizer.add_special_tokens(
             {
-            
+
                 "pad_token": "<PAD>",
             }
         )
@@ -127,16 +131,16 @@ def main(**kwargs):
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model.print_trainable_parameters()
-    
+
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
-            
+
             freeze_transformer_layers(train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-   
+
         model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
@@ -144,6 +148,9 @@ def main(**kwargs):
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
+            sync_module_states=train_config.low_cpu_fsdp,
+            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
+            if train_config.low_cpu_fsdp and rank != 0 else None,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)
@@ -151,14 +158,14 @@ def main(**kwargs):
         model.to("cuda")
 
     dataset_config = generate_dataset_config(train_config, kwargs)
-    
+
      # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
         tokenizer,
         dataset_config,
         split="train",
     )
-    
+
     if not train_config.enable_fsdp or rank == 0:
         print(f"--> Training Set Length = {len(dataset_train)}")
 
@@ -185,7 +192,7 @@ def main(**kwargs):
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
             )
-        
+
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -207,7 +214,7 @@ def main(**kwargs):
             drop_last=True,
             collate_fn=default_data_collator,
         )
-        
+
     # Initialize the optimizer and learning rate scheduler
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
         optimizer = AnyPrecisionAdamW(
@@ -229,7 +236,7 @@ def main(**kwargs):
     results = train(
         model,
         train_dataloader,
-        eval_dataloader, 
+        eval_dataloader,
         tokenizer,
         optimizer,
         scheduler,

+ 2 - 2
model_checkpointing/checkpoint_handler.py

@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg):
     reader = FileSystemReader(load_dir)
 
     with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
-        checkpoint = model.state_dict()
+        checkpoint = {"model": model.state_dict()}
         if rank == 0:
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
@@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg):
             print(f"checkpoint after load_state_dict()")
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-        model.load_state_dict(checkpoint)
+        model.load_state_dict(checkpoint["model"])
     if rank == 0:
         print(f"Sharded state checkpoint loaded from {load_dir}")
 

+ 1 - 1
utils/train_utils.py

@@ -141,7 +141,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         lr_scheduler.step()
           
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
+            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp: