checkpoint_handler.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from pathlib import Path
  4. from datetime import datetime
  5. import torch
  6. import time
  7. from torch.distributed.fsdp import (
  8. FullyShardedDataParallel as FSDP,
  9. StateDictType,
  10. FullStateDictConfig, # general model non-sharded, non-flattened params
  11. LocalStateDictConfig, # flattened params, usable only by FSDP
  12. # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
  13. )
  14. from torch.distributed._shard.checkpoint import (
  15. FileSystemReader,
  16. FileSystemWriter,
  17. save_state_dict,
  18. load_state_dict,
  19. )
  20. from torch.distributed.checkpoint.default_planner import (
  21. DefaultSavePlanner,
  22. DefaultLoadPlanner,
  23. )
  24. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  25. import torch.distributed._shard.checkpoint as dist_cp
  26. import torch.distributed as dist
  27. def get_date_of_run():
  28. """create date and time for file save uniqueness
  29. example: 2022-05-07-08:31:12_PM'
  30. """
  31. date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
  32. print(f"--> current date and time of run = {date_of_run}")
  33. return date_of_run
  34. # create singleton saving policies to avoid making over and over
  35. fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
  36. def load_model_sharded(model, rank, cfg):
  37. # torch.manual_seed(103)
  38. folder_name = (
  39. cfg.dist_checkpoint_root_folder
  40. + "/"
  41. + cfg.dist_checkpoint_folder
  42. + "-"
  43. + cfg.model_name
  44. )
  45. load_dir = Path.cwd() / folder_name
  46. if not load_dir.exists():
  47. if rank == 0:
  48. print(f"No sharded_state_dict checkpoint directory found...skipping")
  49. return
  50. if rank == 0:
  51. print(f"loading model from model path: {load_dir} ")
  52. reader = FileSystemReader(load_dir)
  53. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
  54. checkpoint = model.state_dict()
  55. if rank == 0:
  56. ck = checkpoint.keys()
  57. print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
  58. dist_cp.load_state_dict(
  59. state_dict=checkpoint,
  60. storage_reader=reader,
  61. )
  62. if rank == 0:
  63. print(f"checkpoint after load_state_dict()")
  64. ck = checkpoint.keys()
  65. print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
  66. model.load_state_dict(checkpoint)
  67. if rank == 0:
  68. print(f"Sharded state checkpoint loaded from {load_dir}")
  69. def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
  70. """save model and optimizer via sharded_state_dict to save_dir"""
  71. folder_name = (
  72. cfg.dist_checkpoint_root_folder
  73. + "/"
  74. + cfg.dist_checkpoint_folder
  75. + "-"
  76. + cfg.model_name
  77. )
  78. save_dir = Path.cwd() / folder_name
  79. if rank == 0:
  80. print(f"Saving model to {save_dir}")
  81. distributed_writer = dist_cp.FileSystemWriter(
  82. save_dir,
  83. )
  84. t0 = time.perf_counter()
  85. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
  86. state_dict = {"model": model.state_dict()}
  87. if optim is not None:
  88. state_dict["optim"] = FSDP.optim_state_dict(model, optim)
  89. dist_cp.save_state_dict(
  90. state_dict=state_dict,
  91. storage_writer=distributed_writer,
  92. planner=DefaultSavePlanner(),
  93. )
  94. dist.barrier()
  95. t1 = time.perf_counter()
  96. if rank == 0:
  97. print(f"Sharded state checkpoint saved to {save_dir}")
  98. print(
  99. f"Checkpoint Time = {t1-t0:.4f}\n"
  100. )
  101. def save_model_checkpoint(
  102. model,
  103. optimizer,
  104. rank,
  105. cfg,
  106. epoch=1,
  107. ):
  108. """saving model via rank0 cpu streaming and full_state_dict"""
  109. with FSDP.state_dict_type(
  110. model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
  111. ):
  112. cpu_state = model.state_dict()
  113. print(f"saving process: rank {rank} done w model state_dict\n")
  114. if rank == 0:
  115. print(f"--> saving model ...")
  116. # create save path
  117. folder_name = (
  118. cfg.dist_checkpoint_root_folder
  119. + "/"
  120. + cfg.dist_checkpoint_folder
  121. + "-"
  122. + cfg.model_name
  123. )
  124. save_dir = Path.cwd() / folder_name
  125. save_dir.mkdir(parents=True, exist_ok=True)
  126. save_name = cfg.model_name + "-" + str(epoch) + ".pt"
  127. save_full_path = str(save_dir) + "/" + save_name
  128. # save model
  129. torch.save(cpu_state, save_full_path)
  130. print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
  131. def load_model_checkpoint(model, rank, cfg):
  132. """load local checkpoint to rank0 cpu
  133. must be called * before * passing to FSDP"""
  134. if rank != 0:
  135. return
  136. # where is the checkpoint at...
  137. full_state_dict_model_path = (
  138. Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename
  139. )
  140. # is it present...
  141. if not full_state_dict_model_path.is_file():
  142. print(
  143. f"model checkpoint {full_state_dict_model_path} not present. Returning..."
  144. )
  145. return
  146. model_checkpoint = torch.load(full_state_dict_model_path)
  147. # integrate into loaded model
  148. model.load_state_dict(model_checkpoint)
  149. print(f"model checkpoint loaded to rank0 cpu")
  150. def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
  151. """save optimizer state via full state dict"""
  152. print(f"--> optim state call on rank {rank}\n")
  153. # pull all sharded optimizer states to rank0 cpu...
  154. optim_state = FSDP.full_optim_state_dict(model, optimizer)
  155. print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
  156. if rank == 0:
  157. folder_name = (
  158. cfg.dist_checkpoint_root_folder
  159. + "/"
  160. + cfg.dist_checkpoint_folder
  161. + "-"
  162. + cfg.model_name
  163. )
  164. save_dir = Path.cwd() / folder_name
  165. save_dir.mkdir(parents=True, exist_ok=True)
  166. opt_save_name = (
  167. "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
  168. )
  169. opt_save_full_path = save_dir / opt_save_name
  170. print(f"--> saving optimizer state...")
  171. torch.save(optim_state, opt_save_full_path)
  172. print(f"--> saved {opt_save_full_path} to disk")
  173. def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
  174. """load an fsdp optimizer full_state checkpoint using scatter method
  175. this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
  176. """
  177. if not optimizer_checkpoint_path.is_file():
  178. print(
  179. f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
  180. )
  181. return
  182. full_osd = None
  183. if rank == 0:
  184. full_osd = torch.load(optimizer_checkpoint_path)
  185. # called from all ranks, though only rank0 has a valid param for full_osd
  186. sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
  187. print(f"optimizer shard loaded on rank {rank}")
  188. def load_sharded_model_single_gpu(model,model_path):
  189. reader = FileSystemReader(model_path)
  190. state_dict = {
  191. "model": model.state_dict()
  192. }
  193. dist_cp.load_state_dict(
  194. state_dict=state_dict,
  195. storage_reader= FileSystemReader(model_path),
  196. no_dist=True,
  197. )
  198. model.load_state_dict(state_dict["model"])
  199. print(f"Sharded state checkpoint loaded from {model_path}")
  200. return model