memory_utils.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import gc
  4. import os
  5. import sys
  6. import threading
  7. import numpy as np
  8. import psutil
  9. import torch
  10. from accelerate.utils import is_xpu_available
  11. def byte2gb(x):
  12. return int(x / 2**30)
  13. # This context manager is used to track the peak memory usage of the process
  14. class MemoryTrace:
  15. def __enter__(self):
  16. gc.collect()
  17. if is_xpu_available():
  18. torch.xpu.empty_cache()
  19. torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
  20. self.begin = byte2gb(torch.xpu.memory_allocated())
  21. elif torch.cuda.is_available():
  22. torch.cuda.empty_cache()
  23. torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
  24. self.begin = byte2gb(torch.cuda.memory_allocated())
  25. self.process = psutil.Process()
  26. self.cpu_begin = byte2gb(self.cpu_mem_used())
  27. self.peak_monitoring = True
  28. peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
  29. peak_monitor_thread.daemon = True
  30. peak_monitor_thread.start()
  31. return self
  32. def cpu_mem_used(self):
  33. """get resident set size memory for the current process"""
  34. return self.process.memory_info().rss
  35. def peak_monitor_func(self):
  36. self.cpu_peak = -1
  37. while True:
  38. self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
  39. # can't sleep or will not catch the peak right (this comment is here on purpose)
  40. # time.sleep(0.001) # 1msec
  41. if not self.peak_monitoring:
  42. break
  43. def __exit__(self, *exc):
  44. self.peak_monitoring = False
  45. gc.collect()
  46. if is_xpu_available():
  47. torch.xpu.empty_cache()
  48. self.end = byte2gb(torch.xpu.memory_allocated())
  49. self.peak = byte2gb(torch.xpu.max_memory_allocated())
  50. xpu_info = torch.xpu.memory_stats()
  51. self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
  52. self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
  53. self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
  54. self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
  55. self.used = byte2gb(self.end - self.begin)
  56. self.peaked = byte2gb(self.peak - self.begin)
  57. self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
  58. else:
  59. torch.cuda.empty_cache()
  60. self.end = byte2gb(torch.cuda.memory_allocated())
  61. self.peak = byte2gb(torch.cuda.max_memory_allocated())
  62. cuda_info = torch.cuda.memory_stats()
  63. self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
  64. self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
  65. self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
  66. self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
  67. self.used = byte2gb(self.end - self.begin)
  68. self.peaked = byte2gb(self.peak - self.begin)
  69. self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
  70. self.cpu_end = self.cpu_mem_used()
  71. self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
  72. self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
  73. # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")