|
@@ -65,9 +65,6 @@ def main(**kwargs):
|
|
|
clear_gpu_cache(local_rank)
|
|
|
setup_environ_flags(rank)
|
|
|
|
|
|
- # Calculate gradient accumulation steps
|
|
|
- gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
|
|
|
-
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
|
|
|
"""
|
|
@@ -240,7 +237,7 @@ def main(**kwargs):
|
|
|
tokenizer,
|
|
|
optimizer,
|
|
|
scheduler,
|
|
|
- gradient_accumulation_steps,
|
|
|
+ train_config.gradient_accumulation_steps,
|
|
|
train_config,
|
|
|
fsdp_config if train_config.enable_fsdp else None,
|
|
|
local_rank if train_config.enable_fsdp else None,
|