memory_utils.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import gc
  4. import psutil
  5. import threading
  6. import torch
  7. from accelerate.utils import is_xpu_available
  8. def byte2gb(x):
  9. return int(x / 2**30)
  10. # This context manager is used to track the peak memory usage of the process
  11. class MemoryTrace:
  12. def __enter__(self):
  13. gc.collect()
  14. if is_xpu_available():
  15. torch.xpu.empty_cache()
  16. torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
  17. self.begin = byte2gb(torch.xpu.memory_allocated())
  18. elif torch.cuda.is_available():
  19. torch.cuda.empty_cache()
  20. torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
  21. self.begin = byte2gb(torch.cuda.memory_allocated())
  22. self.process = psutil.Process()
  23. self.cpu_begin = byte2gb(self.cpu_mem_used())
  24. self.peak_monitoring = True
  25. peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
  26. peak_monitor_thread.daemon = True
  27. peak_monitor_thread.start()
  28. return self
  29. def cpu_mem_used(self):
  30. """get resident set size memory for the current process"""
  31. return self.process.memory_info().rss
  32. def peak_monitor_func(self):
  33. self.cpu_peak = -1
  34. while True:
  35. self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
  36. # can't sleep or will not catch the peak right (this comment is here on purpose)
  37. # time.sleep(0.001) # 1msec
  38. if not self.peak_monitoring:
  39. break
  40. def __exit__(self, *exc):
  41. self.peak_monitoring = False
  42. gc.collect()
  43. if is_xpu_available():
  44. torch.xpu.empty_cache()
  45. self.end = byte2gb(torch.xpu.memory_allocated())
  46. self.peak = byte2gb(torch.xpu.max_memory_allocated())
  47. xpu_info = torch.xpu.memory_stats()
  48. self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
  49. self.malloc_retries = xpu_info.get("num_alloc_retries", 0)
  50. self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
  51. self.m_ooms = xpu_info.get("num_ooms", 0)
  52. self.used = byte2gb(self.end - self.begin)
  53. self.peaked = byte2gb(self.peak - self.begin)
  54. self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
  55. elif torch.cuda.is_available():
  56. torch.cuda.empty_cache()
  57. self.end = byte2gb(torch.cuda.memory_allocated())
  58. self.peak = byte2gb(torch.cuda.max_memory_allocated())
  59. cuda_info = torch.cuda.memory_stats()
  60. self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
  61. self.malloc_retries = cuda_info.get("num_alloc_retries", 0)
  62. self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
  63. self.m_ooms = cuda_info.get("num_ooms", 0)
  64. self.used = byte2gb(self.end - self.begin)
  65. self.peaked = byte2gb(self.peak - self.begin)
  66. self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
  67. self.cpu_end = self.cpu_mem_used()
  68. self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
  69. self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
  70. # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
  71. def print_stats(self):
  72. device_str = None
  73. if is_xpu_available():
  74. device_str = "XPU"
  75. elif torch.cuda.is_available():
  76. device_str = "CUDA"
  77. if device_str:
  78. print(f"Max {device_str} memory allocated was {self.peak} GB")
  79. print(f"Max {device_str} memory reserved was {self.max_reserved} GB")
  80. print(f"Peak active {device_str} memory was {self.peak_active_gb} GB")
  81. print(f"{device_str} Malloc retries : {self.malloc_retries}")
  82. print(f"CPU Total Peak Memory consumed during the train (max): {self.cpu_peaked + self.cpu_begin} GB")