|
@@ -44,7 +44,7 @@ def get_date_of_run():
|
|
|
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
|
|
|
|
|
|
|
|
-def load_model_sharded(model, rank, cfg, verbose=True):
|
|
|
+def load_model_sharded(model, rank, cfg):
|
|
|
# torch.manual_seed(103)
|
|
|
folder_name = (
|
|
|
cfg.dist_checkpoint_root_folder
|
|
@@ -83,7 +83,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
|
|
|
print(f"Sharded state checkpoint loaded from {load_dir}")
|
|
|
|
|
|
|
|
|
-def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
|
|
|
+def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
|
|
|
"""save model and optimizer via sharded_state_dict to save_dir"""
|
|
|
|
|
|
folder_name = (
|
|
@@ -142,7 +142,14 @@ def save_model_checkpoint(
|
|
|
if rank == 0:
|
|
|
print(f"--> saving model ...")
|
|
|
# create save path
|
|
|
- save_dir = Path.cwd() / cfg.checkpoint_folder
|
|
|
+ folder_name = (
|
|
|
+ cfg.dist_checkpoint_root_folder
|
|
|
+ + "/"
|
|
|
+ + cfg.dist_checkpoint_folder
|
|
|
+ + "-"
|
|
|
+ + cfg.model_name
|
|
|
+ )
|
|
|
+ save_dir = Path.cwd() / folder_name
|
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
save_name = cfg.model_name + "-" + str(epoch) + ".pt"
|
|
|
save_full_path = str(save_dir) + "/" + save_name
|
|
@@ -150,12 +157,12 @@ def save_model_checkpoint(
|
|
|
# save model
|
|
|
torch.save(cpu_state, save_full_path)
|
|
|
|
|
|
- if cfg.verbose:
|
|
|
- print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
|
|
|
+
|
|
|
+ print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
|
|
|
|
|
|
|
|
|
|
|
|
-def load_model_checkpoint(model, rank, cfg, verbose=True):
|
|
|
+def load_model_checkpoint(model, rank, cfg):
|
|
|
"""load local checkpoint to rank0 cpu
|
|
|
must be called * before * passing to FSDP"""
|
|
|
|
|
@@ -178,8 +185,8 @@ def load_model_checkpoint(model, rank, cfg, verbose=True):
|
|
|
# integrate into loaded model
|
|
|
model.load_state_dict(model_checkpoint)
|
|
|
|
|
|
- if cfg.verbose:
|
|
|
- print(f"model checkpoint loaded to rank0 cpu")
|
|
|
+
|
|
|
+ print(f"model checkpoint loaded to rank0 cpu")
|
|
|
|
|
|
|
|
|
def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
|
|
@@ -192,15 +199,22 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
|
|
|
|
|
|
optim_state = FSDP.full_optim_state_dict(model, optimizer)
|
|
|
|
|
|
- if cfg.verbose:
|
|
|
- print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
|
|
|
+
|
|
|
+ print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
|
|
|
|
|
|
if rank == 0:
|
|
|
- save_dir = Path.cwd() / cfg.checkpoint_folder
|
|
|
+ folder_name = (
|
|
|
+ cfg.dist_checkpoint_root_folder
|
|
|
+ + "/"
|
|
|
+ + cfg.dist_checkpoint_folder
|
|
|
+ + "-"
|
|
|
+ + cfg.model_name
|
|
|
+ )
|
|
|
+ save_dir = Path.cwd() / folder_name
|
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
opt_save_name = (
|
|
|
- cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
|
|
|
+ "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
|
|
|
)
|
|
|
opt_save_full_path = save_dir / opt_save_name
|
|
|
|
|
@@ -211,96 +225,25 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
|
|
|
print(f"--> saved {opt_save_full_path} to disk")
|
|
|
|
|
|
|
|
|
-def load_optimizer_checkpoint(model, optimizer, rank, cfg):
|
|
|
+def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
|
|
|
"""load an fsdp optimizer full_state checkpoint using scatter method
|
|
|
this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
|
|
|
"""
|
|
|
|
|
|
- opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file
|
|
|
|
|
|
- if not opt_file_path.is_file():
|
|
|
+ if not optimizer_checkpoint_path.is_file():
|
|
|
print(
|
|
|
- f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
|
|
|
+ f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
|
|
|
)
|
|
|
return
|
|
|
|
|
|
full_osd = None
|
|
|
|
|
|
if rank == 0:
|
|
|
- full_osd = torch.load(opt_file_path)
|
|
|
-
|
|
|
- if cfg.verbose:
|
|
|
- print(f"loaded full osd on rank 0")
|
|
|
+ full_osd = torch.load(optimizer_checkpoint_path)
|
|
|
|
|
|
# called from all ranks, though only rank0 has a valid param for full_osd
|
|
|
sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
|
|
|
|
|
|
- if cfg.verbose:
|
|
|
- print(f"optimizer shard loaded on rank {rank}")
|
|
|
-
|
|
|
-
|
|
|
+ print(f"optimizer shard loaded on rank {rank}")
|
|
|
|
|
|
-def load_distributed_model_checkpoint(model, rank, cfg):
|
|
|
- if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
|
|
|
- print(f"loading distributed checkpoint, rank {rank}...")
|
|
|
- folder_name = (
|
|
|
- cfg.dist_checkpoint_root_folder
|
|
|
- + "/"
|
|
|
- + cfg.dist_checkpoint_folder
|
|
|
- + "-"
|
|
|
- + cfg.model_name
|
|
|
- )
|
|
|
-
|
|
|
- checkdir = Path.cwd() / folder_name
|
|
|
-
|
|
|
- if not checkdir.exists():
|
|
|
- if rank == 0:
|
|
|
- print(f"No checkpoint directory found...skipping")
|
|
|
- return
|
|
|
-
|
|
|
-
|
|
|
- reader = FileSystemReader(checkdir)
|
|
|
-
|
|
|
- with FSDP.state_dict_type(
|
|
|
- model,
|
|
|
- StateDictType.LOCAL_STATE_DICT,
|
|
|
- ):
|
|
|
- state_dict = model.state_dict()
|
|
|
- load_state_dict(state_dict, reader)
|
|
|
- model.load_state_dict(state_dict)
|
|
|
-
|
|
|
- print(f"--> local state loaded on rank {rank}")
|
|
|
-
|
|
|
- return
|
|
|
-
|
|
|
-
|
|
|
-def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
|
|
|
- # distributed checkpoint saving
|
|
|
-
|
|
|
- # confirm type of checkpoint and save
|
|
|
- if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
|
|
|
- # create writer to current path
|
|
|
- folder_name = (
|
|
|
- cfg.dist_checkpoint_root_folder
|
|
|
- + "/"
|
|
|
- + cfg.dist_checkpoint_folder
|
|
|
- + "-"
|
|
|
- + cfg.model_name
|
|
|
- )
|
|
|
- save_dir = Path.cwd() / folder_name
|
|
|
-
|
|
|
- writer = FileSystemWriter(
|
|
|
- save_dir,
|
|
|
- )
|
|
|
-
|
|
|
- with FSDP.state_dict_type(
|
|
|
- model,
|
|
|
- StateDictType.LOCAL_STATE_DICT,
|
|
|
- ):
|
|
|
- state_dict = model.state_dict()
|
|
|
-
|
|
|
-
|
|
|
- # write out distributed checkpoint
|
|
|
- save_state_dict(state_dict, writer)
|
|
|
-
|
|
|
- return
|