|
@@ -69,7 +69,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
|
total_length = len(train_dataloader)//gradient_accumulation_steps
|
|
|
- pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
|
|
|
+ pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length)
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
@@ -95,7 +95,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(gradient_accumulation_steps)
|
|
|
|
|
|
- pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
+ pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
pbar.close()
|
|
|
|
|
|
epoch_end_time = time.perf_counter()-epoch_start_time
|
|
@@ -177,16 +177,16 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
best_val_loss = eval_epoch_loss
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
- print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
+ print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
else:
|
|
|
- print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
+ print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
val_loss.append(best_val_loss)
|
|
|
val_prep.append(eval_ppl)
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
- print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
|
+ print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
else:
|
|
|
- print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
|
+ print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|