|
@@ -137,7 +137,7 @@ def main(**kwargs):
|
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
limit_all_gathers=True,
|
|
|
- sync_module_states=True if train_config.low_cpu_fsdp else False,
|
|
|
+ sync_module_states=train_config.low_cpu_fsdp,
|
|
|
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
|
if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
)
|