|
@@ -179,14 +179,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
val_loss.append(best_val_loss)
|
|
|
val_prep.append(eval_ppl)
|
|
|
-
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
|
else:
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
|
- avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
- avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)
|
|
|
+ avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
+ avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
|
avg_train_loss = sum(train_loss)/len(train_loss)
|
|
|
if train_config.run_validation:
|