12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- from torch.distributed._tensor.device_mesh import init_device_mesh
- import os
- def fsdp_auto_wrap_policy(model, transformer_layer_name):
- import functools
- from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
- from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
- def lambda_policy_fn(module):
- if (
- len(list(module.named_children())) == 0
- and getattr(module, "weight", None) is not None
- and module.weight.requires_grad
- ):
- return True
- return False
- lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
- transformer_wrap_policy = functools.partial(
- transformer_auto_wrap_policy,
- transformer_layer_cls=(
- PrefixEncoder,
- PromptEncoder,
- PromptEmbedding,
- transformer_layer_name,
- # FullyShardedDataParallelPlugin.get_module_class_from_name(
- # model, transformer_layer_name
- # ),
- ),
- )
- auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
- return auto_wrap_policy
- def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
- """
- Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training.
- This function requires explicit sizes for replica and sharding groups to accommodate models
- whose GPU fit is unknown, providing flexibility in distributed training setups.
-
- Args:
- replica_group_size (int): The size of each replica group. Must be provided to ensure
- the model fits within the available resources.
- sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to
- ensure the correct distribution of model parameters.
- device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda"
- with the local rank as the device index.
- Returns:
- A device mesh object compatible with FSDP.
- Raises:
- ValueError: If replica_group_size or sharding_group_size are not provided, or if the
- world size is not evenly divisible by the sharding group size.
- RuntimeError: If a valid device mesh cannot be created.
- Usage:
- If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then:
- Sharding_Group_Size = 4
- Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups
- >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size)
- >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...)
- """
- if replica_group_size is None or sharding_group_size is None:
- raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
- local_rank = int(os.getenv("LOCAL_RANK", "0"))
- world_size = int(os.getenv("WORLD_SIZE", "1"))
- device = device or f"cuda"
- if world_size % sharding_group_size != 0:
- raise ValueError(f"World size {world_size} is not evenly divisible by "
- f"sharding group size {sharding_group_size}.")
- if (world_size // sharding_group_size) % replica_group_size != 0:
- raise ValueError(f"The calculated number of replica groups is not evenly divisible by "
- f"replica_group_size {replica_group_size}.")
- device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size))
- if device_mesh is None:
- raise RuntimeError("Failed to create a valid device mesh.")
- return device_mesh
|