فهرست منبع

add max_steps_reached to reduce redundancy

Kai Wu 7 ماه پیش
والد
کامیت
e6f69f84ad
1فایلهای تغییر یافته به همراه6 افزوده شده و 3 حذف شده
  1. 6 3
      src/llama_recipes/utils/train_utils.py

+ 6 - 3
src/llama_recipes/utils/train_utils.py

@@ -80,11 +80,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     results = {}
     best_val_loss = float("inf")
     total_train_steps = 0
+    max_steps_reached = False  # Flag to indicate max training steps reached
+    # Start the training loop
     for epoch in range(train_config.num_epochs):
         # stop when the maximum number of training steps is reached
-        if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
-            if not train_config.enable_fsdp or local_rank==0:
-                print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
+        if max_steps_reached:
             break
         epoch_start_time = time.perf_counter()
         with MemoryTrace() as memtrace:  # track the memory usage
@@ -96,6 +96,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 total_train_steps += 1
                 # stop when the maximum number of training steps is reached
                 if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
+                    max_steps_reached = True
+                    if not train_config.enable_fsdp or local_rank==0:
+                        print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
                     break
                 for key in batch.keys():
                     if train_config.enable_fsdp: