|
@@ -80,11 +80,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
results = {}
|
|
|
best_val_loss = float("inf")
|
|
|
total_train_steps = 0
|
|
|
+ max_steps_reached = False # Flag to indicate max training steps reached
|
|
|
+ # Start the training loop
|
|
|
for epoch in range(train_config.num_epochs):
|
|
|
# stop when the maximum number of training steps is reached
|
|
|
- if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
|
|
|
- if not train_config.enable_fsdp or local_rank==0:
|
|
|
- print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
|
|
|
+ if max_steps_reached:
|
|
|
break
|
|
|
epoch_start_time = time.perf_counter()
|
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
@@ -96,6 +96,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
total_train_steps += 1
|
|
|
# stop when the maximum number of training steps is reached
|
|
|
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
|
|
|
+ max_steps_reached = True
|
|
|
+ if not train_config.enable_fsdp or local_rank==0:
|
|
|
+ print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
|
|
|
break
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|