checkpoint_handler.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from pathlib import Path
  4. from datetime import datetime
  5. import torch
  6. import time
  7. from torch.distributed.fsdp import (
  8. FullyShardedDataParallel as FSDP,
  9. StateDictType,
  10. FullStateDictConfig, # general model non-sharded, non-flattened params
  11. LocalStateDictConfig, # flattened params, usable only by FSDP
  12. # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
  13. )
  14. from torch.distributed._shard.checkpoint import (
  15. FileSystemReader,
  16. FileSystemWriter,
  17. save_state_dict,
  18. load_state_dict,
  19. )
  20. from torch.distributed.checkpoint.default_planner import (
  21. DefaultSavePlanner,
  22. DefaultLoadPlanner,
  23. )
  24. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  25. import torch.distributed._shard.checkpoint as dist_cp
  26. import torch.distributed as dist
  27. def get_date_of_run():
  28. """create date and time for file save uniqueness
  29. example: 2022-05-07-08:31:12_PM'
  30. """
  31. date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
  32. print(f"--> current date and time of run = {date_of_run}")
  33. return date_of_run
  34. # create singleton saving policies to avoid making over and over
  35. fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
  36. def load_model_sharded(model, rank, cfg, verbose=True):
  37. # torch.manual_seed(103)
  38. folder_name = (
  39. cfg.dist_checkpoint_root_folder
  40. + "/"
  41. + cfg.dist_checkpoint_folder
  42. + "-"
  43. + cfg.model_name
  44. )
  45. load_dir = Path.cwd() / folder_name
  46. if not load_dir.exists():
  47. if rank == 0:
  48. print(f"No sharded_state_dict checkpoint directory found...skipping")
  49. return
  50. if rank == 0:
  51. print(f"loading model from model path: {load_dir} ")
  52. reader = FileSystemReader(load_dir)
  53. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
  54. checkpoint = model.state_dict()
  55. if rank == 0:
  56. ck = checkpoint.keys()
  57. print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
  58. dist_cp.load_state_dict(
  59. state_dict=checkpoint,
  60. storage_reader=reader,
  61. )
  62. if rank == 0:
  63. print(f"checkpoint after load_state_dict()")
  64. ck = checkpoint.keys()
  65. print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
  66. model.load_state_dict(checkpoint)
  67. if rank == 0:
  68. print(f"Sharded state checkpoint loaded from {load_dir}")
  69. def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
  70. """save model and optimizer via sharded_state_dict to save_dir"""
  71. folder_name = (
  72. cfg.dist_checkpoint_root_folder
  73. + "/"
  74. + cfg.dist_checkpoint_folder
  75. + "-"
  76. + cfg.model_name
  77. )
  78. save_dir = Path.cwd() / folder_name
  79. if rank == 0:
  80. print(f"Saving model to {save_dir}")
  81. distributed_writer = dist_cp.FileSystemWriter(
  82. save_dir,
  83. )
  84. t0 = time.perf_counter()
  85. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
  86. state_dict = {"model": model.state_dict()}
  87. if optim is not None:
  88. state_dict["optim"] = FSDP.optim_state_dict(model, optim)
  89. dist_cp.save_state_dict(
  90. state_dict=state_dict,
  91. storage_writer=distributed_writer,
  92. planner=DefaultSavePlanner(),
  93. )
  94. dist.barrier()
  95. t1 = time.perf_counter()
  96. if rank == 0:
  97. print(f"Sharded state checkpoint saved to {save_dir}")
  98. print(
  99. f"Checkpoint Time = {t1-t0:.4f}\n"
  100. )
  101. def save_model_checkpoint(
  102. model,
  103. optimizer,
  104. rank,
  105. cfg,
  106. epoch=1,
  107. ):
  108. """saving model via rank0 cpu streaming and full_state_dict"""
  109. with FSDP.state_dict_type(
  110. model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
  111. ):
  112. cpu_state = model.state_dict()
  113. print(f"saving process: rank {rank} done w model state_dict\n")
  114. if rank == 0:
  115. print(f"--> saving model ...")
  116. # create save path
  117. save_dir = Path.cwd() / cfg.checkpoint_folder
  118. save_dir.mkdir(parents=True, exist_ok=True)
  119. save_name = cfg.model_name + "-" + str(epoch) + ".pt"
  120. save_full_path = str(save_dir) + "/" + save_name
  121. # save model
  122. torch.save(cpu_state, save_full_path)
  123. if cfg.verbose:
  124. print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
  125. def load_model_checkpoint(model, rank, cfg, verbose=True):
  126. """load local checkpoint to rank0 cpu
  127. must be called * before * passing to FSDP"""
  128. if rank != 0:
  129. return
  130. # where is the checkpoint at...
  131. full_state_dict_model_path = (
  132. Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename
  133. )
  134. # is it present...
  135. if not full_state_dict_model_path.is_file():
  136. print(
  137. f"model checkpoint {full_state_dict_model_path} not present. Returning..."
  138. )
  139. return
  140. model_checkpoint = torch.load(full_state_dict_model_path)
  141. # integrate into loaded model
  142. model.load_state_dict(model_checkpoint)
  143. if cfg.verbose:
  144. print(f"model checkpoint loaded to rank0 cpu")
  145. def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
  146. """save optimizer state via full state dict"""
  147. print(f"--> optim state call on rank {rank}\n")
  148. # pull all sharded optimizer states to rank0 cpu...
  149. optim_state = FSDP.full_optim_state_dict(model, optimizer)
  150. if cfg.verbose:
  151. print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
  152. if rank == 0:
  153. save_dir = Path.cwd() / cfg.checkpoint_folder
  154. save_dir.mkdir(parents=True, exist_ok=True)
  155. opt_save_name = (
  156. cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
  157. )
  158. opt_save_full_path = save_dir / opt_save_name
  159. print(f"--> saving optimizer state...")
  160. torch.save(optim_state, opt_save_full_path)
  161. print(f"--> saved {opt_save_full_path} to disk")
  162. def load_optimizer_checkpoint(model, optimizer, rank, cfg):
  163. """load an fdsp optimizer full_state checkpoint using scatter method
  164. this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
  165. """
  166. opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file
  167. if not opt_file_path.is_file():
  168. print(
  169. f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
  170. )
  171. return
  172. full_osd = None
  173. if rank == 0:
  174. full_osd = torch.load(opt_file_path)
  175. if cfg.verbose:
  176. print(f"loaded full osd on rank 0")
  177. # called from all ranks, though only rank0 has a valid param for full_osd
  178. sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
  179. if cfg.verbose:
  180. print(f"optimizer shard loaded on rank {rank}")
  181. def load_distributed_model_checkpoint(model, rank, cfg):
  182. if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
  183. print(f"loading distributed checkpoint, rank {rank}...")
  184. folder_name = (
  185. cfg.dist_checkpoint_root_folder
  186. + "/"
  187. + cfg.dist_checkpoint_folder
  188. + "-"
  189. + cfg.model_name
  190. )
  191. checkdir = Path.cwd() / folder_name
  192. if not checkdir.exists():
  193. if rank == 0:
  194. print(f"No checkpoint directory found...skipping")
  195. return
  196. reader = FileSystemReader(checkdir)
  197. with FSDP.state_dict_type(
  198. model,
  199. StateDictType.LOCAL_STATE_DICT,
  200. ):
  201. state_dict = model.state_dict()
  202. load_state_dict(state_dict, reader)
  203. model.load_state_dict(state_dict)
  204. print(f"--> local state loaded on rank {rank}")
  205. return
  206. def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
  207. # distributed checkpoint saving
  208. # confirm type of checkpoint and save
  209. if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
  210. # create writer to current path
  211. folder_name = (
  212. cfg.dist_checkpoint_root_folder
  213. + "/"
  214. + cfg.dist_checkpoint_folder
  215. + "-"
  216. + cfg.model_name
  217. )
  218. save_dir = Path.cwd() / folder_name
  219. writer = FileSystemWriter(
  220. save_dir,
  221. )
  222. with FSDP.state_dict_type(
  223. model,
  224. StateDictType.LOCAL_STATE_DICT,
  225. ):
  226. state_dict = model.state_dict()
  227. # write out distributed checkpoint
  228. save_state_dict(state_dict, writer)
  229. return