lchu 1 year ago
parent
commit
e216c6f1f3
1 changed files with 2 additions and 2 deletions
  1. 2 2
      model_checkpointing/checkpoint_handler.py

+ 2 - 2
model_checkpointing/checkpoint_handler.py

@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg):
     reader = FileSystemReader(load_dir)
 
     with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
-        checkpoint = model.state_dict()
+        checkpoint = {"model": model.state_dict()}
         if rank == 0:
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
@@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg):
             print(f"checkpoint after load_state_dict()")
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-        model.load_state_dict(checkpoint)
+        model.load_state_dict(checkpoint["model"])
     if rank == 0:
         print(f"Sharded state checkpoint loaded from {load_dir}")