Browse Source

fixing the full state path in checkpoint handler

Hamid Shojanazeri 1 year ago
parent
commit
bedb96b78a
3 changed files with 33 additions and 93 deletions
  1. 0 2
      model_checkpointing/__init__.py
  2. 31 88
      model_checkpointing/checkpoint_handler.py
  3. 2 3
      utils/train_utils.py

+ 0 - 2
model_checkpointing/__init__.py

@@ -4,8 +4,6 @@
 from .checkpoint_handler import (
     load_model_checkpoint,
     save_model_checkpoint,
-    save_distributed_model_checkpoint,
-    load_distributed_model_checkpoint,
     load_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,

+ 31 - 88
model_checkpointing/checkpoint_handler.py

@@ -44,7 +44,7 @@ def get_date_of_run():
 fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
 
 
-def load_model_sharded(model, rank, cfg, verbose=True):
+def load_model_sharded(model, rank, cfg):
     # torch.manual_seed(103)
     folder_name = (
         cfg.dist_checkpoint_root_folder
@@ -83,7 +83,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
 
-def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
+def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
     """save model and optimizer via sharded_state_dict to save_dir"""
     
     folder_name = (
@@ -142,7 +142,14 @@ def save_model_checkpoint(
     if rank == 0:
         print(f"--> saving model ...")
         # create save path
-        save_dir = Path.cwd() / cfg.checkpoint_folder
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
         save_name = cfg.model_name + "-" + str(epoch) + ".pt"
         save_full_path = str(save_dir) + "/" + save_name
@@ -150,12 +157,12 @@ def save_model_checkpoint(
         # save model
         torch.save(cpu_state, save_full_path)
 
-        if cfg.verbose:
-            print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
+        
+        print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
       
 
 
-def load_model_checkpoint(model, rank, cfg, verbose=True):
+def load_model_checkpoint(model, rank, cfg):
     """load local checkpoint to rank0 cpu
     must be called * before * passing to FSDP"""
 
@@ -178,8 +185,8 @@ def load_model_checkpoint(model, rank, cfg, verbose=True):
     # integrate into loaded model
     model.load_state_dict(model_checkpoint)
 
-    if cfg.verbose:
-        print(f"model checkpoint loaded to rank0 cpu")
+    
+    print(f"model checkpoint loaded to rank0 cpu")
 
 
 def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
@@ -192,15 +199,22 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
 
     optim_state = FSDP.full_optim_state_dict(model, optimizer)
 
-    if cfg.verbose:
-        print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
+    
+    print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
 
     if rank == 0:
-        save_dir = Path.cwd() / cfg.checkpoint_folder
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
 
         opt_save_name = (
-            cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
+            "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
         )
         opt_save_full_path = save_dir / opt_save_name
 
@@ -211,96 +225,25 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
         print(f"--> saved {opt_save_full_path} to disk")
 
 
-def load_optimizer_checkpoint(model, optimizer, rank, cfg):
+def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
     """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     """
 
-    opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file
 
-    if not opt_file_path.is_file():
+    if not optimizer_checkpoint_path.is_file():
         print(
-            f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
+            f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
         )
         return
 
     full_osd = None
 
     if rank == 0:
-        full_osd = torch.load(opt_file_path)
-
-        if cfg.verbose:
-            print(f"loaded full osd on rank 0")
+        full_osd = torch.load(optimizer_checkpoint_path)
 
     # called from all ranks, though only rank0 has a valid param for full_osd
     sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
 
-    if cfg.verbose:
-        print(f"optimizer shard loaded on rank {rank}")
-
-
+    print(f"optimizer shard loaded on rank {rank}")
 
-def load_distributed_model_checkpoint(model, rank, cfg):
-    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
-        print(f"loading distributed checkpoint, rank {rank}...")
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-
-        checkdir = Path.cwd() / folder_name
-
-        if not checkdir.exists():
-            if rank == 0:
-                print(f"No checkpoint directory found...skipping")
-            return
-
-
-        reader = FileSystemReader(checkdir)
-
-        with FSDP.state_dict_type(
-            model,
-            StateDictType.LOCAL_STATE_DICT,
-        ):
-            state_dict = model.state_dict()
-            load_state_dict(state_dict, reader)
-            model.load_state_dict(state_dict)
-
-        print(f"--> local state loaded on rank {rank}")
-
-        return
-
-
-def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
-    # distributed checkpoint saving
-
-    # confirm type of checkpoint and save
-    if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
-        # create writer to current path
-        folder_name = (
-            cfg.dist_checkpoint_root_folder
-            + "/"
-            + cfg.dist_checkpoint_folder
-            + "-"
-            + cfg.model_name
-        )
-        save_dir = Path.cwd() / folder_name
-
-        writer = FileSystemWriter(
-            save_dir,
-        )
-
-        with FSDP.state_dict_type(
-            model,
-            StateDictType.LOCAL_STATE_DICT,
-        ):
-            state_dict = model.state_dict()
-       
-
-        # write out distributed checkpoint
-        save_state_dict(state_dict, writer)
-
-        return

+ 2 - 3
utils/train_utils.py

@@ -84,7 +84,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-
                         batch[key] = batch[key].to('cuda:0')              
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
@@ -137,7 +136,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                         
                         model_checkpointing.save_model_checkpoint(
-                            model, optimizer, rank, train_config, epoch=1
+                            model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" we are about to save the models *******")
@@ -148,7 +147,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
                     if not train_config.use_peft and  train_config.save_optimizer:
                         model_checkpointing.save_optimizer_checkpoint(
-                            model, optimizer, rank, train_config, epoch=1
+                            model, optimizer, rank, train_config, epoch=epoch
                         )