|
@@ -125,7 +125,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.run_validation:
|
|
|
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
|
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
-
|
|
|
+ dist.barrier()
|
|
|
if train_config.use_peft:
|
|
|
|
|
|
print(f"we are in the saving the PEFT modules")
|
|
@@ -148,8 +148,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if not train_config.use_peft and train_config.save_optimizer:
|
|
|
model_checkpointing.save_optimizer_checkpoint(
|
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
|
- )
|
|
|
-
|
|
|
+ )
|
|
|
+ dist.barrier()
|
|
|
|
|
|
if eval_epoch_loss < best_val_loss:
|
|
|
best_val_loss = eval_epoch_loss
|