|
@@ -57,8 +57,6 @@ def main(**kwargs):
|
|
# Set the seeds for reproducibility
|
|
# Set the seeds for reproducibility
|
|
if is_xpu_available():
|
|
if is_xpu_available():
|
|
torch.xpu.manual_seed(train_config.seed)
|
|
torch.xpu.manual_seed(train_config.seed)
|
|
- else:
|
|
|
|
- torch.cuda.manual_seed(train_config.seed)
|
|
|
|
torch.manual_seed(train_config.seed)
|
|
torch.manual_seed(train_config.seed)
|
|
random.seed(train_config.seed)
|
|
random.seed(train_config.seed)
|
|
|
|
|
|
@@ -72,7 +70,7 @@ def main(**kwargs):
|
|
if torch.distributed.is_initialized():
|
|
if torch.distributed.is_initialized():
|
|
if is_xpu_available():
|
|
if is_xpu_available():
|
|
torch.xpu.set_device(local_rank)
|
|
torch.xpu.set_device(local_rank)
|
|
- else:
|
|
|
|
|
|
+ elif torch.cuda.is_available():
|
|
torch.cuda.set_device(local_rank)
|
|
torch.cuda.set_device(local_rank)
|
|
clear_gpu_cache(local_rank)
|
|
clear_gpu_cache(local_rank)
|
|
setup_environ_flags(rank)
|
|
setup_environ_flags(rank)
|
|
@@ -135,7 +133,7 @@ def main(**kwargs):
|
|
|
|
|
|
hsdp_device_mesh = None
|
|
hsdp_device_mesh = None
|
|
if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
|
|
if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
|
|
- hsdp_device_mesh = hdsp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
|
|
|
|
|
|
+ hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
|
|
print("HSDP device mesh is ready")
|
|
print("HSDP device mesh is ready")
|
|
|
|
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
@@ -146,6 +144,12 @@ def main(**kwargs):
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
|
|
+
|
|
|
|
+ device_id = 0
|
|
|
|
+ if is_xpu_available():
|
|
|
|
+ device_id = torch.xpu.current_device()
|
|
|
|
+ elif torch.cuda.is_available():
|
|
|
|
+ device_id = torch.cuda.current_device()
|
|
|
|
|
|
model = FSDP(
|
|
model = FSDP(
|
|
model,
|
|
model,
|
|
@@ -154,7 +158,7 @@ def main(**kwargs):
|
|
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
|
|
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
device_mesh=hsdp_device_mesh,
|
|
device_mesh=hsdp_device_mesh,
|
|
- device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
|
|
|
|
|
|
+ device_id=device_id,
|
|
limit_all_gathers=True,
|
|
limit_all_gathers=True,
|
|
sync_module_states=train_config.low_cpu_fsdp,
|
|
sync_module_states=train_config.low_cpu_fsdp,
|
|
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
@@ -165,7 +169,7 @@ def main(**kwargs):
|
|
elif not train_config.quantization and not train_config.enable_fsdp:
|
|
elif not train_config.quantization and not train_config.enable_fsdp:
|
|
if is_xpu_available():
|
|
if is_xpu_available():
|
|
model.to("xpu:0")
|
|
model.to("xpu:0")
|
|
- else:
|
|
|
|
|
|
+ elif torch.cuda.is_available():
|
|
model.to("cuda")
|
|
model.to("cuda")
|
|
|
|
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|