Hamid Shojanazeri 1 year ago
parent
commit
a2403c7c1a
1 changed files with 38 additions and 19 deletions
  1. 38 19
      utils/train_utils.py

+ 38 - 19
utils/train_utils.py

@@ -78,7 +78,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
-            data_set_len = 0
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
@@ -88,8 +87,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
-                first_key = next(iter(batch))
-                data_set_len += len(batch[first_key])
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
@@ -119,12 +116,20 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
+        if train_config.enable_fsdp:
+            if rank==0:
+                print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+            print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
         
-        print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-        print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-        print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-        print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-        print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
         
         # Update the learning rate as needed
         lr_scheduler.step()
@@ -133,10 +138,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)   
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 dist.barrier()
-                if  train_config.use_peft:
-                    print(f"we are in the saving the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)   
-                    print(f"PEFT modules are saved in {train_config.output_dir} directory")    
+                if train_config.use_peft:
+                    if train_config.enable_fsdp:
+                        if rank==0:
+                            print(f"we are about to save the PEFT modules")
+                    else:
+                        print(f"we are about to save the PEFT modules")
+                    model.save_pretrained(train_config.output_dir)  
+                    if train_config.enable_fsdp:
+                        if rank==0: 
+                            print(f"PEFT modules are saved in {train_config.output_dir} directory")
+                    else:
+                        print(f"PEFT modules are saved in {train_config.output_dir} directory")
+                        
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                         
@@ -171,8 +185,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             val_loss.append(best_val_loss)
             val_prep.append(eval_ppl)
         
-        
-        print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
+        if train_config.enable_fsdp:
+            if rank==0:
+                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
+        else:
+            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
+            
         lr_scheduler.step()
 
     avg_train_prep = sum(train_prep)/len(train_prep)
@@ -207,7 +225,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     model.eval()
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
-    eval_dataset_len = 0
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
             for key in batch.keys():
@@ -221,9 +238,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 outputs = model(**batch)
                 loss = outputs.loss
                 eval_loss += loss.detach().float()
-                first_key = next(iter(batch))
-                eval_dataset_len+= len(batch[first_key])
-                
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
             eval_preds.extend(
@@ -241,7 +255,12 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     eval_ppl = torch.exp(eval_epoch_loss)
     
     # Print evaluation metrics
-    print(f" {eval_ppl=} {eval_epoch_loss=}")
+    if train_config.enable_fsdp:
+        if local_rank==0:
+            print(f" {eval_ppl=} {eval_epoch_loss=}")
+    else:
+        print(f" {eval_ppl=} {eval_epoch_loss=}")
+        
     return eval_ppl, eval_epoch_loss
 
 def freeze_transformer_layers(model, num_layer):