memory_utils.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import gc
  4. import psutil
  5. import threading
  6. import torch
  7. def byte2gb(x):
  8. return int(x / 2**30)
  9. # This context manager is used to track the peak memory usage of the process
  10. class MemoryTrace:
  11. def __enter__(self):
  12. gc.collect()
  13. torch.cuda.empty_cache()
  14. torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
  15. self.begin = byte2gb(torch.cuda.memory_allocated())
  16. self.process = psutil.Process()
  17. self.cpu_begin = byte2gb(self.cpu_mem_used())
  18. self.peak_monitoring = True
  19. peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
  20. peak_monitor_thread.daemon = True
  21. peak_monitor_thread.start()
  22. return self
  23. def cpu_mem_used(self):
  24. """get resident set size memory for the current process"""
  25. return self.process.memory_info().rss
  26. def peak_monitor_func(self):
  27. self.cpu_peak = -1
  28. while True:
  29. self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
  30. # can't sleep or will not catch the peak right (this comment is here on purpose)
  31. # time.sleep(0.001) # 1msec
  32. if not self.peak_monitoring:
  33. break
  34. def __exit__(self, *exc):
  35. self.peak_monitoring = False
  36. gc.collect()
  37. torch.cuda.empty_cache()
  38. self.end = byte2gb(torch.cuda.memory_allocated())
  39. self.peak = byte2gb(torch.cuda.max_memory_allocated())
  40. cuda_info = torch.cuda.memory_stats()
  41. self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
  42. self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
  43. self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
  44. self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
  45. self.used = byte2gb(self.end - self.begin)
  46. self.peaked = byte2gb(self.peak - self.begin)
  47. self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
  48. self.cpu_end = self.cpu_mem_used()
  49. self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
  50. self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
  51. # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")