Browse Source

adding dist barrier before and after checkpointing

Hamid Shojanazeri 1 year ago
parent
commit
4ba4400a75
1 changed files with 3 additions and 3 deletions
  1. 3 3
      utils/train_utils.py

+ 3 - 3
utils/train_utils.py

@@ -125,7 +125,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
             if train_config.save_model and eval_epoch_loss < best_val_loss:
-                
+                dist.barrier()
                 if  train_config.use_peft:
                     
                     print(f"we are in the saving the PEFT modules")
@@ -148,8 +148,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if not train_config.use_peft and  train_config.save_optimizer:
                         model_checkpointing.save_optimizer_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
-                        )   
-                                
+                        )                      
+                dist.barrier()
             
             if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss