|
@@ -2,7 +2,6 @@
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import os
|
|
|
-from pkg_resources import packaging
|
|
|
|
|
|
import dataclasses
|
|
|
import fire
|
|
@@ -51,7 +50,7 @@ from llama_recipes.utils.train_utils import (
|
|
|
from accelerate.utils import is_xpu_available
|
|
|
|
|
|
def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
- try:
|
|
|
+ try:
|
|
|
import wandb
|
|
|
except ImportError:
|
|
|
raise ImportError(
|
|
@@ -97,7 +96,7 @@ def main(**kwargs):
|
|
|
|
|
|
if train_config.use_wandb:
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
- wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
+ wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|
|
@@ -108,11 +107,6 @@ def main(**kwargs):
|
|
|
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
|
|
|
overhead and currently requires latest nightly.
|
|
|
"""
|
|
|
- v = packaging.version.parse(torch.__version__)
|
|
|
- verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
|
|
|
- if not verify_latest_nightly:
|
|
|
- raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
|
|
|
- "please install latest nightly.")
|
|
|
if rank == 0:
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
train_config.model_name,
|
|
@@ -157,12 +151,12 @@ def main(**kwargs):
|
|
|
if wandb_run:
|
|
|
wandb_run.config.update(peft_config)
|
|
|
|
|
|
-
|
|
|
+
|
|
|
hsdp_device_mesh = None
|
|
|
if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
|
|
|
hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
|
|
|
print("HSDP device mesh is ready")
|
|
|
-
|
|
|
+
|
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
|
if train_config.enable_fsdp:
|
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
@@ -171,7 +165,7 @@ def main(**kwargs):
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
|
-
|
|
|
+
|
|
|
device_id = 0
|
|
|
if is_xpu_available():
|
|
|
device_id = torch.xpu.current_device()
|