Browse Source

Make tests run on cpu only machines

Matthias Reso 1 year ago
parent
commit
c5a382e509

+ 10 - 6
src/llama_recipes/finetuning.py

@@ -57,8 +57,6 @@ def main(**kwargs):
     # Set the seeds for reproducibility
     if is_xpu_available():
         torch.xpu.manual_seed(train_config.seed)
-    else:
-        torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
     random.seed(train_config.seed)
 
@@ -72,7 +70,7 @@ def main(**kwargs):
     if torch.distributed.is_initialized():
         if is_xpu_available():
             torch.xpu.set_device(local_rank)
-        else:
+        elif torch.cuda.is_available():
             torch.cuda.set_device(local_rank)
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
@@ -135,7 +133,7 @@ def main(**kwargs):
         
     hsdp_device_mesh = None
     if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
-        hsdp_device_mesh = hdsp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
+        hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
         print("HSDP device mesh is ready")
         
     #setting up FSDP if enable_fsdp is enabled
@@ -146,6 +144,12 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
+        
+        device_id = 0
+        if is_xpu_available():
+            device_id = torch.xpu.current_device()
+        elif torch.cuda.is_available():
+            device_id = torch.cuda.current_device()
 
         model = FSDP(
             model,
@@ -154,7 +158,7 @@ def main(**kwargs):
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_mesh=hsdp_device_mesh,
-            device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
+            device_id=device_id,
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
             param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
@@ -165,7 +169,7 @@ def main(**kwargs):
     elif not train_config.quantization and not train_config.enable_fsdp:
         if is_xpu_available():
             model.to("xpu:0")
-        else:
+        elif torch.cuda.is_available():
             model.to("cuda")
 
     dataset_config = generate_dataset_config(train_config, kwargs)

+ 20 - 6
src/llama_recipes/utils/memory_utils.py

@@ -56,21 +56,21 @@ class MemoryTrace:
             self.peak = byte2gb(torch.xpu.max_memory_allocated())
             xpu_info = torch.xpu.memory_stats()
             self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
-            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
+            self.malloc_retries = xpu_info.get("num_alloc_retries", 0)
             self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
-            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
+            self.m_ooms = xpu_info.get("num_ooms", 0)
             self.used = byte2gb(self.end - self.begin)
             self.peaked = byte2gb(self.peak - self.begin)
             self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
-        else:
+        elif torch.cuda.is_available():
             torch.cuda.empty_cache()
             self.end = byte2gb(torch.cuda.memory_allocated())
             self.peak = byte2gb(torch.cuda.max_memory_allocated())
             cuda_info = torch.cuda.memory_stats()
             self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+            self.malloc_retries = cuda_info.get("num_alloc_retries", 0)
             self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+            self.m_ooms = cuda_info.get("num_ooms", 0)
             self.used = byte2gb(self.end - self.begin)
             self.peaked = byte2gb(self.peak - self.begin)
             self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
@@ -78,4 +78,18 @@ class MemoryTrace:
         self.cpu_end = self.cpu_mem_used()
         self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
         self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
-        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
+        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
+        
+    def print_stats(self):
+        device_str = None
+        if is_xpu_available():
+            device_str = "XPU"
+        elif torch.cuda.is_available():
+            device_str = "CUDA"
+            
+        if device_str:
+            print(f"Max {device_str} memory allocated was {self.peak} GB")
+            print(f"Max {device_str} memory reserved was {self.max_reserved} GB")
+            print(f"Peak active {device_str} memory was {self.peak_active_gb} GB")
+            print(f"{device_str} Malloc retries : {self.malloc_retries}")
+        print(f"CPU Total Peak Memory consumed during the train (max): {self.cpu_peaked + self.cpu_begin} GB")

+ 2 - 25
src/llama_recipes/utils/train_utils.py

@@ -154,31 +154,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_prep.append(float(train_perplexity))
         train_loss.append(float(train_epoch_loss))
         
-        if train_config.enable_fsdp:
-            if rank==0:
-                if is_xpu_available():
-                    print(f"Max XPU memory allocated was {memtrace.peak} GB")
-                    print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
-                    print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
-                    print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
-                else:
-                    print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-                    print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-                    print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-                    print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        else:
-            if is_xpu_available():
-                print(f"Max XPU memory allocated was {memtrace.peak} GB")
-                print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
-                print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
-                print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
-            else:
-                print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-                print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-                print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-                print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-            print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        if not train_config.enable_fsdp or rank==0:
+            memtrace.print_stats()
 
         # Update the learning rate as needed
         lr_scheduler.step()

+ 16 - 4
tests/test_finetuning.py

@@ -6,7 +6,6 @@ from pytest import approx
 from unittest.mock import patch
 
 import torch
-from torch.nn import Linear
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
 from torch.utils.data.sampler import BatchSampler
@@ -45,7 +44,11 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
 
-    assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if torch.cuda.is_available():
+        assert get_model.return_value.to.call_count == 1
+        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    else:
+        assert get_model.return_value.to.call_count == 0
 
 
 @patch('llama_recipes.finetuning.train')
@@ -69,7 +72,11 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
 
-    assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if torch.cuda.is_available():
+        assert get_model.return_value.to.call_count == 1
+        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    else:
+        assert get_model.return_value.to.call_count == 0
 
 
 @patch('llama_recipes.finetuning.train')
@@ -87,7 +94,12 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 
     main(**kwargs)
 
-    assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
+    if torch.cuda.is_available():
+        assert get_model.return_value.to.call_count == 1
+        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    else:
+        assert get_model.return_value.to.call_count == 0
+    
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 
 

+ 0 - 83
utils/memory_utils.py

@@ -1,83 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-import gc
-import os
-import sys
-import threading
-
-import numpy as np
-import psutil
-import torch
-from accelerate.utils import is_xpu_available
-
-def byte2gb(x):
-    return int(x / 2**30)
-# This context manager is used to track the peak memory usage of the process
-class MemoryTrace:
-    def __enter__(self):
-        gc.collect()
-        if is_xpu_available():
-            torch.xpu.empty_cache()
-            torch.xpu.reset_max_memory_allocated()   # reset the peak gauge to zero
-            self.begin = byte2gb(torch.xpu.memory_allocated())
-        elif torch.cuda.is_available():
-            torch.cuda.empty_cache()
-            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
-            self.begin = byte2gb(torch.cuda.memory_allocated())
-        self.process = psutil.Process()
-        self.cpu_begin = byte2gb(self.cpu_mem_used())
-        self.peak_monitoring = True
-        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
-        peak_monitor_thread.daemon = True
-        peak_monitor_thread.start()
-        return self
-
-    def cpu_mem_used(self):
-        """get resident set size memory for the current process"""
-        return self.process.memory_info().rss
-
-    def peak_monitor_func(self):
-        self.cpu_peak = -1
-
-        while True:
-            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
-
-            # can't sleep or will not catch the peak right (this comment is here on purpose)
-            # time.sleep(0.001) # 1msec
-
-            if not self.peak_monitoring:
-                break
-
-    def __exit__(self, *exc):
-        self.peak_monitoring = False
-
-        gc.collect()
-        if is_xpu_available():
-            torch.xpu.empty_cache()
-            self.end = byte2gb(torch.xpu.memory_allocated())
-            self.peak = byte2gb(torch.xpu.max_memory_allocated())
-            xpu_info = torch.xpu.memory_stats()
-            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
-            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
-            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
-            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
-            self.used = byte2gb(self.end - self.begin)
-            self.peaked = byte2gb(self.peak - self.begin)
-            self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
-        else:
-            torch.cuda.empty_cache()
-            self.end = byte2gb(torch.cuda.memory_allocated())
-            self.peak = byte2gb(torch.cuda.max_memory_allocated())
-            cuda_info = torch.cuda.memory_stats()
-            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
-            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
-            self.used = byte2gb(self.end - self.begin)
-            self.peaked = byte2gb(self.peak - self.begin)
-            self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
-
-        self.cpu_end = self.cpu_mem_used()
-        self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
-        self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
-        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")