瀏覽代碼

adding active mem stat (#44)

Geeta Chauhan 1 年之前
父節點
當前提交
021ed8e312
共有 3 個文件被更改,包括 3 次插入3 次删除
  1. 1 1
      configs/fsdp.py
  2. 1 0
      utils/memory_utils.py
  3. 1 2
      utils/train_utils.py

+ 1 - 1
configs/fsdp.py

@@ -13,7 +13,7 @@ class fsdp_config:
     sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
     checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
     fsdp_activation_checkpointing: bool=True
-    pure_bf16: bool = True
+    pure_bf16: bool = False
     optimizer: str= "AdamW"
     
     

+ 1 - 0
utils/memory_utils.py

@@ -50,6 +50,7 @@ class MemoryTrace:
         self.end = byte2gb(torch.cuda.memory_allocated())
         self.peak = byte2gb(torch.cuda.max_memory_allocated())
         cuda_info = torch.cuda.memory_stats()
+        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
         self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
         self.used = byte2gb(self.end - self.begin)

+ 1 - 2
utils/train_utils.py

@@ -78,13 +78,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             model.train()
             total_loss = 0.0
             data_set_len = 0
-            
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-
                         batch[key] = batch[key].to('cuda:0')              
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
@@ -117,6 +115,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         print(f"Max CUDA memory allocated was {memtrace.peak} GB")
         print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+        print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
         print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
         print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")