Browse Source

modify to steping the lr scheduler each epoch

Hamid Shojanazeri 1 year ago
parent
commit
20b061e01c
1 changed files with 7 additions and 4 deletions
  1. 7 4
      utils/train_utils.py

+ 7 - 4
utils/train_utils.py

@@ -105,11 +105,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
-                        lr_scheduler.step()
                         optimizer.zero_grad()
                         
-                print(f"\n step {step} is completed and loss is {loss.detach().float()}")
-
+                print(f"\n step {step} is completed and loss is {loss.detach().float()}")        
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -123,7 +121,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
         print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
         print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-            
+        
+        # Update the learning rate as needed
+        lr_scheduler.step()
+          
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
             if train_config.save_model and eval_epoch_loss < best_val_loss:
@@ -159,7 +160,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             val_loss.append(best_val_loss)
             val_prep.append(eval_ppl)
         
+        
         print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
+        lr_scheduler.step()
 
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)