Browse Source

fixing the train/eval_loss calcualtion

Hamid Shojanazeri 1 year ago
parent
commit
e9559d2669
1 changed files with 24 additions and 10 deletions
  1. 24 10
      utils/train_utils.py

+ 24 - 10
utils/train_utils.py

@@ -66,7 +66,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler() 
-        
+    if train_config.enable_fsdp:
+        world_size = int(os.environ["WORLD_SIZE"]) 
     train_prep = []
     train_loss = []
     val_prep = []
@@ -102,12 +103,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()
-                        
-                print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                if train_config.enable_fsdp:
+                    if rank==0:       
+                        print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                else:
+                    print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                    
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
-        train_epoch_loss = total_loss / data_set_len
+        train_epoch_loss = total_loss / len(train_dataloader)
+        if train_config.enable_fsdp:
+            train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
         
         train_prep.append(train_perplexity)
@@ -127,11 +134,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 dist.barrier()
                 if  train_config.use_peft:
-                    
                     print(f"we are in the saving the PEFT modules")
                     model.save_pretrained(train_config.output_dir)   
-                    print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                    
+                    print(f"PEFT modules are saved in {train_config.output_dir} directory")    
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                         
@@ -139,16 +144,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
-                        print(" we are about to save the models *******")
+                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                        print("=====================================================")
                         
                         model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                             model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                            print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
+                            print("=====================================================")
 
                     if not train_config.use_peft and  train_config.save_optimizer:
                         model_checkpointing.save_optimizer_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
-                        )                      
+                        )
+                        print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
+                        print("=====================================================")                     
                 dist.barrier()
             
             if eval_epoch_loss < best_val_loss:
@@ -192,6 +202,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     
     Returns: eval_ppl, eval_epoch_loss
     """
+    if train_config.enable_fsdp:
+        world_size = int(os.environ["WORLD_SIZE"]) 
     model.eval()
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
@@ -223,7 +235,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     
     # Compute average loss and perplexity
-    eval_epoch_loss = eval_loss / eval_dataset_len
+    eval_epoch_loss = eval_loss / len(eval_dataloader)
+    if train_config.enable_fsdp:
+        eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     
     # Print evaluation metrics