|
@@ -153,7 +153,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
if eval_epoch_loss < best_val_loss:
|
|
if eval_epoch_loss < best_val_loss:
|
|
best_val_loss = eval_epoch_loss
|
|
best_val_loss = eval_epoch_loss
|
|
- print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
|
|
|
+ if train_config.enable_fsdp:
|
|
|
|
+ if rank==0:
|
|
|
|
+ print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
|
+ else:
|
|
|
|
+ print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
val_loss.append(best_val_loss)
|
|
val_loss.append(best_val_loss)
|
|
val_prep.append(eval_ppl)
|
|
val_prep.append(eval_ppl)
|
|
|
|
|