Browse Source

update with active memory and removing rank0 for eval score

Hamid Shojanazeri 1 year ago
parent
commit
3d887ea483
2 changed files with 3 additions and 2 deletions
  1. 1 0
      utils/memory_utils.py
  2. 2 2
      utils/train_utils.py

+ 1 - 0
utils/memory_utils.py

@@ -51,6 +51,7 @@ class MemoryTrace:
         self.peak = byte2gb(torch.cuda.max_memory_allocated())
         cuda_info = torch.cuda.memory_stats()
         self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
         self.used = byte2gb(self.end - self.begin)
         self.peaked = byte2gb(self.peak - self.begin)

+ 2 - 2
utils/train_utils.py

@@ -78,7 +78,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             model.train()
             total_loss = 0.0
             data_set_len = 0
-            
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
@@ -116,6 +115,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         print(f"Max CUDA memory allocated was {memtrace.peak} GB")
         print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+        print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
         print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
         print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
         
@@ -151,7 +151,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         )   
                                 
             
-            if local_rank == 0 and eval_epoch_loss < best_val_loss:
+            if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
                 print(f"best eval loss on epoch {epoch} is {best_val_loss}")
             val_loss.append(best_val_loss)