Browse Source

Fix div by zero if run_validation=False

Matthias Reso 1 year ago
parent
commit
5b58afc754
2 changed files with 3 additions and 3 deletions
  1. 1 0
      src/llama_recipes/finetuning.py
  2. 2 3
      src/llama_recipes/utils/train_utils.py

+ 1 - 0
src/llama_recipes/finetuning.py

@@ -203,6 +203,7 @@ def main(**kwargs):
         collate_fn=default_data_collator,
     )
 
+    eval_dataloader = None
     if train_config.run_validation:
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,

+ 2 - 3
src/llama_recipes/utils/train_utils.py

@@ -179,14 +179,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     print(f"best eval loss on epoch {epoch} is {best_val_loss}")
             val_loss.append(best_val_loss)
             val_prep.append(eval_ppl)
-        
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
         else:
             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
-    avg_epoch_time = sum(epoch_times)/ len(epoch_times) 
-    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)   
+    avg_epoch_time = sum(epoch_times)/ len(epoch_times)
+    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation: