|
@@ -33,7 +33,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
def byte2mb(x):
|
|
|
return int(x / 2**20)
|
|
|
|
|
|
-def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
|
|
|
+def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
|
|
|
"""
|
|
|
Trains the model on the given dataloader
|
|
|
|
|
@@ -133,6 +133,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|
|
|
|
|
|
+ if wandb_run:
|
|
|
+ if not train_config.enable_fsdp or rank==0:
|
|
|
+ wandb_run.log({
|
|
|
+ 'train/epoch': epoch + 1,
|
|
|
+ 'train/step': epoch * len(train_dataloader) + step,
|
|
|
+ 'train/loss': loss.detach().float(),
|
|
|
+ })
|
|
|
+
|
|
|
pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
|
|
|
if train_config.save_metrics:
|
|
@@ -161,7 +169,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
if train_config.run_validation:
|
|
|
- eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
|
|
|
+ eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
|
|
|
if train_config.save_metrics:
|
|
|
val_step_loss.extend(temp_val_loss)
|
|
|
val_step_perplexity.extend(temp_step_perplexity)
|
|
@@ -252,7 +260,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
return results
|
|
|
|
|
|
-def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
+def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
|
|
|
"""
|
|
|
Evaluates the model on the given dataloader
|
|
|
|
|
@@ -315,6 +323,12 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
else:
|
|
|
print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
+
|
|
|
+ if wandb_run:
|
|
|
+ wandb_run.log({
|
|
|
+ 'eval/perplexity': eval_ppl,
|
|
|
+ 'eval/loss': eval_epoch_loss,
|
|
|
+ }, commit=False)
|
|
|
|
|
|
return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
|
|
|
|
|
@@ -478,4 +492,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
|
|
|
"val_epoch_perplexity": val_epoch_ppl
|
|
|
}
|
|
|
with open(output_filename, "w") as f:
|
|
|
- json.dump(metrics_data, f)
|
|
|
+ json.dump(metrics_data, f)
|