@@ -153,7 +153,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
- print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+ if train_config.enable_fsdp:
+ if rank==0:
+ print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+ else:
val_loss.append(best_val_loss)
val_prep.append(eval_ppl)