Explorar el Código

add max_step feature for training and eval

Kai Wu hace 7 meses
padre
commit
fa0a389f74
Se han modificado 2 ficheros con 32 adiciones y 13 borrados
  1. 2 0
      src/llama_recipes/configs/training.py
  2. 30 13
      src/llama_recipes/utils/train_utils.py

+ 2 - 0
src/llama_recipes/configs/training.py

@@ -17,6 +17,8 @@ class train_config:
     gradient_clipping: bool = False
     gradient_clipping_threshold: float = 1.0
     num_epochs: int=3
+    max_train_step: int=0
+    max_eval_step: int=0
     num_workers_dataloader: int=1
     lr: float=1e-4
     weight_decay: float=0.0

+ 30 - 13
src/llama_recipes/utils/train_utils.py

@@ -57,9 +57,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
+
 
-    
 
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
 
@@ -74,12 +74,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_step_loss = []
         val_step_loss = []
         val_step_perplexity = []
-        
+
     epoch_times = []
     checkpoint_times = []
     results = {}
     best_val_loss = float("inf")
+    total_train_steps = 0
     for epoch in range(train_config.num_epochs):
+        # stop when the maximum number of training steps is reached
+        if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
+            if not train_config.enable_fsdp or local_rank==0:
+                print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
+            break
         epoch_start_time = time.perf_counter()
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
@@ -87,6 +93,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
             for step, batch in enumerate(train_dataloader):
+                total_train_steps += 1
+                # stop when the maximum number of training steps is reached
+                if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
+                    break
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                         if is_xpu_available():
@@ -98,7 +108,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         if is_xpu_available():
                             batch[key] = batch[key].to('xpu:0')
                         else:
-                            batch[key] = batch[key].to('cuda:0')              
+                            batch[key] = batch[key].to('cuda:0')
                 with autocast():
                     loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
@@ -133,7 +143,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         optimizer.zero_grad()
                         pbar.update(1)
 
-                if wandb_run: 
+                if wandb_run:
                     if not train_config.enable_fsdp or rank==0:
                         wandb_run.log({
                             'train/epoch': epoch + 1,
@@ -158,10 +168,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
-        
+
         train_prep.append(float(train_perplexity))
         train_loss.append(float(train_epoch_loss))
-        
+
         if not train_config.enable_fsdp or rank==0:
             memtrace.print_stats()
 
@@ -231,7 +241,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
         else:
             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
-        
+
         # Saving the results every epoch to plot later
         if train_config.save_metrics:
             save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
@@ -279,8 +289,15 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
     val_step_loss = []
     val_step_perplexity = []
     eval_loss = 0.0  # Initialize evaluation loss
+    total_eval_steps = 0
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
+            total_eval_steps += 1
+            # stop when the maximum number of eval steps is reached
+            if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
+                if not train_config.enable_fsdp or local_rank==0:
+                    print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
+                break
             for key in batch.keys():
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
@@ -288,7 +305,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
                     if is_xpu_available():
                         batch[key] = batch[key].to('xpu:0')
                     else:
-                        batch[key] = batch[key].to('cuda:0')  
+                        batch[key] = batch[key].to('cuda:0')
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
@@ -296,7 +313,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
                 loss = outputs.loss
                 if train_config.save_metrics:
                     val_step_loss.append(loss.detach().float().item())
-                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))  
+                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))
 
                 eval_loss += loss.detach().float()
             # Decode predictions and add to evaluation predictions list
@@ -324,12 +341,12 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
 
-    if wandb_run: 
+    if wandb_run:
         wandb_run.log({
                         'eval/perplexity': eval_ppl,
                         'eval/loss': eval_epoch_loss,
                     }, commit=False)
-        
+
     return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
 def freeze_transformer_layers(model, num_layer):
@@ -410,7 +427,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
 
-    
+
     verify_bfloat_support = ((
     torch.version.cuda
     and torch.cuda.is_bf16_supported()