Browse Source

adding HSDP as sharding strategy for FSDP training composable with PEFT (#281)

Hamid Shojanazeri 1 year ago
parent
commit
e9985fc4ad

+ 14 - 1
README.md

@@ -174,7 +174,20 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /path_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+
+```
+
+In case you are dealing with slower interconnect network between nodes, to reduce the communication overhead you can make use of `--hsdp` flag. 
+
+HSDP (Hybrid sharding Data Parallel) helps to define a hybrid sharding strategy where you can have FSDP within `sharding_group_size` which can be the minimum number of GPUs you can fit your model and DDP between the replicas of the model specified by `replica_group_size`.
+
+This will require to set the Sharding strategy in [fsdp config](./src/llama_recipes/configs/fsdp.py) to `ShardingStrategy.HYBRID_SHARD` and specify two additional settings, `sharding_group_size` and `replica_group_size` where former specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model and latter specifies the replica group size, which is world_size/sharding_group_size.
+
+
+```bash
+
+torchrun --nnodes 4 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --hsdp --sharding_group_size n --replica_group_size world_size/n
 
 ```
 

+ 3 - 0
scripts/spellcheck_conf/wordlist.txt

@@ -1178,6 +1178,9 @@ gradio
 pdf
 quantized
 streamlit
+HSDP
+ShardingStrategy
+hsdp
 prem
 Prem
 OpenAI

+ 4 - 1
src/llama_recipes/configs/fsdp.py

@@ -10,7 +10,10 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 class fsdp_config:
     mixed_precision: bool=True
     use_fp16: bool=False
-    sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
+    sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP".
+    hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group.
+    sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
+    replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size.
     checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
     fsdp_activation_checkpointing: bool=True
     fsdp_cpu_offload: bool=False

+ 11 - 2
src/llama_recipes/finetuning.py

@@ -11,7 +11,9 @@ import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
+    ShardingStrategy
 )
+
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
@@ -35,6 +37,7 @@ from llama_recipes.utils.config_utils import (
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
+from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
     train,
     freeze_transformer_layers,
@@ -42,7 +45,7 @@ from llama_recipes.utils.train_utils import (
     setup_environ_flags,
     clear_gpu_cache,
     print_model_size,
-    get_policies
+    get_policies,
 )
 from accelerate.utils import is_xpu_available
 
@@ -129,7 +132,12 @@ def main(**kwargs):
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model.print_trainable_parameters()
-
+        
+    hsdp_device_mesh = None
+    if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
+        hsdp_device_mesh = hdsp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
+        print("HSDP device mesh is ready")
+        
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
@@ -145,6 +153,7 @@ def main(**kwargs):
             cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
+            device_mesh=hsdp_device_mesh,
             device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,

+ 1 - 1
src/llama_recipes/utils/__init__.py

@@ -3,5 +3,5 @@
 
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.dataset_utils import *
-from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy
+from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh
 from llama_recipes.utils.train_utils import *

+ 57 - 1
src/llama_recipes/utils/fsdp_utils.py

@@ -1,5 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+from torch.distributed._tensor.device_mesh import init_device_mesh
+import os 
 
 def fsdp_auto_wrap_policy(model, transformer_layer_name):
     import functools
@@ -32,4 +34,58 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
     )
 
     auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
-    return auto_wrap_policy
+    return auto_wrap_policy
+
+
+def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
+    """
+     Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training.
+
+    This function requires explicit sizes for replica and sharding groups to accommodate models
+    whose GPU fit is unknown, providing flexibility in distributed training setups.
+    
+    Args:
+        replica_group_size (int): The size of each replica group. Must be provided to ensure
+            the model fits within the available resources.
+        sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to 
+            ensure the correct distribution of model parameters.
+        device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda"
+            with the local rank as the device index.
+
+    Returns:
+        A device mesh object compatible with FSDP.
+
+    Raises:
+        ValueError: If replica_group_size or sharding_group_size are not provided, or if the
+            world size is not evenly divisible by the sharding group size.
+        RuntimeError: If a valid device mesh cannot be created.
+
+    Usage:
+        If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then:
+        Sharding_Group_Size = 4
+        Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups
+        >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size)
+        >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...)
+    """
+
+    if replica_group_size is None or sharding_group_size is None:
+        raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
+
+    local_rank = int(os.getenv("LOCAL_RANK", "0"))
+    world_size = int(os.getenv("WORLD_SIZE", "1"))
+
+    device = device or f"cuda"
+
+    if world_size % sharding_group_size != 0:
+        raise ValueError(f"World size {world_size} is not evenly divisible by "
+                         f"sharding group size {sharding_group_size}.")
+
+    if (world_size // sharding_group_size) % replica_group_size != 0:
+        raise ValueError(f"The calculated number of replica groups is not evenly divisible by "
+                         f"replica_group_size {replica_group_size}.")
+
+    device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size))
+    if device_mesh is None:
+        raise RuntimeError("Failed to create a valid device mesh.")
+
+    return device_mesh