Kaynağa Gözat

Adding a feature that will stop the training/eval process after reaching some max_steps (#428)

Hamid Shojanazeri 7 ay önce
ebeveyn
işleme
aaa9e2c863

+ 2 - 0
src/llama_recipes/configs/training.py

@@ -17,6 +17,8 @@ class train_config:
     gradient_clipping: bool = False
     gradient_clipping_threshold: float = 1.0
     num_epochs: int=3
+    max_train_step: int=0
+    max_eval_step: int=0
     num_workers_dataloader: int=1
     lr: float=1e-4
     weight_decay: float=0.0

+ 33 - 13
src/llama_recipes/utils/train_utils.py

@@ -57,9 +57,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
+
 
-    
 
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
 
@@ -74,12 +74,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_step_loss = []
         val_step_loss = []
         val_step_perplexity = []
-        
+
     epoch_times = []
     checkpoint_times = []
     results = {}
     best_val_loss = float("inf")
+    total_train_steps = 0
+    max_steps_reached = False  # Flag to indicate max training steps reached
+    # Start the training loop
     for epoch in range(train_config.num_epochs):
+        # stop when the maximum number of training steps is reached
+        if max_steps_reached:
+            break
         epoch_start_time = time.perf_counter()
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
@@ -87,6 +93,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
             for step, batch in enumerate(train_dataloader):
+                total_train_steps += 1
+                # stop when the maximum number of training steps is reached
+                if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
+                    max_steps_reached = True
+                    if not train_config.enable_fsdp or local_rank==0:
+                        print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
+                    break
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                         if is_xpu_available():
@@ -98,7 +111,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         if is_xpu_available():
                             batch[key] = batch[key].to('xpu:0')
                         else:
-                            batch[key] = batch[key].to('cuda:0')              
+                            batch[key] = batch[key].to('cuda:0')
                 with autocast():
                     loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
@@ -133,7 +146,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         optimizer.zero_grad()
                         pbar.update(1)
 
-                if wandb_run: 
+                if wandb_run:
                     if not train_config.enable_fsdp or rank==0:
                         wandb_run.log({
                             'train/epoch': epoch + 1,
@@ -158,10 +171,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
-        
+
         train_prep.append(float(train_perplexity))
         train_loss.append(float(train_epoch_loss))
-        
+
         if not train_config.enable_fsdp or rank==0:
             memtrace.print_stats()
 
@@ -231,7 +244,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
         else:
             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
-        
+
         # Saving the results every epoch to plot later
         if train_config.save_metrics:
             save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
@@ -279,8 +292,15 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
     val_step_loss = []
     val_step_perplexity = []
     eval_loss = 0.0  # Initialize evaluation loss
+    total_eval_steps = 0
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
+            total_eval_steps += 1
+            # stop when the maximum number of eval steps is reached
+            if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
+                if not train_config.enable_fsdp or local_rank==0:
+                    print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
+                break
             for key in batch.keys():
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
@@ -288,7 +308,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
                     if is_xpu_available():
                         batch[key] = batch[key].to('xpu:0')
                     else:
-                        batch[key] = batch[key].to('cuda:0')  
+                        batch[key] = batch[key].to('cuda:0')
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
@@ -296,7 +316,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
                 loss = outputs.loss
                 if train_config.save_metrics:
                     val_step_loss.append(loss.detach().float().item())
-                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))  
+                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))
 
                 eval_loss += loss.detach().float()
             # Decode predictions and add to evaluation predictions list
@@ -324,12 +344,12 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
 
-    if wandb_run: 
+    if wandb_run:
         wandb_run.log({
                         'eval/perplexity': eval_ppl,
                         'eval/loss': eval_epoch_loss,
                     }, commit=False)
-        
+
     return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
 def freeze_transformer_layers(model, num_layer):
@@ -410,7 +430,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
 
-    
+
     verify_bfloat_support = ((
     torch.version.cuda
     and torch.cuda.is_bf16_supported()

+ 4 - 2
tests/test_train_utils.py

@@ -44,6 +44,8 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.use_fp16 = False
     train_config.run_validation = False
     train_config.gradient_clipping = False
+    train_config.max_train_step = 0
+    train_config.max_eval_step = 0
     train_config.save_metrics = False
 
     train(
@@ -98,6 +100,8 @@ def test_save_to_json(temp_output_dir, mocker):
     train_config.run_validation = False
     train_config.gradient_clipping = False
     train_config.save_metrics = True
+    train_config.max_train_step = 0
+    train_config.max_eval_step = 0
     train_config.output_dir = temp_output_dir
 
     results = train(
@@ -114,5 +118,3 @@ def test_save_to_json(temp_output_dir, mocker):
 
     assert results["metrics_filename"] not in ["", None]
     assert os.path.isfile(results["metrics_filename"])
-
-